Skip to content

Commit 6c563dc

Browse files
authored
[mlir][XeGPU] Add optional layout attribute to LoadGather StoreScatter ops (#163414)
As [suggested here](#163071 (comment)) the PR adds an optional layout attribute for `LoadGather` and `StoreScatter` ops. For the load-op the attribute describes the layout of the result (ex `layout_result_0`), and for store-op it describes the layout for the vector-to-store operand (ex `layout_operand_0`). The PR also reworks `propagate-layout` pass to consider perm layout attributes and back-propagate them accordingly. The helper utility function `getDistributeLayoutAttr` is reworked to return either `layout_operand/result_0` or `layout` for load/store ops (denepding on which one is set). After an offline discussion decided that the overall utilities layouts API is confusing since it tries to mix permament and temporary layouts. Would need to change it in the future. --------- Signed-off-by: dchigarev <[email protected]>
1 parent 71022d1 commit 6c563dc

File tree

10 files changed

+200
-33
lines changed

10 files changed

+200
-33
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
843843
AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
844844
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
845845
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
846-
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
846+
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
847+
OptionalAttr<XeGPU_LayoutAttr>:$layout);
847848
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
848849

849850
let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -895,7 +896,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
895896
"IntegerAttr": $chunk_size,
896897
"xegpu::CachePolicyAttr": $l1_hint,
897898
"xegpu::CachePolicyAttr": $l2_hint,
898-
"xegpu::CachePolicyAttr": $l3_hint)>
899+
"xegpu::CachePolicyAttr": $l3_hint)>,
900+
OpBuilder<(ins "Type": $value, "Value": $source,
901+
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
902+
"IntegerAttr": $chunk_size,
903+
"xegpu::CachePolicyAttr": $l1_hint,
904+
"xegpu::CachePolicyAttr": $l2_hint,
905+
"xegpu::CachePolicyAttr": $l3_hint,
906+
"xegpu::LayoutAttr": $layout)>
899907
];
900908

901909
let hasVerifier = 1;
@@ -979,7 +987,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
979987
AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
980988
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
981989
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
982-
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
990+
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
991+
OptionalAttr<XeGPU_LayoutAttr>:$layout);
983992

984993
let extraClassDeclaration = extraBaseClassDeclaration#[{
985994
Type getDestType() {
@@ -1030,7 +1039,14 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
10301039
"IntegerAttr": $chunk_size,
10311040
"xegpu::CachePolicyAttr": $l1_hint,
10321041
"xegpu::CachePolicyAttr": $l2_hint,
1033-
"xegpu::CachePolicyAttr": $l3_hint)>
1042+
"xegpu::CachePolicyAttr": $l3_hint)>,
1043+
OpBuilder<(ins "Value": $value, "Value": $dest,
1044+
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
1045+
"IntegerAttr": $chunk_size,
1046+
"xegpu::CachePolicyAttr": $l1_hint,
1047+
"xegpu::CachePolicyAttr": $l2_hint,
1048+
"xegpu::CachePolicyAttr": $l3_hint,
1049+
"xegpu::LayoutAttr": $layout)>
10341050
];
10351051

10361052
let hasVerifier = 1;

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,15 @@ void removeLayoutAttrs(Operation *op);
104104

105105
/// Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching
106106
/// it to the owner's dictionary attributes
107+
/// If `respectPermLayout` is true the existing permament layout
108+
/// attribute will be kept and assigned to the attribute dict instead
109+
/// of the provided layout.
107110
template <typename T,
108111
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
109112
std::is_same_v<T, OpResult>>>
110113
void setDistributeLayoutAttr(const T &operandOrResult,
111-
const DistributeLayoutAttr layout);
114+
const DistributeLayoutAttr layout,
115+
bool respectPermLayout = false);
112116

