- 
                Notifications
    
You must be signed in to change notification settings  - Fork 15.1k
 
[MLIR][XeGPU] Support order attribute and add pattern for vector.transpose in WgToSg Pass #165307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| 
          
 @llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Nishant Patel (nbpatel) ChangesPatch is 67.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165307.diff 8 Files Affected: 
 diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 24e909548fe0b..dfd4093905875 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -270,26 +270,76 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
 FailureOr<SmallVector<Value>>
 LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
                                   Value linearId) {
-  // delinearizeSubgroupId is only available for
-  // workgroup-level layout attribute
   if (!isForWorkgroup())
     return failure();
 
-  // TODO: handle order attribute
-  auto hasDefaultOrder = [&]() {
-    DenseI32ArrayAttr order = getOrder();
-    return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
-                         llvm::reverse(order.asArrayRef())));
-  };
-  if (!hasDefaultOrder())
-    return mlir::emitError(loc, "order attribute is currently not supported.");
+  SmallVector<int64_t> sgLayoutInt = getEffectiveSgLayoutAsInt();
+  DenseI32ArrayAttr orderAttr = getOrder();
 
-  auto dims =
-      llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
-        return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
-      });
+  // Handle order attribute
+  SmallVector<int64_t> order;
+  if (orderAttr && !orderAttr.empty()) {
+    order = llvm::to_vector(
+        llvm::map_range(orderAttr.asArrayRef(),
+                        [](int32_t idx) { return static_cast<int64_t>(idx); }));
+  } else {
+    // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
+    order = llvm::to_vector(
+        llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
+  }
 
-  return affine::delinearizeIndex(builder, loc, linearId, dims);
+  if (order.size() != sgLayoutInt.size()) {
+    return failure();
+  }
+
+  SmallVector<Value> result(sgLayoutInt.size());
+  Value remaining = linearId;
+
+  /// Process dimensions in the order they appear in the order array
+  /// The first dimension in order is the fastest-changing
+  ///
+  /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
+  /// 
+  /// Initial: remaining=22, result=[?,?,?]
+  /// 
+  /// i=0 (process columns, dimIdx=2, dimSize=4):
+  ///   result[2] = 22 % 4 = 2  (column coordinate)
+  ///   remaining = 22 / 4 = 5  (5 complete groups of 4 columns processed)
+  /// 
+  /// i=1 (process rows, dimIdx=1, dimSize=4):
+  ///   result[1] = 5 % 4 = 1   (row coordinate) 
+  ///   remaining = 5 / 4 = 1   (1 complete group of 4 rows processed)
+  /// 
+  /// i=2 (process layers, dimIdx=0, dimSize=2):
+  ///   result[0] = 1 % 2 = 1   (layer coordinate)
+  ///   (no remaining update - last iteration)
+  /// 
+  /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
+  for (size_t i = 0; i < order.size(); ++i) {
+    int64_t dimIdx = order[i];
+    int64_t dimSize = sgLayoutInt[dimIdx];
+
+    Value dimSizeVal =
+        builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
+
+    /// Extract the coordinate for this dimension using modulo operation
+    /// This gives us "how far within this dimension" we are
+    /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within this
+    /// dimension)
+    result[dimIdx] =
+        builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal);
+
+    /// Update remaining for the next dimension by removing what we've already
+    /// processed. Division tells us "how many complete groups of this dimension
+    /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
+    /// completed 5 groups of 4) Skip this for the last iteration since there's
+    /// no next dimension to process
+    if (i < order.size() - 1) {
+      remaining =
+          builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal);
+    }
+  }
+  return result;
 }
 
 /// Implements DistributeLayoutAttr::getOffsets to generate
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9af5c7b..88d3ed743628e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1217,6 +1217,93 @@ struct WgToSgMultiDimReductionOp
   }
 };
 
