Skip to content

Commit 8c91b39

Browse files
committed
address comments
1 parent 3544073 commit 8c91b39

File tree

3 files changed

+53
-18
lines changed

3 files changed

+53
-18
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,29 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
335335
// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
336336
// it could be either 1:N or N:1 cast. In both cases, the pattern
337337
// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
338+
// for example, the following scf::forOp
339+
// ```
340+
// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
341+
// %n = use(%arg1): vector<128x128xf16>
342+
// scf.yield %n : vector<128x128xf16>
343+
// }
344+
// ```
345+
// Could be converted to:
346+
// ```
347+
// %1 = unrealized_conversion_cast %0
348+
// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
349+
// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
350+
// -> (vector<16x16xf16>, vector<16x16xf16) {
351+
// %m = unrealized_conversion_cast %arg1, %arg2
352+
// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
353+
// %n = use(%m): vector<128x128xf16>
354+
// %b = unrealized_conversion_cast %n
355+
// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
356+
// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
357+
// }
358+
// %cast = unrealized_conversion_cast %for:2
359+
// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
360+
// ```
338361
// TODO: remove it when context-aware type converter is ready.
339362
struct UnrealizedConversionCastOpPattern
340363
: public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
@@ -353,8 +376,9 @@ struct UnrealizedConversionCastOpPattern
353376
!llvm::all_equal(ValueRange(inputs).getTypes()))
354377
return failure();
355378

356-
// Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ...
357-
// the input values provided by the adaptor should already be distributed,
379+
// Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
380+
// It is generated by source materialization (e.g., inits to scf forOp).
381+
// The input values provided by the adaptor should already be distributed,
358382
// and their types should correspond exactly to the result types of the
359383
// operation.
360384
if (op.getNumOperands() == 1 &&
@@ -363,11 +387,13 @@ struct UnrealizedConversionCastOpPattern
363387
return success();
364388
}
365389

366-
// Handles the case where cast %1 : vector<16xf32>, ... to vector<256xf32>.
367-
// All input values must have the same vector type, and their shape must be
368-
// evenly divisible by the output vector's shape.
390+
// Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
391+
// It is generated by target materialization (e.g., arguments/results
392+
// of scf forOp). All input values must have the same vector type, and
393+
// their shape must be evenly divisible by the output vector's shape
394+
// (determined by the nature of the workgroup to subgroup distribution).
369395
// TODO: it is not safe to do such forward, since such N:1 cast could be
370-
// from others
396+
// from others.
371397
if (op.getNumResults() == 1 &&
372398
computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
373399
rewriter.replaceOpWithMultiple(op, {inputs});
@@ -510,7 +536,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
510536
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
511537
op->removeAttr(name);
512538
if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op))
513-
op->setAttr(name, layout.dropInstData());
539+
op->setAttr(name, layout.dropSgLayoutAndData());
514540
}
515541
}
516542
});

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ gpu.module @test_round_robin_assignment {
119119
xegpu.store_nd %3, %arg3 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
120120
%4 = xegpu.update_nd_offset %arg3, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
121121
%5 = xegpu.update_nd_offset %arg4, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
122+
// CHECK-LABEL: scf.yield
123+
// CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
122124
scf.yield %4, %5 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
123125
}
124126
gpu.return

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -186,22 +186,29 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
186186
%4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
187187
%5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
188188

189-
//CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]] iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) -> (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
190-
//CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
191-
//CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
192-
//CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
193-
//CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16>
194-
//CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16>
195-
//CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
196-
%6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>, !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>) {
189+
// CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]]
190+
// CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) ->
191+
// CHECK-SAME: (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
192+
// CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
193+
// CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
194+
// CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
195+
// CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16>
196+
// CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16>
197+
// CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
198+
%6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3)
199+
-> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>,
200+
!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>) {
197201
%8 = xegpu.load_nd %arg4 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
198202
%9 = xegpu.load_nd %arg5 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
199-
%10 = xegpu.dpas %8, %9, %arg6 {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
203+
%10 = xegpu.dpas %8, %9, %arg6 {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}
204+
: vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
200205
%11 = xegpu.update_nd_offset %arg4, [%c0, %c128] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
201206
%12 = xegpu.update_nd_offset %arg5, [%c128, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
202-
scf.yield %11, %12, %10 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>, !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>
207+
scf.yield %11, %12, %10 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>,
208+
!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>
203209
}
204-
%7 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
210+
%7 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32>
211+
-> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
205212
xegpu.store_nd %6#2, %7 : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
206213
gpu.return
207214
}

0 commit comments

Comments
 (0)