Skip to content

Commit 51b2537

Browse files
dchigarevgithub-actions[bot]
authored andcommitted
Automerge: [mlir][XeGPU] Use DistributeLayoutAttr instead of LayoutAttr for load gather/scatter ops (#167850)
The PR changes the layout attribute type for `xegpu::LoadGatherOp/StoreScatterOp` from `LayoutAttr` to `DistributeLayoutAttr` to also support `xegpu.slice` layouts. Initially we [wanted to restrict slice layouts](llvm/llvm-project#163414 (comment)) from the attribute, but now it turns out there are actually valid use cases for that: ```mlir gpu.func @distribute_load_slice_attr() { %2 = memref.alloca() {alignment = 1024} : memref<4096xf32> %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex> %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1> %3 = xegpu.load %2[%offset], %mask <{chunk_size = 1, layout = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]>>} { layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32> %4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32> gpu.return } ``` Signed-off-by: dchigarev <[email protected]>
2 parents 4b85697 + cd5d5b3 commit 51b2537

File tree

5 files changed

+32
-13
lines changed

5 files changed

+32
-13
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
844844
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
845845
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
846846
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
847-
OptionalAttr<XeGPU_LayoutAttr>:$layout);
847+
OptionalAttr<DistributeLayoutAttr>:$layout);
848848
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
849849

850850
let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -903,7 +903,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
903903
"xegpu::CachePolicyAttr": $l1_hint,
904904
"xegpu::CachePolicyAttr": $l2_hint,
905905
"xegpu::CachePolicyAttr": $l3_hint,
906-
"xegpu::LayoutAttr": $layout)>
906+
"xegpu::DistributeLayoutAttr": $layout)>
907907
];
908908

909909
let hasVerifier = 1;
@@ -988,7 +988,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
988988
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
989989
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
990990
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
991-
OptionalAttr<XeGPU_LayoutAttr>:$layout);
991+
OptionalAttr<DistributeLayoutAttr>:$layout);
992992

993993
let extraClassDeclaration = extraBaseClassDeclaration#[{
994994
Type getDestType() {
@@ -1046,7 +1046,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
10461046
"xegpu::CachePolicyAttr": $l1_hint,
10471047
"xegpu::CachePolicyAttr": $l2_hint,
10481048
"xegpu::CachePolicyAttr": $l3_hint,
1049-
"xegpu::LayoutAttr": $layout)>
1049+
"xegpu::DistributeLayoutAttr": $layout)>
10501050
];
10511051

10521052
let hasVerifier = 1;

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
901901
IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
902902
xegpu::CachePolicyAttr l2_hint,
903903
xegpu::CachePolicyAttr l3_hint,
904-
xegpu::LayoutAttr layout) {
904+
DistributeLayoutAttr layout) {
905905
auto loc = source.getLoc();
906906
int64_t size = static_cast<int64_t>(offsets.size());
907907
auto type = VectorType::get(size, builder.getIndexType());
@@ -985,7 +985,7 @@ void StoreScatterOp::build(
985985
OpBuilder &builder, OperationState &state, Value value, Value dest,
986986
ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
987987
xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
988-
xegpu::CachePolicyAttr l3_hint, xegpu::LayoutAttr layout) {
988+
xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
989989
auto loc = dest.getLoc();
990990
int64_t size = static_cast<int64_t>(offsets.size());
991991
auto type = VectorType::get(size, builder.getIndexType());

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ struct UnrollLoadGatherOpWithOffset
678678
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
679679
}
680680

681-
auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
681+
auto layout = op.getLayoutAttr();
682682
if (layout)
683683
layout = layout.dropInstData();
684684

@@ -778,7 +778,7 @@ struct UnrollStoreScatterOpWithOffsets
778778
SmallVector<Value> convertedValues =
779779
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
780780

781-
auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
781+
auto layout = op.getLayoutAttr();
782782
if (layout)
783783
layout = layout.dropInstData();
784784

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -889,8 +889,8 @@ struct WgToSgLoadGatherOpWithOffset
889889
return failure();
890890
ArrayRef<int64_t> wgShape = resultType.getShape();
891891

892-
xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
893-
xegpu::getDistributeLayoutAttr(op.getResult()));
892+
xegpu::DistributeLayoutAttr layout =
893+
xegpu::getDistributeLayoutAttr(op.getResult());
894894
if (!layout || !layout.isForWorkgroup())
895895
return failure();
896896

@@ -913,10 +913,12 @@ struct WgToSgLoadGatherOpWithOffset
913913
VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
914914
for (auto [offsets, mask] :
915915
llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
916+
auto newLayout = layout.dropSgLayoutAndData();
916917
auto newLoadOp = xegpu::LoadGatherOp::create(
917918
rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
918919
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
919-
layout.dropSgLayoutAndData());
920+
newLayout);
921+
xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout);
920922
newLoadOps.push_back(newLoadOp);
921923
}
922924
rewriter.replaceOpWithMultiple(op, {newLoadOps});
@@ -941,8 +943,8 @@ struct WgToSgStoreScatterOpWithOffset
941943
if (!valueType)
942944
return failure();
943945

944-
xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
945-
xegpu::getDistributeLayoutAttr(op.getOperand(0)));
946+
xegpu::DistributeLayoutAttr layout =
947+
xegpu::getDistributeLayoutAttr(op.getOperand(0));
946948
if (!layout || !layout.isForWorkgroup())
947949
return failure();
948950

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,4 +547,21 @@ gpu.module @test_distribution {
547547
%broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex>
548548
gpu.return
549549
}
550+
551+
// CHECK-LABEL: distribute_load_slice_attr
552+
gpu.func @distribute_load_slice_attr() {
553+
%2 = memref.alloca() {alignment = 1024} : memref<4096xf32>
554+
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex>
555+
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1>
556+
557+
// CHECK: %[[LOAD:.*]] = xegpu.load {{.*}} <{chunk_size = 1 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>}>
558+
// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>} :
559+
// CHECK-SAME: memref<4096xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32>
560+
%3 = xegpu.load %2[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32>
561+
562+
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[LOAD]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<32xf32> to vector<32x32xf32>
563+
%4 = vector.broadcast %3 {layout_result_0 =
564+
#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
565+
gpu.return
566+
}
550567
}

0 commit comments

Comments
 (0)