+// This pattern transforms vector.transpose ops to work at subgroup level.
+struct WgToSgVectorTransposeOp
+    : public OpConversionPattern<vector::TransposeOp> {
+  using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(op.getVector());
+    if (!sourceLayout || !sourceLayout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sourceSgLayout =
+        sourceLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> sourceSgData = sourceLayout.getEffectiveSgDataAsInt();
+    SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> resultSgData = layout.getEffectiveSgDataAsInt();
+    DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
+    DenseI32ArrayAttr resultOrder = layout.getOrder();
+
+    if (!sourceOrder || !resultOrder) {
+      return rewriter.notifyMatchFailure(
+          op, "Both source and result must have order attributes");
+    }
+
+    SmallVector<int64_t> sourceOrderVec = llvm::to_vector(
+        llvm::map_range(sourceOrder.asArrayRef(),
+                        [](int32_t idx) { return static_cast<int64_t>(idx); }));
+    SmallVector<int64_t> resultOrderVec = llvm::to_vector(
+        llvm::map_range(resultOrder.asArrayRef(),
+                        [](int32_t idx) { return static_cast<int64_t>(idx); }));
+
+    ArrayRef<int64_t> permutation = op.getPermutation();
+    size_t expectedSize = permutation.size();
+    if (sourceSgLayout.size() != expectedSize ||
+        sourceSgData.size() != expectedSize ||
+        resultSgLayout.size() != expectedSize ||
+        resultSgData.size() != expectedSize ||
+        sourceOrderVec.size() != expectedSize ||
+        resultOrderVec.size() != expectedSize) {
+      return rewriter.notifyMatchFailure(
+          op, "All layouts and permutation must have the same rank");
+    }
+
+    // Check that sgLayout, sgData & order are properly transposed for operand
+    // and result
+    for (size_t i = 0; i < permutation.size(); ++i) {
+      int64_t srcDim = permutation[i];
+      if (resultSgLayout[i] != sourceSgLayout[srcDim] ||
+          resultSgData[i] != sourceSgData[srcDim] ||
+          resultOrderVec[i] != sourceOrderVec[srcDim]) {
+        return rewriter.notifyMatchFailure(
+            op, "Result layout is not a valid transpose of source layout "
+                "according to permutation");
+      }
+    }
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+    SmallVector<Value> newTransposeOps;
+    for (auto src : adaptor.getVector()) {
+      auto newTranspose = vector::TransposeOp::create(
+          rewriter, op.getLoc(), newResultType, src, permutation);
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty())
+        xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
+                                       layout.dropSgLayoutAndData());
+      newTransposeOps.push_back(newTranspose.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newTransposeOps});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -1231,7 +1318,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
            WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
-           WgToSgMultiDimReductionOp>(patterns.getContext());
+           WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
+          patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1358,7 +1446,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
-  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
+                               vector::TransposeOp, vector::BroadcastOp,
+                               vector::MultiDimReductionOp>(
       [=](Operation *op) -> bool {
         // Check for either a SliceAttr or LayoutAttr on the result.
         auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
@@ -1377,16 +1467,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
-  target.addDynamicallyLegalOp<vector::BroadcastOp>(
-      [=](vector::BroadcastOp op) -> bool {
-        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
-      });
-
-  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
-      [=](vector::MultiDimReductionOp op) -> bool {
-        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
-      });
-
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
index b73bc69393dab..02c5f71d5c83d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
@@ -1,33 +1,32 @@
 // RUN: mlir-opt --test-xegpu-layout-interface --cse -split-input-file %s | FileCheck %s
 
-//CHECk: #map = affine_map<()[s0] -> (s0 floordiv 8)>
 gpu.module @test {
   gpu.func @slice_attr() -> vector<128xindex> {
-    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
-    //CHECK: [[c32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
-    //CHECK: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
-    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
-    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
-    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C8:.*]]
+    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU]], %[[C4:.*]]
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
+    // CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
+    // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
+    // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
+    // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
     %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
     gpu.return %step : vector<128xindex>
   }
 
   gpu.func @nested_slice_attr() -> vector<128xindex> {
-    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
-    //CHECK: [[c32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
-    //CHECK: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
-    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
-    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
-    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[DIVU1:.*]] = index.divu %[[SGID]], %[[C1:.*]]
+    // CHECK-DAG: %[[DIVU2:.*]] = index.divu %[[DIVU1]], %[[C8:.*]]
+    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU2]], %[[C4:.*]]
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
+    // CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
+    // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
+    // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
+    // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
     %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 1], sg_data = [32, 32, 1]>, dims = [2]>, dims = [1]>} : vector<128xindex>
     gpu.return %0 : vector<128xindex>
   }
 
