Skip to content

Commit 3c4a5aa

Browse files
committed
Address feedback
1 parent b4f5a4d commit 3c4a5aa

File tree

6 files changed

+163
-91
lines changed

6 files changed

+163
-91
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def XeGPU_WGLevel: I32EnumAttrCase<"WG", 0, "wg">;
171171
def XeGPU_SGLevel: I32EnumAttrCase<"SG", 1, "sg">;
172172
def XeGPU_WILevel: I32EnumAttrCase<"WI", 2, "wi">;
173173
def XeGPU_DistributionLevel: I32EnumAttr<"DistributionLevel",
174-
"The enumeration for the scope of fence operation.",
174+
"Specify target level for offsets distribution utility.",
175175
[XeGPU_WGLevel, XeGPU_SGLevel, XeGPU_WILevel]> {
176176
let genSpecializedAttr = 0;
177177
let cppNamespace = "::mlir::xegpu";
@@ -243,7 +243,7 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
243243
represents the higher-level problem size. Each `level` may access
244244
multiple blocks according to round-robin distribution rules.}],
245245
"FailureOr<SmallVector<SmallVector<Value>>>",
246-
"computeDistributedCoords",
246+
"computeDistributedOffsets",
247247
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape, "xegpu::DistributionLevel": $level)>,
248248
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
249249
to some other layout according to given permutation of (0...n-1).}],
@@ -496,7 +496,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
496496
/// represents the higher-level problem size. Each `level` may access
497497
/// multiple blocks according to round-robin distribution rules.
498498
FailureOr<SmallVector<SmallVector<Value>>>
499-
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
499+
computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
500500

501501
/// Check if this is slice of some other layout.
502502
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -661,7 +661,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
661661
/// multiple blocks according to round-robin distribution rules.
662662

663663
FailureOr<SmallVector<SmallVector<Value>>>
664-
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
664+
computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
665665

666666
/// Check if this is slice of some other layout.
667667
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ void XeGPUDialect::initialize() {
4141
// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
4242
// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
4343
// within each distribution unit.
44+
// Example:
45+
// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
46+
// distribution unit of shape 64x64, we have 2x4 such distribution units.
47+
// `delinearizedId` is used to identify a 16x32 of a subgroup in each
48+
// distribution unit.
4449
static SmallVector<SmallVector<Value>>
4550
genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
4651
ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
@@ -294,13 +299,13 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
294299
return affine::delinearizeIndex(builder, loc, linearId, dims);
295300
}
296301

297-
/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
302+
/// Implements DistributeLayoutAttr::computeDistributedOffsets to generate
298303
/// instructions for computing multi-dimensional offsets when distributed by
299304
/// LayoutAttr.
300305
FailureOr<SmallVector<SmallVector<Value>>>
301-
LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
302-
Value linearId, ArrayRef<int64_t> shape,
303-
xegpu::DistributionLevel targetLevel) {
306+
LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
307+
Value linearId, ArrayRef<int64_t> shape,
308+
xegpu::DistributionLevel targetLevel) {
304309
SmallVector<int64_t> layout;
305310
SmallVector<int64_t> subShape;
306311
if (targetLevel == DistributionLevel::SG) {
@@ -386,13 +391,13 @@ SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
386391
return parent.delinearizeId(builder, loc, linearId, level);
387392
}
388393

389-
// Implements DistributeLayoutAttr::computeDistributedCoords to generate
394+
// Implements DistributeLayoutAttr::computeDistributedOffsets to generate
390395
// instructions for computing multi-dimensional offsets when distributed by
391396
// LayoutAttr.
392397
FailureOr<SmallVector<SmallVector<Value>>>
393-
SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
394-
Value linearId, ArrayRef<int64_t> shape,
395-
xegpu::DistributionLevel targetLevel) {
398+
SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
399+
Value linearId, ArrayRef<int64_t> shape,
400+
xegpu::DistributionLevel targetLevel) {
396401
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
397402
if (!isForWorkgroup())
398403
return failure();

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 133 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -907,34 +907,48 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
907907
}
908908
};
909909

