Skip to content

[MLIR][XeGPU] Distribute create_nd_desc op without offset from Wg to Sg #152351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 12, 2025

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Aug 6, 2025

This PR adds pattern to distribute the create_nd_desc op without offsets from workgroup (Wg) IR to subgroup (Sg) IR.
The round robin distribution logic (involves offset calculation) now will happen in load/store/prefetch nd ops instead of create_nd.

@nbpatel nbpatel requested a review from chencha3 August 6, 2025 22:53
@nbpatel nbpatel marked this pull request as ready for review August 6, 2025 22:53
@nbpatel nbpatel requested a review from Jianhui-Li August 6, 2025 22:54
@llvmbot
Copy link
Member

llvmbot commented Aug 6, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

Changes

This PR adds pattern to distribute the create_nd_desc op without offsets from workgroup (Wg) IR to subgroup (Sg) IR.


Patch is 34.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152351.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+64-5)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+41-19)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+70-50)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 850f70cca288f..b2eaa436ac76e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -161,6 +161,18 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   LogicalResult
   matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+
+    // Ensure that the op has explicit offsets specified (either dynamic or
+    // constant).
+    int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+    if (offsetSize == 0) {
+      auto constOffsetsAttr = op.getConstOffsetsAttr();
+      if (!constOffsetsAttr || constOffsetsAttr.empty() ||
+          llvm::all_of(constOffsetsAttr.asArrayRef(),
+                       [](auto v) { return v == 0; }))
+        return failure();
+    }
+
     Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
     xegpu::TensorDescType tdescTy = op.getType();
@@ -250,6 +262,52 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   }
 };
 
+// This pattern transforms the CreateNdDescOp without offsets to create a
+// subgroup descriptor from a workgroup descriptor
+struct WgToSgCreateNdOpNoOffset
+    : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+    if (offsetSize != 0 || (op.getConstOffsetsAttr() &&
+                            llvm::any_of(op.getConstOffsetsAttr().asArrayRef(),
+                                         [](auto v) { return v != 0; })))
+      return failure();
+
+    Location loc = op.getLoc();
+    MLIRContext *ctx = op.getContext();
+    xegpu::TensorDescType tdescTy = op.getType();
+    auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+    if (!layout)
+      return failure();
+
+    Type elemTy = tdescTy.getElementType();
+    ArrayRef<int64_t> wgShape = tdescTy.getShape();
+
+    SmallVector<int64_t> sgShape;
+    int count;
+    std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+    xegpu::TensorDescType newTdescTy =
+        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+                                   layout.dropSgLayoutAndData());
+
+    SmallVector<Value> newCreateNdOps;
+    for (int i = 0; i < count; ++i) {
+      auto newOp = xegpu::CreateNdDescOp::create(
+          rewriter, loc, newTdescTy, op.getSource(), ValueRange(), ValueRange(),
+          ValueRange(), DenseI64ArrayAttr(), DenseI64ArrayAttr(),
+          DenseI64ArrayAttr());
+      newCreateNdOps.push_back(newOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
+    return success();
+  }
+};
+
 /// This pattern transforms the LoadNdOp to load subgroup data.
 struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
   using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
@@ -654,11 +712,12 @@ struct UnrealizedConversionCastOpPattern
 namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
-  patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
-               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
-               UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
-               WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
-      patterns.getContext());
+  patterns
+      .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+           WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
+           WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
+           WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
+          patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 628a4857d1253..f1b68c0decdc2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -7,7 +7,20 @@ gpu.module @test_round_robin_assignment {
       // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
       // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-NOT: xegpu.create_nd_tdesc
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      %cst0 = arith.constant 0 : index
+      %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
+        -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      gpu.return
+    }
+
+  // CHECK-LABEL: create_nd_tdesc_no_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
+      // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][0, 0] : memref<256x128xf32>
+      // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      // CHECK-NOT: xegpu.create_nd_tdesc
+      %cst0 = arith.constant 0 : index
+      %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       gpu.return
     }