113117
/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given
114118
/// operation. If the operation contains regions, it is also applied recursively

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,8 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
457457
/*chunk_size=*/IntegerAttr{},
458458
/*l1_hint=*/xegpu::CachePolicyAttr{},
459459
/*l2_hint=*/xegpu::CachePolicyAttr{},
460-
/*l3_hint=*/xegpu::CachePolicyAttr{});
460+
/*l3_hint=*/xegpu::CachePolicyAttr{},
461+
/*layout=*/nullptr);
461462

462463
rewriter.replaceOp(readOp, gatherOp.getResult());
463464
return success();
@@ -491,7 +492,8 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
491492
/*chunk_size=*/IntegerAttr{},
492493
/*l1_hint=*/xegpu::CachePolicyAttr{},
493494
/*l2_hint=*/xegpu::CachePolicyAttr{},
494-
/*l3_hint=*/xegpu::CachePolicyAttr{});
495+
/*l3_hint=*/xegpu::CachePolicyAttr{},
496+
/*layout=*/nullptr);
495497
rewriter.eraseOp(writeOp);
496498
return success();
497499
}
@@ -646,7 +648,8 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
646648
/*chunk_size=*/IntegerAttr{},
647649
/*l1_hint=*/xegpu::CachePolicyAttr{},
648650
/*l2_hint=*/xegpu::CachePolicyAttr{},
649-
/*l3_hint=*/xegpu::CachePolicyAttr{});
651+
/*l3_hint=*/xegpu::CachePolicyAttr{},
652+
/*layout=*/nullptr);
650653

651654
auto selectOp =
652655
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
@@ -680,7 +683,8 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
680683
/*chunk_size=*/IntegerAttr{},
681684
/*l1_hint=*/xegpu::CachePolicyAttr{},
682685
/*l2_hint=*/xegpu::CachePolicyAttr{},
683-
/*l3_hint=*/xegpu::CachePolicyAttr{});
686+
/*l3_hint=*/xegpu::CachePolicyAttr{},
687+
/*layout=*/nullptr);
684688
rewriter.eraseOp(scatterOp);
685689
return success();
686690
}

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
876876
xegpu::CachePolicyAttr l2_hint,
877877
xegpu::CachePolicyAttr l3_hint) {
878878
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
879-
l1_hint, l2_hint, l3_hint);
879+
l1_hint, l2_hint, l3_hint, /*layout=*/nullptr);
880880
}
881881

882882
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -892,7 +892,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
892892
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
893893

894894
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
895-
l2_hint, l3_hint);
895+
l2_hint, l3_hint, /*layout=*/nullptr);
896+
}
897+
898+
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
899+
Type valueType, Value source,
900+
ArrayRef<OpFoldResult> offsets, Value mask,
901+
IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
902+
xegpu::CachePolicyAttr l2_hint,
903+
xegpu::CachePolicyAttr l3_hint,
904+
xegpu::LayoutAttr layout) {
905+
auto loc = source.getLoc();
906+
int64_t size = static_cast<int64_t>(offsets.size());
907+
auto type = VectorType::get(size, builder.getIndexType());
908+
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
909+
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
910+
911+
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
912+
l2_hint, l3_hint, layout);
896913
}
897914

898915
//===----------------------------------------------------------------------===//
@@ -943,7 +960,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
943960
xegpu::CachePolicyAttr l2_hint,
944961
xegpu::CachePolicyAttr l3_hint) {
945962
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
946-
l2_hint, l3_hint);
963+
l2_hint, l3_hint, /*layout=*/nullptr);
947964
}
948965

949966
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -961,7 +978,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
961978

962979
// Call the correct builder overload that does not expect result types.
963980
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
964-
l3_hint);
981+
l3_hint, /*layout=*/nullptr);
982+
}
983+
984+
void StoreScatterOp::build(
985+
OpBuilder &builder, OperationState &state, Value value, Value dest,
986+
ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
987+
xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
988+
xegpu::CachePolicyAttr l3_hint, xegpu::LayoutAttr layout) {
989+
auto loc = dest.getLoc();
990+
int64_t size = static_cast<int64_t>(offsets.size());
991+
auto type = VectorType::get(size, builder.getIndexType());
992+
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
993+
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
994+
995+
// Call the correct builder overload that does not expect result types.
996+
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
997+
l3_hint, layout);
965998
}
966999

