Skip to content

Commit feb4def

Browse files
committed
[mlir][XeGPU] Add optional layout attribute to LoadGather StoreScatter ops
Signed-off-by: dchigarev <[email protected]>
1 parent a321ce3 commit feb4def

File tree

5 files changed

+72
-16
lines changed

5 files changed

+72
-16
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
843843
AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
844844
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
845845
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
846-
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
846+
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
847+
OptionalAttr<DistributeLayoutAttr>:$layout);
847848
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
848849

849850
let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -895,7 +896,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
895896
"IntegerAttr": $chunk_size,
896897
"xegpu::CachePolicyAttr": $l1_hint,
897898
"xegpu::CachePolicyAttr": $l2_hint,
898-
"xegpu::CachePolicyAttr": $l3_hint)>
899+
"xegpu::CachePolicyAttr": $l3_hint)>,
900+
OpBuilder<(ins "Type": $value, "Value": $source,
901+
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
902+
"IntegerAttr": $chunk_size,
903+
"xegpu::CachePolicyAttr": $l1_hint,
904+
"xegpu::CachePolicyAttr": $l2_hint,
905+
"xegpu::CachePolicyAttr": $l3_hint,
906+
"xegpu::DistributeLayoutAttr": $layout)>
899907
];
900908

901909
let hasVerifier = 1;
@@ -979,7 +987,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
979987
AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
980988
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
981989
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
982-
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
990+
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
991+
OptionalAttr<DistributeLayoutAttr>:$layout);
983992

984993
let extraClassDeclaration = extraBaseClassDeclaration#[{
985994
Type getDestType() {
@@ -1030,7 +1039,14 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
10301039
"IntegerAttr": $chunk_size,
10311040
"xegpu::CachePolicyAttr": $l1_hint,
10321041
"xegpu::CachePolicyAttr": $l2_hint,
1033-
"xegpu::CachePolicyAttr": $l3_hint)>
1042+
"xegpu::CachePolicyAttr": $l3_hint)>,
1043+
OpBuilder<(ins "Value": $value, "Value": $dest,
1044+
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
1045+
"IntegerAttr": $chunk_size,
1046+
"xegpu::CachePolicyAttr": $l1_hint,
1047+
"xegpu::CachePolicyAttr": $l2_hint,
1048+
"xegpu::CachePolicyAttr": $l3_hint,
1049+
"xegpu::DistributeLayoutAttr": $layout)>
10341050
];
10351051

10361052
let hasVerifier = 1;

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,8 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
435435
/*chunk_size=*/IntegerAttr{},
436436
/*l1_hint=*/xegpu::CachePolicyAttr{},
437437
/*l2_hint=*/xegpu::CachePolicyAttr{},
438-
/*l3_hint=*/xegpu::CachePolicyAttr{});
438+
/*l3_hint=*/xegpu::CachePolicyAttr{},
439+
/*layout=*/nullptr);
439440

440441
rewriter.replaceOp(readOp, gatherOp.getResult());
441442
return success();
@@ -469,7 +470,8 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
469470
/*chunk_size=*/IntegerAttr{},
470471
/*l1_hint=*/xegpu::CachePolicyAttr{},
471472
/*l2_hint=*/xegpu::CachePolicyAttr{},
472-
/*l3_hint=*/xegpu::CachePolicyAttr{});
473+
/*l3_hint=*/xegpu::CachePolicyAttr{},
474+
/*layout=*/nullptr);
473475
rewriter.eraseOp(writeOp);
474476
return success();
475477
}
@@ -621,7 +623,8 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
621623
/*chunk_size=*/IntegerAttr{},
622624
/*l1_hint=*/xegpu::CachePolicyAttr{},
623625
/*l2_hint=*/xegpu::CachePolicyAttr{},
624-
/*l3_hint=*/xegpu::CachePolicyAttr{});
626+
/*l3_hint=*/xegpu::CachePolicyAttr{},
627+
/*layout=*/nullptr);
625628