@@ -15,7 +28,8 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: load_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      %cst0 = arith.constant 0 : index
+      %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
       // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -30,7 +44,8 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: store_nd
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @store_nd(%src: memref<256x128xf32>) {
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      %cst0 = arith.constant 0 : index
+      %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
       // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -46,7 +61,8 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: update_nd
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @update_nd(%src: memref<256x128xf32>){
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
       ->  !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
     // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>>
@@ -69,12 +85,13 @@ gpu.module @test_round_robin_assignment {
     // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
     // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
     // CHECK-NOT: xegpu.dpas
-    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
+    %cst0 = arith.constant 0 : index
+    %tdesc_a = xegpu.create_nd_tdesc %a[%cst0, %cst0] : memref<256x128xf16>
       -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load_a =  xegpu.load_nd %tdesc_a
       : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       -> vector<256x128xf16>
-    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16>
+    %tdesc_b = xegpu.create_nd_tdesc %b[%cst0, %cst0] : memref<128x256xf16>
       -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
     %load_b =  xegpu.load_nd %tdesc_b
       : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
@@ -91,7 +108,8 @@ gpu.module @test_round_robin_assignment {
     // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
     // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.prefetch_nd
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
     xegpu.prefetch_nd %tdesc
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -101,7 +119,8 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: broadcast
   // CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32>
   gpu.func @broadcast(%src: memref<128x1xf32>) {
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<128x1xf32>
       -> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
@@ -122,8 +141,8 @@ gpu.module @test_round_robin_assignment {
     %c0 = arith.constant 0 : index
     %c256 = arith.constant 256 : index
     %c1024 = arith.constant 1024 : index
-    %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
-    %1 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %1 = xegpu.create_nd_tdesc %arg1[%c0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     // CHECK-LABEL: scf.for
     // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
     %2:2 = scf.for %arg2 = %c0 to %c1024 step %c256 iter_args(%arg3 = %0, %arg4 = %1)
@@ -143,9 +162,10 @@ gpu.module @test_round_robin_assignment {
     %c1_i32 = arith.constant 1 : i32
     %c10_i32 = arith.constant 10 : i32
     %c0_i32 = arith.constant 0 : i32
-    %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %cst0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
-    %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %2 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     //CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32)
     %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) {
       %4 = arith.cmpi slt, %arg3, %c10_i32 : i32
@@ -164,10 +184,11 @@ gpu.module @test_round_robin_assignment {
   }
 
   gpu.func @scf_if(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
+    %cst0 = arith.constant 0 : index
     %c10 = arith.constant 10 : index
     %0 = gpu.subgroup_id : index
-    %1 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
-    %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %1 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %2 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     %3 = arith.cmpi eq, %0, %c10 : index
     // CHECK-LABEL: scf.if
     //  CHECK-SAME: (vector<16xf32>, vector<16xf32>)
@@ -189,20 +210,20 @@ gpu.module @test_round_robin_assignment {
   gpu.func @scf_if_tensor_desc(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
     %c10 = arith.constant 10 : index
     %id = gpu.subgroup_id : index
-
-    %t = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %cst0 = arith.constant 0 : index
+    %t = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     %d = xegpu.load_nd %t : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
 
     %0 = arith.cmpi eq, %id, %c10 : index
     // CHECK-LABEL: scf.if
     //  CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
     %1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>) {
-      %2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+      %2 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
       // CHECK-LABEL: scf.yield
       //  CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
       scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     } else {
-      %3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+      %3 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
       // CHECK-LABEL: scf.yield
       //  CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
       scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
@@ -212,7 +233,8 @@ gpu.module @test_round_robin_assignment {
   }
 
   gpu.func @convert_layout_optimal(%arg0: memref<32x64xf32>) {
-    %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>>
+    %cst0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%cst0, %cst0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>>
     //CHECK-2: xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32>
     //CHECK-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32>
     %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>> -> vector<32x64xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d4b00372bc193..a0352169f2380 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -6,32 +6,42 @@ gpu.module @test_1_1_assignment {
   // CHECK-LABEL: create_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
-  // CHECK: %[[SGID:.*]] = gpu.subgroup_id
-  // CHECK: %[[C8:.*]] = arith.constant 8 : index
-  // CHECK: %[[C32:.*]] = arith.constant 32 : index
-  // CHECK: %[[C4:.*]] = arith.constant 4 : index
-  // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
-  // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
-  // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
-  // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
-  // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
-  // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
-  // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[C256:.*]] = arith.constant 256 : index
-  // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
-  // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
-  // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]]
-  // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
-  // CHECK: %[[C128:.*]] = arith.constant 128 : index
-  // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
-  // CHECK: %[[C0_4:.*]] = arith.constant 0 : index
-  // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]]
-  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
-  // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  // CHECK: gpu.return
-  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-    -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
-  gpu.return
+    // CHECK: %[[C0:.*]] = arith.constant 0 : index
+    // CHECK: %[[SGID:.*]] = gpu.subgroup_id
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK: %[[C4:.*]] = arith.constant 4 : index
+    // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
+    // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
+    // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
+    // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
+    // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
+    // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
+    // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
+    // CHECK: %[[C256:.*]] = arith.constant 256 : index
+    // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
+    // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]]
+    // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
+    // CHECK: %[[C128:.*]] = arith.constant 128 : index
+    // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
+    // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0]]
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK: gpu.return
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+
+  // CHECK-LABEL: create_nd_tdesc_no_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
+    // CHECK: xegpu.create_nd_tdesc %[[ARG_0]][0, 0] : memref...
[truncated]

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Garra1980
Copy link

what makes "no offset" case special?

"initialization.");
Type srcTy = source.getType();
assert((isa<IntegerType, MemRefType>(srcTy)) &&
"Source has to be either int or memref.");

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is (or should be) there a test for each of 2 cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

@Jianhui-Li
Copy link
Contributor

what makes "no offset" case special?

The op without offset doesn't need to compute multiple offsets (in case of round-robin distribution) for each distributed op.

@Garra1980
Copy link

what makes "no offset" case special?

The op without offset doesn't need to compute multiple offsets (in case of round-robin distribution) for each distributed op.

Thanks! I think this can be added to PR description for better context

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why create separate files and not put tests into existing ones?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because when we transition to this version we can just delete the old file.

@nbpatel
Copy link
Contributor Author

nbpatel commented Aug 11, 2025

what makes "no offset" case special?

The op without offset doesn't need to compute multiple offsets (in case of round-robin distribution) for each distributed op.

Thanks! I think this can be added to PR description for better context

added to the description

@nbpatel nbpatel merged commit 88ff0f9 into llvm:main Aug 12, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants