Skip to content

Commit a25c40d

Browse files
committed
Feedback
1 parent e3c02a6 commit a25c40d

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ struct WgToSgLoadGatherOpWithOffset
780780
ArrayRef<int64_t> wgShape = resultType.getShape();
781781

782782
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
783-
if (!layout || !layout.getSgLayout())
783+
if (!layout || !layout.isForWorkgroup())
784784
return failure();
785785

786786
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
@@ -820,9 +820,8 @@ struct WgToSgStoreScatterOpWithOffset
820820
if (!valueType)
821821
return failure();
822822

823-
ArrayRef<int64_t> wgShape = valueType.getShape();
824823
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue());
825-
if (!layout || !layout.getSgLayout())
824+
if (!layout || !layout.isForWorkgroup())
826825
return failure();
827826

828827
auto chunkSizeOpt = op.getChunkSize();
@@ -833,12 +832,9 @@ struct WgToSgStoreScatterOpWithOffset
833832
rewriter.create<xegpu::StoreScatterOp>(
834833
loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
835834
op.getL2HintAttr(), op.getL3HintAttr());
836-
// Update the layout_result_0 attribute to drop sg_layout and sg_data.
837-
if (auto layoutAttr =
838-
op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0")) {
839-
if (auto newLayout = layoutAttr.dropSgLayoutAndData())
840-
op->setAttr("layout_result_0", newLayout);
841-
}
835+
// Update the layout attribute to drop sg_layout and sg_data.
836+
if (auto newLayout = layout.dropSgLayoutAndData())
837+
op->setAttr("layout", newLayout);
842838
}
843839
rewriter.eraseOp(op);
844840
return success();
@@ -1042,7 +1038,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
10421038
target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
10431039
[=](xegpu::StoreScatterOp op) -> bool {
10441040
// Check if the layout attribute is present on the result.
1045-
auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0");
1041+
auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
10461042
if (!layout)
10471043
return true;
10481044
return isLegal(layout);

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ gpu.module @test_distribution {
269269
gpu.func @load_gather(%src : memref<?xf16>) {
270270
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex>
271271
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<32x4xi1>
272-
// CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
272+
// CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
273+
// CHECK-SAME: : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
273274
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<0> : vector<256x16xindex>
274275
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<1> : vector<256x16xi1>
275276
%load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
@@ -283,21 +284,23 @@ gpu.module @test_distribution {
283284
// CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
284285
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
285286
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
286-
// CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
287+
// CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
288+
// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
287289
%val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
288290
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
289291
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
290-
xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
292+
xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
291293
: vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
292294
gpu.return
293295
}
294296

295-
// CHECK-LABEL: @load_with_chunk_size
297+
// CHECK-LABEL: @load_with_non_unit_chunk_size
296298
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
297-
gpu.func @load_with_chunk_size(%src : memref<?xf16>) {
299+
gpu.func @load_with_non_unit_chunk_size(%src : memref<?xf16>) {
298300
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
299301
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
300-
// CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
302+
// CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint<cached>}>
303+
// CHECK-SAME: : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
301304
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
302305
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
303306
%load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 4]>, l1_hint = #xegpu.cache_hint<cached>}

0 commit comments

Comments
 (0)