-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[MLIR][XeGPU] Distribute create_nd_desc op without offset from Wg to Sg #152351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Nishant Patel (nbpatel) ChangesThis PR adds pattern to distribute the create_nd_desc op without offsets from workgroup (Wg) IR to subgroup (Sg) IR. Patch is 34.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152351.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 850f70cca288f..b2eaa436ac76e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -161,6 +161,18 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
LogicalResult
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+
+ // Ensure that the op has explicit offsets specified (either dynamic or
+ // constant).
+ int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+ if (offsetSize == 0) {
+ auto constOffsetsAttr = op.getConstOffsetsAttr();
+ if (!constOffsetsAttr || constOffsetsAttr.empty() ||
+ llvm::all_of(constOffsetsAttr.asArrayRef(),
+ [](auto v) { return v == 0; }))
+ return failure();
+ }
+
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
xegpu::TensorDescType tdescTy = op.getType();
@@ -250,6 +262,52 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
}
};
+// This pattern transforms the CreateNdDescOp without offsets to create a
+// subgroup descriptor from a workgroup descriptor
+struct WgToSgCreateNdOpNoOffset
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+ if (offsetSize != 0 || (op.getConstOffsetsAttr() &&
+ llvm::any_of(op.getConstOffsetsAttr().asArrayRef(),
+ [](auto v) { return v != 0; })))
+ return failure();
+
+ Location loc = op.getLoc();
+ MLIRContext *ctx = op.getContext();
+ xegpu::TensorDescType tdescTy = op.getType();
+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+ if (!layout)
+ return failure();
+
+ Type elemTy = tdescTy.getElementType();
+ ArrayRef<int64_t> wgShape = tdescTy.getShape();
+
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+ xegpu::TensorDescType newTdescTy =
+ xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+ layout.dropSgLayoutAndData());
+
+ SmallVector<Value> newCreateNdOps;
+ for (int i = 0; i < count; ++i) {
+ auto newOp = xegpu::CreateNdDescOp::create(
+ rewriter, loc, newTdescTy, op.getSource(), ValueRange(), ValueRange(),
+ ValueRange(), DenseI64ArrayAttr(), DenseI64ArrayAttr(),
+ DenseI64ArrayAttr());
+ newCreateNdOps.push_back(newOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
+ return success();
+ }
+};
+
/// This pattern transforms the LoadNdOp to load subgroup data.
struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
@@ -654,11 +712,12 @@ struct UnrealizedConversionCastOpPattern
namespace mlir {
namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
- patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
- WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
- UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
- WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
- patterns.getContext());
+ patterns
+ .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+ WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
+ WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
+ WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 628a4857d1253..f1b68c0decdc2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -7,7 +7,20 @@ gpu.module @test_round_robin_assignment {
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ %cst0 = arith.constant 0 : index
+ %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: create_nd_tdesc_no_offset
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][0, 0] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-NOT: xegpu.create_nd_tdesc
+ %cst0 = arith.constant 0 : index
+ %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
@@ -15,7 +28,8 @@ gpu.module @test_round_robin_assignment {
// CHECK-LABEL: load_nd_tdesc
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ %cst0 = arith.constant 0 : index
+ %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-COUNT-4: xegpu.load_nd %{{.*}}
// CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -30,7 +44,8 @@ gpu.module @test_round_robin_assignment {
// CHECK-LABEL: store_nd
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @store_nd(%src: memref<256x128xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ %cst0 = arith.constant 0 : index
+ %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
// CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -46,7 +61,8 @@ gpu.module @test_round_robin_assignment {
// CHECK-LABEL: update_nd
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @update_nd(%src: memref<256x128xf32>){
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ %cst0 = arith.constant 0 : index
+ %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
// CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>>
@@ -69,12 +85,13 @@ gpu.module @test_round_robin_assignment {
// CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
// CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
// CHECK-NOT: xegpu.dpas
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
+ %cst0 = arith.constant 0 : index
+ %tdesc_a = xegpu.create_nd_tdesc %a[%cst0, %cst0] : memref<256x128xf16>
-> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-> vector<256x128xf16>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[%cst0, %cst0] : memref<128x256xf16>
-> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
@@ -91,7 +108,8 @@ gpu.module @test_round_robin_assignment {
// CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
// CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.prefetch_nd
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ %cst0 = arith.constant 0 : index
+ %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.prefetch_nd %tdesc
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -101,7 +119,8 @@ gpu.module @test_round_robin_assignment {
// CHECK-LABEL: broadcast
// CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32>
gpu.func @broadcast(%src: memref<128x1xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32>
+ %cst0 = arith.constant 0 : index
+ %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<128x1xf32>
-> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
@@ -122,8 +141,8 @@ gpu.module @test_round_robin_assignment {
%c0 = arith.constant 0 : index
%c256 = arith.constant 256 : index
%c1024 = arith.constant 1024 : index
- %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
- %1 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+ %1 = xegpu.create_nd_tdesc %arg1[%c0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
// CHECK-LABEL: scf.for
// CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
%2:2 = scf.for %arg2 = %c0 to %c1024 step %c256 iter_args(%arg3 = %0, %arg4 = %1)
@@ -143,9 +162,10 @@ gpu.module @test_round_robin_assignment {
%c1_i32 = arith.constant 1 : i32
%c10_i32 = arith.constant 10 : i32
%c0_i32 = arith.constant 0 : i32
- %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+ %cst0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
- %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+ %2 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
//CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32)
%3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) {
%4 = arith.cmpi slt, %arg3, %c10_i32 : i32
@@ -164,10 +184,11 @@ gpu.module @test_round_robin_assignment {
}
gpu.func @scf_if(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
+ %cst0 = arith.constant 0 : index
%c10 = arith.constant 10 : index
%0 = gpu.subgroup_id : index
- %1 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
- %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+ %1 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+ %2 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
%3 = arith.cmpi eq, %0, %c10 : index
// CHECK-LABEL: scf.if
// CHECK-SAME: (vector<16xf32>, vector<16xf32>)
@@ -189,20 +210,20 @@ gpu.module @test_round_robin_assignment {
gpu.func @scf_if_tensor_desc(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
%c10 = arith.constant 10 : index
%id = gpu.subgroup_id : index
-
- %t = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+ %cst0 = arith.constant 0 : index
+ %t = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
%d = xegpu.load_nd %t : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
%0 = arith.cmpi eq, %id, %c10 : index
// CHECK-LABEL: scf.if
// CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
%1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>) {
- %2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+ %2 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
// CHECK-LABEL: scf.yield
// CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
} else {
- %3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+ %3 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
// CHECK-LABEL: scf.yield
// CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
@@ -212,7 +233,8 @@ gpu.module @test_round_robin_assignment {
}
gpu.func @convert_layout_optimal(%arg0: memref<32x64xf32>) {
- %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>>
+ %cst0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0[%cst0, %cst0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>>
//CHECK-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32>
//CHECK-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32>
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>> -> vector<32x64xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d4b00372bc193..a0352169f2380 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -6,32 +6,42 @@ gpu.module @test_1_1_assignment {
// CHECK-LABEL: create_nd_tdesc
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
- // CHECK: %[[SGID:.*]] = gpu.subgroup_id
- // CHECK: %[[C8:.*]] = arith.constant 8 : index
- // CHECK: %[[C32:.*]] = arith.constant 32 : index
- // CHECK: %[[C4:.*]] = arith.constant 4 : index
- // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
- // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
- // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
- // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
- // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[C256:.*]] = arith.constant 256 : index
- // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
- // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]]
- // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
- // CHECK: %[[C128:.*]] = arith.constant 128 : index
- // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
- // CHECK: %[[C0_4:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]]
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK: gpu.return
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
- -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
+ // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
+ // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
+ // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
+ // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
+ // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
+ // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
+ // CHECK: %[[C256:.*]] = arith.constant 256 : index
+ // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
+ // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]]
+ // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
+ // CHECK: %[[C128:.*]] = arith.constant 128 : index
+ // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
+ // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0]]
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK: gpu.return
+ %cst0 = arith.constant 0 : index
+ %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: create_nd_tdesc_no_offset
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
+ // CHECK: xegpu.create_nd_tdesc %[[ARG_0]][0, 0] : memref...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
what makes "no offset" case special? |
"initialization."); | ||
Type srcTy = source.getType(); | ||
assert((isa<IntegerType, MemRefType>(srcTy)) && | ||
"Source has to be either int or memref."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is (or should be) there a test for each of 2 cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
The op without offset doesn't need to compute multiple offsets (in case of round-robin distribution) for each distributed op. |
Thanks! I think this can be added to PR description for better context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why create separate files and not put tests into existing ones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because when we transition to this version we can just delete the old file.
added to the description |
This PR adds pattern to distribute the create_nd_desc op without offsets from workgroup (Wg) IR to subgroup (Sg) IR.
The round robin distribution logic (involves offset calculation) now will happen in load/store/prefetch nd ops instead of create_nd.