Skip to content

Commit b86fef8

Browse files
authored
[mlir][xegpu] Create a test pass for subgroup distribution. (#161592)
Current subgroup distribution test employ the entire `xegpu-subgroup-distribute` pass which include multiple steps like layout propagation, move func body into warp op, and distribute to work items. This makes it harder to isolate the testing for xegpu subgroup distribution logic, because certain corner cases may be not supported yet by other steps mentioned above. This PR introduces a test pass for subgroup distribution logic and isolate the testing for distribution logic. We plan to add more corner case (that were not possible before) covering non-xegpu ops (like vector) in next PRs. This PR also include, 1. minor bug fixes in gather/scatter distribution. 2. bug fix in vector multi reduction lowering where it fails to retain some layouts.
1 parent 99b0aaf commit b86fef8

File tree

5 files changed

+757
-489
lines changed

5 files changed

+757
-489
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
2727
}];
2828
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
2929
"vector::VectorDialect"];
30-
let options = [Option<
31-
"enableSGReductions", "enable-sg-reductions", "bool",
32-
/*default=*/"true",
33-
"Enable subgroup reductions using subgroup shuffles.">];
3430
}
3531

3632
def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -875,23 +875,29 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
875875
storeScatterOp,
876876
"Some vector operands have no layouts, using defaults instead.");
877877
}
878-
VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
879-
VectorType expectedPayloadTy = VectorType::get(
880-
{distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
878+
// Distributed store payload type according to the lane layout.
879+
VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value();
880+
// Expected distributed payload type is always 1D.
881+
VectorType expectedPayloadTy =
882+
VectorType::get({distPayloadTyByWarpOp.getNumElements()},
883+
distPayloadTyByWarpOp.getElementType());
881884

882885
SmallVector<size_t> newRetIndices;
883886
SmallVector<Value> operands = storeScatterOp->getOperands();
884887
SmallVector<Type> operandTypesToYield = {
885-
expectedPayloadTy, operands[1].getType(),
888+
distPayloadTyByWarpOp, operands[1].getType(),
886889
distOffsetsByWarpOpOrFailure.value(),
887890
distMaskByWarpOpOrFailure.value()};
888891

889892
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
890893
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
891894
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
892895
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
893-
896+
// The payload operand may need type adjustment due to mismatch between warp
897+
// distributed type and expected SIMT type.
894898
rewriter.setInsertionPointAfter(newWarpOp);
899+
newStoreScatterOpOperands[0] = resolveDistributedTy(
900+
newStoreScatterOpOperands[0], expectedPayloadTy, rewriter);
895901
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
896902
rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
897903
storeScatterOp->getAttrs());
@@ -976,8 +982,11 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
976982
distMaskByWarpOpOrFailure.value()};
977983

978984
const unsigned operandIdx = producedByLastLoad->getOperandNumber();
979-
VectorType loadVecTy =
985+
VectorType distResultTy =
980986
cast<VectorType>(warpOp.getResult(operandIdx).getType());
987+
// Distributed load op will always be 1D.
988+
VectorType loadVecTy = VectorType::get({distResultTy.getNumElements()},
989+
distResultTy.getElementType());
981990

982991
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
983992
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -991,13 +1000,16 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
9911000
loadGatherOp->getAttrs());
9921001
xegpu::removeLayoutAttrs(newOp);
9931002
Value distributedVal = newWarpOp.getResult(operandIdx);
994-
rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
1003+
// Resolve the output type and replace all uses.
1004+
rewriter.replaceAllUsesWith(
1005+
distributedVal,
1006+
resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
9951007
return success();
9961008
}
9971009
};
9981010