910-
template <class MatrixOp>
911-
struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
910+
static SmallVector<Value> computeDistributedOffsetsForMatrixOp(
911+
PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
912+
Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
913+
SmallVector<Value> newOffsets;
914+
;
915+
auto maybeDescOffsets = layout.computeDistributedOffsets(
916+
rewriter, loc, laneId, payloadShape, xegpu::DistributionLevel::WI);
917+
if (failed(maybeDescOffsets))
918+
return {};
919+
assert(maybeDescOffsets.value().size() == 1 &&
920+
"Expected one set of distributed offsets");
921+
SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
922+
rewriter, loc, getAsOpFoldResult(maybeDescOffsets.value()[0]),
923+
getAsOpFoldResult(origOffsets));
924+
newOffsets = llvm::to_vector(llvm::map_range(
925+
ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
926+
return newOffsets;
927+
}
928+
929+
/// Pattern for distributing xegpu::LoadMatrixOp.
930+
struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
912931
using gpu::WarpDistributionPattern::WarpDistributionPattern;
913932
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
914933
PatternRewriter &rewriter) const override {
915934
gpu::YieldOp yield = warpOp.getTerminator();
916935
Operation *lastNode = yield->getPrevNode();
917-
auto matrixOp = dyn_cast_or_null<MatrixOp>(lastNode);
936+
auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
918937
if (!matrixOp)
919938
return failure();
920-
constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
921-
int operandIdx{-1};
922-
923-
VectorType sgPayloadTy;
924-
VectorType warpResultTy;
925-
if constexpr (isLoad) {
926-
OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
927-
return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
928-
});
929-
if (!producedByLastLoad)
930-
return rewriter.notifyMatchFailure(
931-
warpOp, "The last op is not xegpu::LoadMatrixOp");
932-
operandIdx = producedByLastLoad->getOperandNumber();
933-
sgPayloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
934-
warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
935-
} else {
936-
sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
937-
}
939+
940+
OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
941+
return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
942+
});
943+
if (!producedByLastLoad)
944+
return rewriter.notifyMatchFailure(
945+
warpOp, "The last op is not xegpu::LoadMatrixOp");
946+
const int operandIdx = producedByLastLoad->getOperandNumber();
947+
948+
VectorType sgPayloadTy =
949+
dyn_cast<VectorType>(matrixOp.getResult().getType());
950+
VectorType warpResultTy =
951+
cast<VectorType>(warpOp.getResult(operandIdx).getType());
938952
if (!sgPayloadTy)
939953
return rewriter.notifyMatchFailure(
940954
matrixOp, "the matrix op payload must be a vector type");
@@ -956,21 +970,14 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
956970
getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
957971
if (failed(distPayloadByWarpOpOrFailure))
958972
return rewriter.notifyMatchFailure(
959-
matrixOp,
960-
"The matrix op payload has no layouts, using defaults instead.");
961-
962-
SmallVector<Value> operands;
963-
if constexpr (isLoad)
964-
operands = {matrixOp.getMemDesc()};
965-
else
966-
operands = {matrixOp.getData(), matrixOp.getMemDesc()};
973+
matrixOp, "The matrix op payload has no layout.");
974+
975+
SmallVector<Value> operands = {matrixOp.getMemDesc()};
967976
const unsigned offsetsStartIdx = operands.size();
968977
operands.append(offsetsAsValues);
969978

970979
SmallVector<Type> operandTypes = llvm::to_vector(
971980
llvm::map_range(operands, [](Value v) { return v.getType(); }));
972-
if constexpr (!isLoad)
973-
operandTypes[0] = *distPayloadByWarpOpOrFailure;
974981

