Skip to content

Commit fe75a08

Browse files
committed
Add create_nd_desc pattern without offset
1 parent 1e815ce commit fe75a08

File tree

3 files changed

+175
-74
lines changed

3 files changed

+175
-74
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,18 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
161161
LogicalResult
162162
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
163163
ConversionPatternRewriter &rewriter) const override {
164+
165+
// Ensure that the op has explicit offsets specified (either dynamic or
166+
// constant).
167+
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
168+
if (offsetSize == 0) {
169+
auto constOffsetsAttr = op.getConstOffsetsAttr();
170+
if (!constOffsetsAttr || constOffsetsAttr.empty() ||
171+
llvm::all_of(constOffsetsAttr.asArrayRef(),
172+
[](auto v) { return v == 0; }))
173+
return failure();
174+
}
175+
164176
Location loc = op.getLoc();
165177
MLIRContext *ctx = op.getContext();
166178
xegpu::TensorDescType tdescTy = op.getType();
@@ -250,6 +262,52 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
250262
}
251263
};
252264

265+
// This pattern transforms the CreateNdDescOp without offsets to create a
266+
// subgroup descriptor from a workgroup descriptor
267+
struct WgToSgCreateNdOpNoOffset
268+
: public OpConversionPattern<xegpu::CreateNdDescOp> {
269+
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
270+
271+
LogicalResult
272+
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
273+
ConversionPatternRewriter &rewriter) const override {
274+
275+
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
276+
if (offsetSize != 0 || (op.getConstOffsetsAttr() &&
277+
llvm::any_of(op.getConstOffsetsAttr().asArrayRef(),
278+
[](auto v) { return v != 0; })))
279+
return failure();
280+
281+
Location loc = op.getLoc();
282+
MLIRContext *ctx = op.getContext();
283+
xegpu::TensorDescType tdescTy = op.getType();
284+
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
285+
if (!layout)
286+
return failure();
287+
288+
Type elemTy = tdescTy.getElementType();
289+
ArrayRef<int64_t> wgShape = tdescTy.getShape();
290+
291+
SmallVector<int64_t> sgShape;
292+
int count;
293+
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
294+
xegpu::TensorDescType newTdescTy =
295+
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
296+
layout.dropSgLayoutAndData());
297+
298+
SmallVector<Value> newCreateNdOps;
299+
for (int i = 0; i < count; ++i) {
300+
auto newOp = xegpu::CreateNdDescOp::create(
301+
rewriter, loc, newTdescTy, op.getSource(), ValueRange(), ValueRange(),
302+
ValueRange(), DenseI64ArrayAttr(), DenseI64ArrayAttr(),
303+
DenseI64ArrayAttr());
304+
newCreateNdOps.push_back(newOp);
305+
}
306+
rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
307+
return success();
308+
}
309+
};
310+
253311
/// This pattern transforms the LoadNdOp to load subgroup data.
254312
struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
255313
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
@@ -654,11 +712,12 @@ struct UnrealizedConversionCastOpPattern
654712
namespace mlir {
655713
namespace xegpu {
656714
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
657-
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
658-
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
659-
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
660-
WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
661-
patterns.getContext());
715+
patterns
716+
.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
717+
WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
718+
WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
719+
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
720+
patterns.getContext());
662721
}
663722
} // namespace xegpu
664723
} // namespace mlir

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,29 @@ gpu.module @test_round_robin_assignment {
77
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
88
// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
99
// CHECK-NOT: xegpu.create_nd_tdesc
10-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
10+
%cst0 = arith.constant 0 : index
11+
%tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
12+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
13+
gpu.return
14+
}
15+
16+
// CHECK-LABEL: create_nd_tdesc_no_offset
17+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
18+
gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
19+
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][0, 0] : memref<256x128xf32>
20+
// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
21+
// CHECK-NOT: xegpu.create_nd_tdesc
22+
%cst0 = arith.constant 0 : index
23+
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
1124
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
1225
gpu.return
1326
}
1427

