Skip to content

Commit a1b35a4

Browse files
committed
Add more tests
1 parent 35bdf57 commit a1b35a4

File tree

3 files changed

+97
-5
lines changed

3 files changed

+97
-5
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,6 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
296296
}
297297
};
298298

299-
// This pattern transforms the LoadNdOp with explicit offsets to load subgroup
300-
// data.
301-
// Use a template parameter for the adaptor type
302299
template <typename OpTy, typename AdaptorTy, typename CreateFn>
303300
LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
304301
ConversionPatternRewriter &rewriter,
@@ -359,12 +356,12 @@ LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
359356
for (auto v : op.getOffsets())
360357
oldOffsets.push_back(v);
361358

362-
// Delegate to the operation-specific creation function
363359
return createOp(loc, sgShape, *maybeTdescOffsets, oldOffsets, adaptor,
364360
rewriter, op);
365361
}
366362

367-
// Usage for LoadNdOp
363+
// This pattern transforms the LoadNdOp with explicit offsets to load
364+
// subgroup data.
368365
struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
369366
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
370367
LogicalResult matchAndRewrite(

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,33 @@ gpu.module @test_distribution {
5353
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
5454
gpu.return
5555
}
56+
57+
// CHECK-LABEL: dpas
58+
// CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
59+
gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
60+
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16>
61+
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
62+
// CHECK-NOT: xegpu.create_nd_tdesc
63+
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16>
64+
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
65+
// CHECK-NOT: xegpu.create_nd_tdesc
66+
// CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
67+
// CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
68+
// CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
69+
// CHECK-NOT: xegpu.dpas
70+
%tdesc_a = xegpu.create_nd_tdesc %a : memref<256x128xf16>
71+
-> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
72+
%load_a = xegpu.load_nd %tdesc_a[0, 0]
73+
: !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
74+
-> vector<256x128xf16>
75+
%tdesc_b = xegpu.create_nd_tdesc %b : memref<128x256xf16>
76+
-> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
77+
%load_b = xegpu.load_nd %tdesc_b[0, 0]
78+
: !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
79+
-> vector<128x256xf16>
80+
%dpas = xegpu.dpas %load_a, %load_b
81+
{layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
82+
: vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
83+
gpu.return
84+
}
5685
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,4 +242,70 @@ gpu.module @test_distribution {
242242
xegpu.store_nd %8#2, %2[%0, %1] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
243243
gpu.return
244244
}
245+
246+
// CHECK-LABEL: @subgroup_id_range
247+
gpu.func @subgroup_id_range(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) {
248+
%sg_id = gpu.subgroup_id : index
249+
%c0 = arith.constant 0 : index
250+
%c1 = arith.constant 1 : index
251+
%c2 = arith.constant 2 : index
252+
%c31 = arith.constant 31 : index
253+
%c3 = arith.constant 3 : index
254+
%cond1 = arith.cmpi sge, %sg_id, %c0 : index
255+
%cond2 = arith.cmpi slt, %sg_id, %c1 : index
256+
%cond = arith.andi %cond1, %cond2 : i1
257+
scf.if %cond {
258+
// CHECK-NOT: index.sub
259+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
260+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
261+
%load = xegpu.load_nd %tdesc[0, 0]
262+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
263+
-> vector<256x128xf32>
264+
} {sg_id_range = #xegpu.range<[0, 32]>}
265+
%cond3 = arith.cmpi sge, %sg_id, %c2 : index
266+
%cond4 = arith.cmpi slt, %sg_id, %c31 : index
267+
%cond5 = arith.andi %cond3, %cond4 : i1
268+
scf.if %cond5 {
269+
// CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
270+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
271+
// CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
272+
%tdesc = xegpu.create_nd_tdesc %src2 : memref<128x64xf32>
273+
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
274+
%load = xegpu.load_nd %tdesc[0, 0]
275+
: !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
276+
-> vector<128x64xf32>
277+
%exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
278+
}{sg_id_range = #xegpu.range<[2, 18]>}
279+
gpu.return
280+
}
281+
282+
// CHECK-LABEL: @subgroup_id_range_nested_if
283+
gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
284+
%sg_id = gpu.subgroup_id : index
285+
%c1 = arith.constant 1 : i1
286+
%c3 = arith.constant 3 : index
287+
%c32 = arith.constant 32 : index
288+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
289+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
290+
%load = xegpu.load_nd %tdesc[0, 0]
291+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
292+
-> vector<256x128xf32>
293+
%cond1 = arith.cmpi sge, %sg_id, %c3 : index
294+
%cond2 = arith.cmpi slt, %sg_id, %c32 : index
295+
%cond = arith.andi %cond1, %cond2 : i1
296+
scf.if %c1 {
297+
scf.if %cond {
298+
// CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
299+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
300+
// CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
301+
%td = xegpu.create_nd_tdesc %src1 : memref<128x64xf32>
302+
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
303+
%ld = xegpu.load_nd %td[0, 0]
304+
: !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
305+
-> vector<128x64xf32>
306+
%exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
307+
}
308+
} {sg_id_range = #xegpu.range<[3, 19]>}
309+
gpu.return
310+
}
245311
}

0 commit comments

Comments
 (0)