975982
SmallVector<size_t> newRetIndices;
976983
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -986,40 +993,97 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
986993
ValueRange currentOffsets =
987994
ValueRange(newOperands).drop_front(offsetsStartIdx);
988995

989-
rewriter.setInsertionPointAfter(newWarpOp);
990996
SmallVector<Value> newOffsets = currentOffsets;
997+
rewriter.setInsertionPointAfter(newWarpOp);
998+
991999
if (!matrixOp.getSubgroupBlockIoAttr()) {
992-
auto maybeDescOffsets = layout.computeDistributedCoords(
993-
rewriter, loc, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
994-
xegpu::DistributionLevel::WI);
995-
if (failed(maybeDescOffsets))
996-
return failure();
997-
assert(maybeDescOffsets.value().size() == 1 &&
998-
"Expected same number of offset sets as number of accessed "
999-
"sub-tensors or sub-memory descriptors.");
1000-
SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
1001-
rewriter, loc, getAsOpFoldResult(maybeDescOffsets.value()[0]),
1002-
offsets);
1003-
newOffsets = llvm::to_vector(llvm::map_range(
1004-
ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
1000+
newOffsets = computeDistributedOffsetsForMatrixOp(
1001+
rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1002+
currentOffsets);
10051003
}
1004+
xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
1005+
rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
1006+
newOperands[0], ValueRange(newOffsets), newConstOffsetsAttr,
1007+
matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1008+
// Resolve the output type and replace all uses.
1009+
rewriter.replaceAllUsesWith(
1010+
newWarpOp.getResult(operandIdx),
1011+
resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
1012+
return success();
1013+
}
1014+
};
10061015