9671000
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -904,9 +904,16 @@ void LayoutInfoPropagation::visitStoreScatterOp(
904904
if (dstTdescTy.getChunkSizeAsInt() > 1)
905905
instData.push_back(chunkSize);
906906
}
907-
LayoutInfo payloadLayout = getDefaultSIMTLayoutInfo(
908-
payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
909-
/*scattered=*/true);
907+
908+
LayoutInfo payloadLayout;
909+
910+
if (auto layout = storeScatter.getLayoutAttr()) {
911+
payloadLayout = LayoutInfo(layout);
912+
} else {
913+
payloadLayout = getDefaultSIMTLayoutInfo(
914+
payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
915+
/*scattered=*/true);
916+
}
910917

911918
LayoutInfo maskLayout =
912919
getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
@@ -1041,7 +1048,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
10411048
}
10421049
// If the result is a vector type, add a temporary layout attribute to the
10431050
// op.
1044-
xegpu::setDistributeLayoutAttr(result, layout);
1051+
xegpu::setDistributeLayoutAttr(result, layout, /*respectPermLayout*/ true);
10451052
}
10461053
return success();
10471054
}

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,12 +678,16 @@ struct UnrollLoadGatherOpWithOffset
678678
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
679679
}
680680

681+
auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
682+
if (layout)
683+
layout = layout.dropInstData();
684+
681685
SmallVector<Value> newOps;
682686
for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
683687
auto newOp = xegpu::LoadGatherOp::create(
684688
rewriter, loc, newValueTy, op.getSource(), o, m,
685689
rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
686-
op.getL2HintAttr(), op.getL3HintAttr());
690+
op.getL2HintAttr(), op.getL3HintAttr(), layout);
687691
newOps.push_back(newOp);
688692
}
689693

@@ -774,12 +778,16 @@ struct UnrollStoreScatterOpWithOffsets
774778
SmallVector<Value> convertedValues =
775779
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
776780

781+
auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
782+
if (layout)
783+
layout = layout.dropInstData();
784+
777785
for (auto [v, o, m] :
778786
llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
779787
xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
780788
rewriter.getI64IntegerAttr(chunkSize),
781789
op.getL1HintAttr(), op.getL2HintAttr(),
782-
op.getL3HintAttr());
790+
op.getL3HintAttr(), layout);
783791
}
784792

785793
rewriter.eraseOp(op);

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -889,8 +889,8 @@ struct WgToSgLoadGatherOpWithOffset
889889
return failure();
890890
ArrayRef<int64_t> wgShape = resultType.getShape();
891891

892-
xegpu::DistributeLayoutAttr layout =
893-
xegpu::getDistributeLayoutAttr(op.getResult());
892+
xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
893+
xegpu::getDistributeLayoutAttr(op.getResult()));
894894
if (!layout || !layout.isForWorkgroup())
895895
return failure();
896896

@@ -915,9 +915,8 @@ struct WgToSgLoadGatherOpWithOffset
915915
llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
916916
auto newLoadOp = xegpu::LoadGatherOp::create(
917917
rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
918-
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
919-
xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
920-
layout.dropSgLayoutAndData());
918+
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
919+
layout.dropSgLayoutAndData());
921920
newLoadOps.push_back(newLoadOp);
922921
}
923922
rewriter.replaceOpWithMultiple(op, {newLoadOps});
@@ -942,8 +941,8 @@ struct WgToSgStoreScatterOpWithOffset
942941
if (!valueType)
943942
return failure();
944943