626629
auto selectOp =
627630
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
@@ -655,7 +658,8 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
655658
/*chunk_size=*/IntegerAttr{},
656659
/*l1_hint=*/xegpu::CachePolicyAttr{},
657660
/*l2_hint=*/xegpu::CachePolicyAttr{},
658-
/*l3_hint=*/xegpu::CachePolicyAttr{});
661+
/*l3_hint=*/xegpu::CachePolicyAttr{},
662+
/*layout=*/nullptr);
659663
rewriter.eraseOp(scatterOp);
660664
return success();
661665
}

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
859859
xegpu::CachePolicyAttr l2_hint,
860860
xegpu::CachePolicyAttr l3_hint) {
861861
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
862-
l1_hint, l2_hint, l3_hint);
862+
l1_hint, l2_hint, l3_hint, /*layout=*/nullptr);
863863
}
864864

865865
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -875,7 +875,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
875875
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
876876

877877
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
878-
l2_hint, l3_hint);
878+
l2_hint, l3_hint, /*layout=*/nullptr);
879+
}
880+
881+
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
882+
Type valueType, Value source,
883+
ArrayRef<OpFoldResult> offsets, Value mask,
884+
IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
885+
xegpu::CachePolicyAttr l2_hint,
886+
xegpu::CachePolicyAttr l3_hint,
887+
DistributeLayoutAttr layout) {
888+
auto loc = source.getLoc();
889+
int64_t size = static_cast<int64_t>(offsets.size());
890+
auto type = VectorType::get(size, builder.getIndexType());
891+
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
892+
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
893+
894+
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
895+
l2_hint, l3_hint, layout);
879896
}
880897

881898
//===----------------------------------------------------------------------===//
@@ -926,7 +943,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
926943
xegpu::CachePolicyAttr l2_hint,
927944
xegpu::CachePolicyAttr l3_hint) {
928945
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
929-
l2_hint, l3_hint);
946+
l2_hint, l3_hint, /*layout=*/nullptr);
930947
}
931948

932949
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -944,7 +961,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
944961

945962
// Call the correct builder overload that does not expect result types.
946963
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
947-
l3_hint);
964+
l3_hint, /*layout=*/nullptr);
965+
}
966+
967+
void StoreScatterOp::build(
968+
OpBuilder &builder, OperationState &state, Value value, Value dest,
969+
ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
970+
xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
971+
xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
972+
auto loc = dest.getLoc();
973+
int64_t size = static_cast<int64_t>(offsets.size());
974+
auto type = VectorType::get(size, builder.getIndexType());
975+
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
976+
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
977+
978+
// Call the correct builder overload that does not expect result types.
979+
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
980+
l3_hint, layout);
948981
}
949982

950983
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ struct UnrollLoadGatherOpWithOffset
687687
auto newOp = xegpu::LoadGatherOp::create(
688688
rewriter, loc, newValueTy, op.getSource(), o, m,
689689
rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
690-
op.getL2HintAttr(), op.getL3HintAttr());
690+
op.getL2HintAttr(), op.getL3HintAttr(), /*layout*/ nullptr);
691691
newOps.push_back(newOp);
692692
}
693693

@@ -783,7 +783,7 @@ struct UnrollStoreScatterOpWithOffsets
783783
xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
784784
rewriter.getI64IntegerAttr(chunkSize),
785785
op.getL1HintAttr(), op.getL2HintAttr(),
786-
op.getL3HintAttr());
786+
op.getL3HintAttr(), /*layout*/ nullptr);
787787
}
788788

789789
rewriter.eraseOp(op);

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,8 @@ struct WgToSgLoadGatherOpWithOffset
914914
llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
915915
auto newLoadOp = xegpu::LoadGatherOp::create(
916916
rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
917-
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
917+
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
918+
/*layout*/ nullptr);
918919
xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
919920
layout.dropSgLayoutAndData());
920921
newLoadOps.push_back(newLoadOp);
@@ -962,9 +963,11 @@ struct WgToSgStoreScatterOpWithOffset
962963
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
963964
for (auto [val, offs, mask] : llvm::zip(
964965
adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
966+
965967
auto store = xegpu::StoreScatterOp::create(
966968
rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
967-
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
969+
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
970+
/*layout*/ nullptr);
968971
// Update the layout attribute to drop sg_layout and sg_data.
969972
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
970973
!layout.getEffectiveInstDataAsInt().empty()) {

0 commit comments

Comments
 (0)