1007-
if constexpr (isLoad) {
1008-
xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
1009-
rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
1010-
newOperands[0], ValueRange(newOffsets), newConstOffsetsAttr,
1011-
matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1012-
// Resolve the output type and replace all uses.
1013-
rewriter.replaceAllUsesWith(
1014-
newWarpOp.getResult(operandIdx),
1015-
resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
1016-
} else {
1017-
xegpu::StoreMatrixOp::create(
1018-
rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
1019-
ValueRange(newOffsets), newConstOffsetsAttr,
1020-
matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1021-
rewriter.eraseOp(matrixOp);
1016+
/// Pattern for distributing xegpu::StoreMatrixOp.
1017+
struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
1018+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
1019+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1020+
PatternRewriter &rewriter) const override {
1021+
gpu::YieldOp yield = warpOp.getTerminator();
1022+
Operation *lastNode = yield->getPrevNode();
1023+
auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
1024+
if (!matrixOp)
1025+
return failure();
1026+
1027+
VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1028+
if (!sgPayloadTy)
1029+
return rewriter.notifyMatchFailure(
1030+
matrixOp, "the matrix op payload must be a vector type");
1031+
1032+
auto loc = matrixOp.getLoc();
1033+
auto offsets = matrixOp.getMixedOffsets();
1034+
if (offsets.empty())
1035+
return rewriter.notifyMatchFailure(matrixOp,
1036+
"the store op must have offsets");
1037+
SmallVector<Value> offsetsAsValues =
1038+
vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
1039+
1040+
auto layout = matrixOp.getLayoutAttr();
1041+
if (!layout)
1042+
return rewriter.notifyMatchFailure(
1043+
matrixOp, "the matrix operation lacks layout attribute");
1044+
1045+
FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1046+
getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
1047+
if (failed(distPayloadByWarpOpOrFailure))
1048+
return rewriter.notifyMatchFailure(
1049+
matrixOp, "The matrix op payload has no layout.");
1050+
1051+
SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1052+
const unsigned offsetsStartIdx = operands.size();
1053+
operands.append(offsetsAsValues);
1054+
1055+
SmallVector<Type> operandTypes = llvm::to_vector(
1056+
llvm::map_range(operands, [](Value v) { return v.getType(); }));
1057+
operandTypes[0] = *distPayloadByWarpOpOrFailure;
1058+
1059+
SmallVector<size_t> newRetIndices;
1060+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1061+
rewriter, warpOp, operands, operandTypes, newRetIndices);
1062+
SmallVector<Value> newOperands = llvm::map_to_vector(
1063+
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1064+
1065+
SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
1066+
std::fill(newConstOffsets.begin(), newConstOffsets.end(),
1067+
ShapedType::kDynamic);
1068+
DenseI64ArrayAttr newConstOffsetsAttr =
1069+
rewriter.getDenseI64ArrayAttr(newConstOffsets);
1070+
ValueRange currentOffsets =
1071+
ValueRange(newOperands).drop_front(offsetsStartIdx);
1072+
1073+
SmallVector<Value> newOffsets = currentOffsets;
1074+
rewriter.setInsertionPointAfter(newWarpOp);
1075+
1076+
if (!matrixOp.getSubgroupBlockIoAttr()) {
1077+
newOffsets = computeDistributedOffsetsForMatrixOp(
1078+
rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1079+
currentOffsets);
10221080
}
1081+
1082+
xegpu::StoreMatrixOp::create(
1083+
rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
1084+
ValueRange(newOffsets), newConstOffsetsAttr,
1085+
matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1086+
rewriter.eraseOp(matrixOp);
10231087
return success();
10241088
}
10251089
};
@@ -1551,16 +1615,15 @@ struct XeGPUSubgroupDistributePass final
15511615

15521616
void xegpu::populateXeGPUSubgroupDistributePatterns(
15531617
RewritePatternSet &patterns) {
1554-
patterns
1555-
.add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1556-
DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution,
1557-
VectorMultiReductionDistribution, LoadDistribution,
1558-
StoreDistribution, VectorTransposeDistribution,
1559-
VectorBitcastDistribution, MatrixOpDistribution<xegpu::LoadMatrixOp>,
1560-
MatrixOpDistribution<xegpu::StoreMatrixOp>,
1561-
MemrefExtractAlignedPointerAsIndexDistribution>(
1562-
patterns.getContext(),
1563-
/*pattern benefit=*/regularPatternBenefit);
1618+
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1619+
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1620+
GpuBarrierDistribution, VectorMultiReductionDistribution,
1621+
LoadDistribution, StoreDistribution, VectorTransposeDistribution,
1622+
VectorBitcastDistribution, LoadMatrixDistribution,
1623+
StoreMatrixDistribution,
1624+
MemrefExtractAlignedPointerAsIndexDistribution>(
1625+
patterns.getContext(),
1626+
/*pattern benefit=*/regularPatternBenefit);
15641627
patterns.add<VectorShapeCastDistribution>(
15651628
patterns.getContext(),
15661629
/*pattern benefit=*/highPatternBenefit);

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
114114
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
115115
// descriptors to be accessed, based on the layout information.
116116
ArrayRef<int64_t> wgShape = op.getDataShape();
117-
auto maybeDescOffsets = layout.computeDistributedCoords(
117+
auto maybeDescOffsets = layout.computeDistributedOffsets(
118118
rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
119119
if (failed(maybeDescOffsets))
120120
return failure();
@@ -831,7 +831,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
831831
// Get subgroup id
832832
Value sgId =
833833
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
834-
auto sgOffsets = layout.computeDistributedCoords(
834+
auto sgOffsets = layout.computeDistributedOffsets(
835835
rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
836836
if (failed(sgOffsets))
837837
return failure();
@@ -1053,7 +1053,7 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
10531053

10541054
Value sgId =
10551055
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1056-
auto sgOffsets = layout.computeDistributedCoords(
1056+
auto sgOffsets = layout.computeDistributedOffsets(
10571057
rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
10581058
if (failed(sgOffsets))
10591059
return failure();

0 commit comments

Comments
 (0)