1528
// CHECK-LABEL: load_nd_tdesc
1629
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
1730
gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
18-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
31+
%cst0 = arith.constant 0 : index
32+
%tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
1933
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
2034
// CHECK-COUNT-4: xegpu.load_nd %{{.*}}
2135
// CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -30,7 +44,8 @@ gpu.module @test_round_robin_assignment {
3044
// CHECK-LABEL: store_nd
3145
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
3246
gpu.func @store_nd(%src: memref<256x128xf32>) {
33-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
47+
%cst0 = arith.constant 0 : index
48+
%tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
3449
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
3550
// CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
3651
// CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -46,7 +61,8 @@ gpu.module @test_round_robin_assignment {
4661
// CHECK-LABEL: update_nd
4762
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
4863
gpu.func @update_nd(%src: memref<256x128xf32>){
49-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
64+
%cst0 = arith.constant 0 : index
65+
%tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
5066
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
5167
// CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
5268
// CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>>
@@ -69,12 +85,13 @@ gpu.module @test_round_robin_assignment {
6985
// CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
7086
// CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
7187
// CHECK-NOT: xegpu.dpas
72-
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
88+
%cst0 = arith.constant 0 : index
89+
%tdesc_a = xegpu.create_nd_tdesc %a[%cst0, %cst0] : memref<256x128xf16>
7390
-> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
7491
%load_a = xegpu.load_nd %tdesc_a
7592
: !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
7693
-> vector<256x128xf16>
77-
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16>
94+
%tdesc_b = xegpu.create_nd_tdesc %b[%cst0, %cst0] : memref<128x256xf16>
7895
-> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
7996
%load_b = xegpu.load_nd %tdesc_b
8097
: !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
@@ -91,7 +108,8 @@ gpu.module @test_round_robin_assignment {
91108
// CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
92109
// CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
93110
// CHECK-NOT: xegpu.prefetch_nd
94-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
111+
%cst0 = arith.constant 0 : index
112+
%tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
95113
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
96114
xegpu.prefetch_nd %tdesc
97115
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -101,7 +119,8 @@ gpu.module @test_round_robin_assignment {
101119
// CHECK-LABEL: broadcast
102120
// CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32>
103121
gpu.func @broadcast(%src: memref<128x1xf32>) {
104-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32>
122+
%cst0 = arith.constant 0 : index
123+
%tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<128x1xf32>
105124
-> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
106125
%load = xegpu.load_nd %tdesc
107126
: !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
@@ -122,8 +141,8 @@ gpu.module @test_round_robin_assignment {
122141
%c0 = arith.constant 0 : index
123142
%c256 = arith.constant 256 : index
124143
%c1024 = arith.constant 1024 : index
125-
%0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
126-
%1 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
144+
%0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
145+
%1 = xegpu.create_nd_tdesc %arg1[%c0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
127146
// CHECK-LABEL: scf.for
128147
// CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
129148
%2:2 = scf.for %arg2 = %c0 to %c1024 step %c256 iter_args(%arg3 = %0, %arg4 = %1)
@@ -143,9 +162,10 @@ gpu.module @test_round_robin_assignment {
143162
%c1_i32 = arith.constant 1 : i32
144163
%c10_i32 = arith.constant 10 : i32
145164
%c0_i32 = arith.constant 0 : i32
146-
%0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
165+
%cst0 = arith.constant 0 : index
166+
%0 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
147167
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
148-
%2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
168+
%2 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
149169
//CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32)
150170
%3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) {
151171
%4 = arith.cmpi slt, %arg3, %c10_i32 : i32
@@ -164,10 +184,11 @@ gpu.module @test_round_robin_assignment {
164184
}
165185

166186
gpu.func @scf_if(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
187+
%cst0 = arith.constant 0 : index
167188
%c10 = arith.constant 10 : index
168189
%0 = gpu.subgroup_id : index
169-
%1 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
170-
%2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
190+
%1 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
191+
%2 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
171192
%3 = arith.cmpi eq, %0, %c10 : index
172193
// CHECK-LABEL: scf.if
173194
// CHECK-SAME: (vector<16xf32>, vector<16xf32>)
@@ -189,20 +210,20 @@ gpu.module @test_round_robin_assignment {
189210
gpu.func @scf_if_tensor_desc(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
190211
%c10 = arith.constant 10 : index
191212
%id = gpu.subgroup_id : index
192-
193-
%t = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
213+
%cst0 = arith.constant 0 : index
214+
%t = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
194215
%d = xegpu.load_nd %t : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
195216

196217
%0 = arith.cmpi eq, %id, %c10 : index
197218
// CHECK-LABEL: scf.if
198219
// CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
199220
%1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>) {
200-
%2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
221+
%2 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
201222
// CHECK-LABEL: scf.yield
202223
// CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
203224
scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
204225
} else {
205-
%3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
226+
%3 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
206227
// CHECK-LABEL: scf.yield
207228
// CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
208229
scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
@@ -212,7 +233,8 @@ gpu.module @test_round_robin_assignment {
212233
}
213234

214235
gpu.func @convert_layout_optimal(%arg0: memref<32x64xf32>) {
215-
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>>
236+
%cst0 = arith.constant 0 : index
237+
%0 = xegpu.create_nd_tdesc %arg0[%cst0, %cst0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>>
216238
//CHECK-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32>
217239
//CHECK-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32>
218240
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>> -> vector<32x64xf32>

0 commit comments

Comments
 (0)