945-
xegpu::DistributeLayoutAttr layout =
946-
xegpu::getDistributeLayoutAttr(op.getOperand(0));
944+
xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
945+
xegpu::getDistributeLayoutAttr(op.getOperand(0)));
947946
if (!layout || !layout.isForWorkgroup())
948947
return failure();
949948

@@ -965,7 +964,8 @@ struct WgToSgStoreScatterOpWithOffset
965964
adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
966965
auto store = xegpu::StoreScatterOp::create(
967966
rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
968-
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
967+
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
968+
layout.dropSgLayoutAndData());
969969
// Update the layout attribute to drop sg_layout and sg_data.
970970
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
971971
!layout.getEffectiveInstDataAsInt().empty()) {

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
144144
std::string layoutName = getLayoutName(result);
145145
if (defOp->hasAttr(layoutName))
146146
return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
147+
148+
// check for "permament" layout only after "temporary" layout name lookup
149+
// for backward compatibility
150+
if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
151+
return loadGatherOp.getLayoutAttr();
147152
}
148153

149154
if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -171,27 +176,77 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
171176
std::string layoutName = xegpu::getLayoutName(opr);
172177
if (op->hasAttr(layoutName))
173178
return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
179+
180+
// check for "permament" layout only after "temporary" layout name lookup
181+
if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
182+
if (auto layout = storeScatterOp.getLayoutAttr())
183+
return layout;
184+
174185
return getDistributeLayoutAttr(opr.get());
175186
}
176187

188+
// Returns the permanent layout attribute for the given result if it's
189+
// available on the defining op. Otherwise returns the provided layout.
190+
xegpu::DistributeLayoutAttr
191+
maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
192+
const OpResult &result, mlir::Operation *owner,
193+
const std::string &name) {
194+
xegpu::DistributeLayoutAttr candidate = layout;
195+
196+
if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) {
197+
if (auto perm = loadOp.getLayoutAttr())
198+
candidate = perm;
199+
}
200+
201+
return candidate;
202+
}
203+
204+
// Returns the permanent layout attribute for the given operand if it's
205+
// available on the defining op. Otherwise returns the provided layout.
206+
xegpu::DistributeLayoutAttr
207+
maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
208+
const OpOperand &operand, mlir::Operation *owner,
209+
const std::string &name) {
210+
xegpu::DistributeLayoutAttr candidate = layout;
211+
unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
212+
213+
if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) {
214+
if (idx == 0) {
215+
if (auto perm = storeOp.getLayoutAttr())
216+
candidate = perm;
217+
}
218+
}
219+
220+
return candidate;
221+
}
222+
177223
template <typename T, typename>
178224
void xegpu::setDistributeLayoutAttr(const T &operandOrResult,
179-
const DistributeLayoutAttr layout) {
225+
const DistributeLayoutAttr layout,
226+
bool respectPermLayout) {
180227
Operation *owner = operandOrResult.getOwner();
181228
std::string name = xegpu::getLayoutName(operandOrResult);
182-
if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
183-
owner->setAttr(name, layout);
229+
230+
if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
231+
return;
232+
233+
DistributeLayoutAttr candidate = layout;
234+
if (respectPermLayout)
235+
candidate = maybePickPermanentLayout(layout, operandOrResult, owner, name);
236+
237+
if (candidate)
238+
owner->setAttr(name, candidate);
184239
}
185240

186241
// Explicit instantiation for OpResult
187242
template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
188243
const mlir::OpResult &result,
189-
const mlir::xegpu::DistributeLayoutAttr layout);
244+
const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout);
190245

191246
// Explicit instantiation for OpOperand
192247
template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
193248
const mlir::OpOperand &operand,
194-
const mlir::xegpu::DistributeLayoutAttr layout);
249+
const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout);
195250

196251
void xegpu::setDistributeLayoutAttrs(
197252
Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {

0 commit comments

Comments
 (0)