9991011
/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
1000-
/// VectorReductionOps.
1012+
/// VectorReductionOps. We also insert layouts for the newly created ops.
10011013
static Value lowerToVectorReductions(TypedValue<VectorType> src,
10021014
TypedValue<VectorType> acc,
10031015
vector::CombiningKind kind,
@@ -1014,6 +1026,9 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
10141026
Value reductionResult = arith::ConstantOp::create(
10151027
rewriter, loc, acc.getType(),
10161028
DenseElementsAttr::get(acc.getType(), zeroAttr));
1029+
// Reduction result should have the same layout as the accumulator.
1030+
xegpu::setDistributeLayoutAttr(cast<OpResult>(reductionResult),
1031+
xegpu::getDistributeLayoutAttr(acc));
10171032
// For each slice of the source, extract the slice vector, do a reduction
10181033
// and, insert the reduced value back to the result vector.
10191034
for (int i = 0; i < nSlices; ++i) {
@@ -1029,13 +1044,23 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
10291044
vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
10301045
sliceSizes, {1, 1});
10311046
int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1032-
Value slice = vector::ShapeCastOp::create(
1047+
vector::ShapeCastOp slice = vector::ShapeCastOp::create(
10331048
rewriter, loc,
10341049
VectorType::get({nSliceElements}, sourceType.getElementType()),
10351050
extractOp.getResult());
1051+
// Shape cast is currently handled in xegpu side. So layouts must be
1052+
// retained during lowering. Shape cast output has the same layout as the
1053+
// accumulator. Shape cast source has the same layout as the original
1054+
// reduction source.
1055+
// TODO: other ops generated here may also need layout attributes.
1056+
xegpu::setDistributeLayoutAttr(slice->getOpOperand(0),
1057+
xegpu::getDistributeLayoutAttr(src));
1058+
xegpu::setDistributeLayoutAttr(slice->getOpResult(0),
1059+
xegpu::getDistributeLayoutAttr(acc));
1060+
// Extract and reduction results in scalars, so no result layout is needed.
10361061
Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
1037-
Value reduction =
1038-
vector::ReductionOp::create(rewriter, loc, kind, slice, accExtract);
1062+
Value reduction = vector::ReductionOp::create(
1063+
rewriter, loc, kind, slice.getResult(), accExtract);
10391064
reductionResult =
10401065
vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
10411066
}
@@ -1107,7 +1132,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
11071132
return failure();
11081133
auto reductionOp =
11091134
cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
1110-
unsigned operandNumber = yieldOperand->getOperandNumber();
1135+
unsigned operandIdx = yieldOperand->getOperandNumber();
11111136
VectorType sourceType = reductionOp.getSourceVectorType();
11121137
// Only 2D vectors are supported.
11131138
if (sourceType.getRank() != 2)
@@ -1121,7 +1146,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
11211146
warpOp, "Only 1 reduction dimension is supported.");
11221147
int64_t reductionDim = reductionDims[0];
11231148
VectorType distributedResultType =
1124-
cast<VectorType>(warpOp.getResult(operandNumber).getType());
1149+
cast<VectorType>(warpOp.getResult(operandIdx).getType());
11251150
VectorType resultType = cast<VectorType>(reductionOp.getType());
11261151
xegpu::DistributeLayoutAttr sourceLayout =
11271152
xegpu::getDistributeLayoutAttr(reductionOp.getSource());
@@ -1184,7 +1209,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
11841209
cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
11851210
reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
11861211
// Replace the warp op result with the final result.
1187-
rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1212+
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), result);
11881213
return success();
11891214
}
11901215
// For non-lane-local case, we simply rewrite the MultiReductionOp in terms
@@ -1217,7 +1242,7 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
12171242
auto resultDistTy =
12181243
cast<VectorType>(warpOp.getResult(operandNumber).getType());
12191244
xegpu::DistributeLayoutAttr sourceLayout =
1220-
xegpu::getDistributeLayoutAttr(shapeCastOp.getSource());
1245+
xegpu::getDistributeLayoutAttr(shapeCastOp->getOpOperand(0));
12211246
xegpu::DistributeLayoutAttr resultLayout =
12221247
xegpu::getDistributeLayoutAttr(shapeCastOp.getResult());
12231248
if (!sourceLayout || !resultLayout)
@@ -1403,11 +1428,6 @@ namespace {
14031428
struct XeGPUSubgroupDistributePass final
14041429
: public xegpu::impl::XeGPUSubgroupDistributeBase<
14051430
XeGPUSubgroupDistributePass> {
1406-
XeGPUSubgroupDistributePass() = default;
1407-
XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
1408-
default;
1409-
XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
1410-
: XeGPUSubgroupDistributeBase(options) {}
14111431
void runOnOperation() override;
14121432
};
14131433
} // namespace
@@ -1515,10 +1535,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
15151535
return laneVal;
15161536
};
15171537

1518-
if (enableSGReductions)
1519-
vector::populateDistributeReduction(
1520-
patterns, warpReduction,
1521-
/*pattern benefit=*/regularPatternBenefit);
1538+
vector::populateDistributeReduction(
1539+
patterns, warpReduction,
1540+
/*pattern benefit=*/regularPatternBenefit);
15221541

15231542
vector::populatePropagateWarpVectorDistributionPatterns(
15241543
patterns, distributionFn, shuffleFn,

0 commit comments

Comments
 (0)