-}
\ No newline at end of file
+}
+
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
index 09df1e4da43e2..9580769d37313 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -166,14 +166,12 @@ gpu.module @test_elementwise_ops {
     %load_b = xegpu.load_nd %tdesc_b
       : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       -> vector<24x32xf32>
-    // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
-    // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+    // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
     // CHECK-NOT: arith.negf
     %negf = arith.negf %load_a
       {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<24x32xf32>
-    // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
-    // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+    // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
     // CHECK-NOT: math.powf
     %powf = math.powf %load_a, %load_b
       {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d2d250cbe0f66..01134d8eaabec 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -1,14 +1,10 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
-#map = affine_map<()[s0] -> (s0 floordiv 4)>
-#map1 = affine_map<()[s0] -> (s0 mod 4)>
-
 gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: create_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
-      // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
-      // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-NOT: xegpu.create_nd_tdesc
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -16,22 +12,23 @@ gpu.module @test_round_robin_assignment {
     }
 
   // CHECK-LABEL: create_nd_tdesc_with_shared_data
-  // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) {
-    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[IdY:%.+]] = affine.apply #map()[[[sgId]]]
-    //CHECK: [[IdX:%.+]] = affine.apply #map1()[[[sgId]]]
-    //CHECK: [[C16:%.+]] = arith.constant 16 : index
-    //CHECK: [[LY:%.+]] = index.mul [[IdY]], [[C16]]
-    //CHECK: [[C64:%.+]] = arith.constant 64 : index
-    //CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]]
-    //CHECK: [[C0:%.+]] = arith.constant 0 : index
-    //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
-    //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[offY:%.+]] = index.remu [[LY]], [[C128]]
-    //CHECK: [[C64_2:%.+]] = arith.constant 64 : index
-    //CHECK: [[offX:%.+]] = index.remu [[LX]], [[C64_2]]
-    //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
+    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK: %[[C4:.*]] = arith.constant 4 : index
+    // CHECK: %[[IDX:.*]] = index.remu %[[SGID]], %[[C4]]
+    // CHECK: %[[IDY_DIV:.*]] = index.divu %[[SGID]], %[[C4]]
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[IDY:.*]] = index.remu %[[IDY_DIV]], %[[C8]]
+    // CHECK: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK: %[[LY:.*]] = index.mul %[[IDY]], %[[C16]]
+    // CHECK: %[[C64:.*]] = arith.constant 64 : index
+    // CHECK: %[[LX:.*]] = index.mul %[[IDX]], %[[C64]]
+    // CHECK: %[[C128:.*]] = arith.constant 128 : index
+    // CHECK: %[[OFFY:.*]] = index.remu %[[LY]], %[[C128]]
+    // CHECK: %[[C64_1:.*]] = arith.constant 64 : index
+    // CHECK: %[[OFFX:.*]] = index.remu %[[LX]], %[[C64_1]]
+    // CHECK: xegpu.create_nd_tdesc %[[ARG_0]][%[[OFFY]], %[[OFFX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
     gpu.return
@@ -42,9 +39,7 @@ gpu.module @test_round_robin_assignment {
   gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-      // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
-      // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
+      // CHECK-COUNT-4: xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf32>
       // CHECK-NOT: xegpu.load_nd
       %load =  xegpu.load_nd %tdesc
         : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -57,9 +52,8 @@ gpu.module @test_round_robin_assignment {
   gpu.func @store_nd(%src: memref<256x128xf32>) {
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-      // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
-      // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      // CHECK-NOT : xegpu.store_nd
+      // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}} : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      // CHECK-NOT: xegpu.store_nd
       %load = xegpu.load_nd %tdesc
         : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
         -> vector<256x128xf32>
@@ -73,8 +67,7 @@ gpu.module @test_round_robin_assignment {
   gpu.func @update_nd(%src: memref<256x128xf32>){
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       ->  !xegpu.tensor_desc<2...
[truncated]
 | 
    
| 
          
 ✅ With the latest revision this PR passed the C/C++ code formatter.  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % comments.
| /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]: | ||
| /// | ||
| /// Initial: remaining=22, result=[?,?,?] | ||
| /// | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: consider add comment: dimIdx = order[i], dimSize = sgLayout[dimIdx]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR does the following: