Skip to content

Commit 3c06f28

Browse files
committed
save work
1 parent 3dea80c commit 3c06f28

File tree

2 files changed

+179
-147
lines changed

2 files changed

+179
-147
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 179 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "llvm/ADT/ArrayRef.h"
3535
#include "llvm/ADT/STLExtras.h"
3636
#include "llvm/ADT/SmallVector.h"
37+
#include "llvm/Support/LogicalResult.h"
3738

3839
namespace mlir {
3940
namespace xegpu {
@@ -72,27 +73,43 @@ namespace {
7273
/// | 32x16 | [2, 8] | 16x2 |
7374
/// | 2x32x16 | [1, 16] | 2x32x1 |
7475
static FailureOr<VectorType>
75-
getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
76+
getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
7677
VectorType originalType) {
7778
if (!layout)
7879
return failure();
80+
assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
81+
"Expecting a valid layout.");
82+
SmallVector<int64_t> effectiveLaneLayout;
83+
// If the layout is a slice, we need to get effective lane layout by removing
84+
// sliced dims.
85+
if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
86+
ArrayRef<int64_t> slicedDims = sliceAttr.flatten().getDims().asArrayRef();
87+
llvm::DenseSet<int64_t> lookUp(slicedDims.begin(), slicedDims.end());
88+
for (auto [i, dim] :
89+
llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) {
90+
if (!lookUp.contains(i))
91+
effectiveLaneLayout.push_back(dim);
92+
}
93+
} else {
94+
effectiveLaneLayout = cast<xegpu::LayoutAttr>(layout).getLaneLayoutAsInt();
95+
}
7996

80-
auto laneLayout = layout.getLaneLayout().asArrayRef();
81-
assert(originalType.getShape().size() >= laneLayout.size() &&
97+
assert(originalType.getShape().size() >= effectiveLaneLayout.size() &&
8298
"Rank of the original vector type should be greater or equal to the "
8399
"size of the lane layout to distribute the vector type.");
84100
SmallVector<int64_t> distributedShape(originalType.getShape());
85101
// Only distribute the last `laneLayout.size()` dimensions. The remaining
86102
// dimensions are not distributed.
87-
unsigned distributionStart = originalType.getRank() - laneLayout.size();
103+
unsigned distributionStart =
104+
originalType.getRank() - effectiveLaneLayout.size();
88105
for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
89106
if (i < distributionStart)
90107
continue;
91108

92109
// Check if the dimension can be distributed evenly.
93-
if (dim % laneLayout[i - distributionStart] != 0)
110+
if (dim % effectiveLaneLayout[i - distributionStart] != 0)
94111
return failure();
95-
distributedShape[i] = dim / laneLayout[i - distributionStart];
112+
distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
96113
}
97114
return VectorType::get(distributedShape, originalType.getElementType());
98115
}
@@ -858,7 +875,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
858875
/// gpu.yield %1 : vector<2xf32>
859876
/// }
860877
struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
861-
using Base::Base;
878+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
862879
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
863880
PatternRewriter &rewriter) const override {
864881
OpOperand *yieldOperand =
@@ -869,83 +886,108 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
869886
cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
870887
unsigned operandNumber = yieldOperand->getOperandNumber();
871888
VectorType sourceType = reductionOp.getSourceVectorType();
872-
873889
// Only 2D vectors are supported.
874890
if (sourceType.getRank() != 2)
875891
return rewriter.notifyMatchFailure(warpOp,
876892
"Only 2D reductions are supported.");
877893
ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
878-
// Only 1 reduction dimension supported. This also ensures that result is
879-
// also vector type.
894+
// Only 1 reduction dimension supported. This also ensures that the result
895+
// is vector type.
880896
if (reductionDims.size() != 1)
881897
return rewriter.notifyMatchFailure(
882898
warpOp, "Only 1 reduction dimension is supported.");
883899
int64_t reductionDim = reductionDims[0];
884-
auto resultType = cast<VectorType>(reductionOp.getType());
885-
auto distributedResultType =
900+
VectorType distributedResultType =
886901
cast<VectorType>(warpOp.getResult(operandNumber).getType());
902+
VectorType resultType = cast<VectorType>(reductionOp.getType());
887903
Type elementType = distributedResultType.getElementType();
904+
xegpu::DistributeLayoutAttr sourceLayout =
905+
xegpu::getDistributeLayoutAttr(reductionOp.getSource());
888906

889-
// Currently we make the following assumptions.
890-
// 1. The source vector is distributed in the column dimension. Each lane
891-
// owns complete column(s) of the source vector.
892-
// 2. If the reduction dim == 0, its a lane-local col reduction. In this
893-
// case each lane owns its portion of the result (i.e. result is also
894-
// distributed).
895-
// 3. If reduction dim == 1, its a row reduction that require cross lanes
896-
// shuffles. In this case, the reduction result is not distributed across
897-
// lanes. Instead each lane owns a complete copy of the result
898-
// (broadcasted).
899-
// TODO: These assumptions are fairly restrictive. For example, source
900-
// vector can have row distributed layout. Improve support for such cases.
901-
if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
907+
FailureOr<VectorType> sourceDistTypeOrFailure =
908+
getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
909+
if (failed(sourceDistTypeOrFailure))
902910
return rewriter.notifyMatchFailure(
903-
warpOp, "Source vector dimension must be divisible by warp size.");
904-
bool isResultDistributed =
911+
warpOp, "Failed to distribute the source vector type.");
912+
VectorType sourceDistType = sourceDistTypeOrFailure.value();
913+
// Only single dimension distribution is supported.
914+
bool dim0Distributed =
915+
sourceDistType.getShape()[0] != sourceType.getShape()[0];
916+
bool dim1Distributed =
917+
sourceDistType.getShape()[1] != sourceType.getShape()[1];
918+
if (dim0Distributed && dim1Distributed)
919+
return rewriter.notifyMatchFailure(
920+
warpOp, "Expecting source to be distributed in a single dimension.");
921+
int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
922+
if (sourceDistDim == -1)
923+
return rewriter.notifyMatchFailure(
924+
warpOp, "Expecting a distributed source vector.");
925+
bool resultDistributed =
905926
distributedResultType.getNumElements() < resultType.getNumElements();
906-
if (reductionDim == 0 && !isResultDistributed)
927+
// If the lane owns all the data required for reduction (i.e. reduction is
928+
// fully parallel accross lanes), then each lane owns part of the result
929+
// (i.e. result is distributed). If the reduction require cross-lane
930+
// shuffling, then the result is shared among all lanes (broadcasted).
931+
// Therefore we expect following cases:
932+
//
933+
// | Source vector | Reduction dim | Result vector |
934+
// |----------------------|----------------|----------------|
935+
// | dim-0 distributed | 0 | broadcasted |
936+
// | dim-0 distributed | 1 | distributed |
937+
// | dim-1 distributed | 0 | distributed |
938+
// | dim-1 distributed | 1 | broadcasted |
939+
940+
bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
941+
(sourceDistDim == 1 && reductionDim == 0);
942+
if (isReductionLaneLocal && !resultDistributed)
907943
return rewriter.notifyMatchFailure(
908-
warpOp,
909-
"Expecting result vector to be distributed in a col reduction.");
910-
if (reductionDim == 1 && isResultDistributed)
944+
warpOp, "Expecting a distributed result for lane-local reduction.");
945+
946+
if (!isReductionLaneLocal && resultDistributed)
911947
return rewriter.notifyMatchFailure(
912948
warpOp,
913-
"Expecting result vector to be broadcasted in a row reduction.");
949+
"Expecting a broadcasted result for non-lane-local reduction.");
914950

915951
// Create a constant vector to store the result of the reduction per lane.
952+
rewriter.setInsertionPoint(warpOp);
916953
TypedAttr zeroAttr =
917954
rewriter.getZeroAttr(distributedResultType.getElementType());
918955
Value result = arith::ConstantOp::create(
919956
rewriter, reductionOp->getLoc(), distributedResultType,
920957
DenseElementsAttr::get(distributedResultType, zeroAttr));
921-
// Col reduction.
922-
if (reductionDim == 0) {
923-
// Compute source distributed type assuming each lane owns cols.
924-
SmallVector<int64_t> shape(sourceType.getShape());
925-
shape[1] = shape[1] / warpOp.getWarpSize();
926-
auto sourceDistributedType = VectorType::get(shape, elementType);
927958

959+
// Handle lane-local reduction case. In this case we fully distribute the
960+
// reduction.
961+
if (isReductionLaneLocal) {
928962
// Yield the source and acc vectors from the WarpOp.
929963
SmallVector<size_t> newRetIndices;
930964
auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
931965
rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
932-
{sourceDistributedType, distributedResultType}, newRetIndices);
966+
{sourceDistType, distributedResultType}, newRetIndices);
933967
rewriter.setInsertionPointAfter(newWarpOp);
934968

935-
int nCols = sourceDistributedType.getShape()[1];
969+
int nSlices = sourceDistType.getShape()[sourceDistDim];
936970
Value source = newWarpOp.getResult(newRetIndices[0]);
937971
Value acc = newWarpOp.getResult(newRetIndices[1]);
938-
// For each column owned by a lane, extract the column (of size nRows x
939-
// 1), shape cast to 1D (nRows), do a vector.reduction and, insert the
940-
// result back to the result vector.
941-
for (int i = 0; i < nCols; ++i) {
972+
// For each slice owned by a lane, extract the slice, shape cast to 1D, do
973+
// a vector.reduction and, insert the result back to the result vector.
974+
for (int i = 0; i < nSlices; ++i) {
975+
SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
976+
if (sourceDistDim == 0) {
977+
sliceOffsets = {i, 0};
978+
sliceSizes = {1, sourceDistType.getShape()[1]};
979+
} else {
980+
sliceOffsets = {0, i};
981+
sliceSizes = {sourceDistType.getShape()[0], 1};
982+
}
942983
Value col = vector::ExtractStridedSliceOp::create(
943-
rewriter, reductionOp.getLoc(), source, {0, i},
944-
{sourceDistributedType.getShape()[0], 1}, {1, 1});
984+
rewriter, reductionOp.getLoc(), source, sliceOffsets, sliceSizes,
985+
{1, 1});
986+
int64_t col1DSize =
987+
sourceDistType.getShape()[sourceDistDim == 1 ? 0 : 1];
945988
col = vector::ShapeCastOp::create(
946989
rewriter, reductionOp.getLoc(),
947-
VectorType::get({sourceDistributedType.getShape()[0]}, elementType),
948-
col);
990+
VectorType::get({col1DSize}, elementType), col);
949991
Value accCol =
950992
vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
951993
Value colReduce = vector::ReductionOp::create(
@@ -957,26 +999,79 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
957999
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
9581000
return success();
9591001
}
960-
// For row reductions, we simply rewrite the MultiReductionOp in terms of
961-
// multiple ReductionOps. Actual distribution is done by the WarpOpReduction
962-
// pattern.
1002+
// For non-lane-local case, we simply rewrite the MultiReductionOp in terms
1003+
// of multiple ReductionOps. Actual distribution is done by the
1004+
// WarpOpReduction pattern.
9631005
rewriter.setInsertionPointAfter(reductionOp);
964-
int nRows = sourceType.getShape()[0];
965-
// For each row of the source, extract the row vector, do a reduction and,
966-
// insert the result back to the result.
967-
for (int i = 0; i < nRows; ++i) {
968-
Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
969-
reductionOp.getSource(), i);
970-
Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
971-
reductionOp.getAcc(), i);
972-
Value rowReduce = vector::ReductionOp::create(
973-
rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc);
1006+
int nSlices = sourceType.getShape()[sourceDistDim == 0 ? 1 : 0];
1007+
// For each slice of the source, extract the slice vector, do a reduction
1008+
// and, insert the result back to the result.
1009+
for (int i = 0; i < nSlices; ++i) {
1010+
SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
1011+
if (sourceDistDim == 1) {
1012+
sliceOffsets = {i, 0};
1013+
sliceSizes = {1, sourceType.getShape()[1]};
1014+
} else {
1015+
sliceOffsets = {0, i};
1016+
sliceSizes = {sourceType.getShape()[0], 1};
1017+
}
1018+
Value col = vector::ExtractStridedSliceOp::create(
1019+
rewriter, reductionOp.getLoc(), reductionOp.getSource(), sliceOffsets,
1020+
sliceSizes, {1, 1});
1021+
int64_t col1DSize = sourceType.getShape()[sourceDistDim];
1022+
col = vector::ShapeCastOp::create(
1023+
rewriter, reductionOp.getLoc(),
1024+
VectorType::get({col1DSize}, elementType), col);
1025+
Value accCol = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
1026+
reductionOp.getAcc(), i);
1027+
Value colReduce = vector::ReductionOp::create(
1028+
rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
9741029
result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
975-
rowReduce, result, i);
1030+
colReduce, result, i);
9761031
}
9771032
// Replace the warp op result with the final result.
9781033
rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1034+
return success();
1035+
}
1036+
};
9791037

1038+
struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
1039+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
1040+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1041+
PatternRewriter &rewriter) const override {
1042+
OpOperand *yieldOperand =
1043+
getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1044+
if (!yieldOperand)
1045+
return failure();
1046+
auto shapeCastOp =
1047+
cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
1048+
unsigned operandNumber = yieldOperand->getOperandNumber();
1049+
auto resultDistTy =
1050+
cast<VectorType>(warpOp.getResult(operandNumber).getType());
1051+
xegpu::DistributeLayoutAttr sourceLayout =
1052+
xegpu::getDistributeLayoutAttr(shapeCastOp.getSource());
1053+
if (!sourceLayout)
1054+
return rewriter.notifyMatchFailure(
1055+
warpOp, "the source of shape_cast op lacks distribution layout");
1056+
FailureOr<VectorType> sourceDistTypeOrFailure =
1057+
getDistVecTypeBasedOnLaneLayout(sourceLayout,
1058+
shapeCastOp.getSourceVectorType());
1059+
if (failed(sourceDistTypeOrFailure))
1060+
return rewriter.notifyMatchFailure(
1061+
warpOp, "failed to get distributed vector type for source");
1062+
VectorType sourceDistType = sourceDistTypeOrFailure.value();
1063+
// Create a new warp op that yields the source of the shape_cast op.
1064+
SmallVector<size_t> newRetIndices;
1065+
auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1066+
rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1067+
newRetIndices);
1068+
rewriter.setInsertionPointAfter(newWarpOp);
1069+
Value source = newWarpOp.getResult(newRetIndices[0]);
1070+
// Create a new shape_cast op outside the warp op.
1071+
Value newShapeCast = vector::ShapeCastOp::create(
1072+
rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1073+
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1074+
newShapeCast);
9801075
return success();
9811076
}
9821077
};
@@ -998,6 +1093,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
9981093
DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
9991094
GpuBarrierDistribution, VectorMultiReductionDistribution>(
10001095
patterns.getContext());
1096+
patterns.add<VectorShapeCastDistribution>(patterns.getContext(),
1097+
/*benefit=*/2);
10011098
}
10021099

10031100
void XeGPUSubgroupDistributePass::runOnOperation() {
@@ -1012,8 +1109,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
10121109
if (!isa<VectorType>(operand.get().getType()))
10131110
continue;
10141111

1015-
auto layout =
1016-
xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
1112+
auto layout = xegpu::getDistributeLayoutAttr(operand.get());
10171113
if (!layout) {
10181114
op->emitError("Could not find layout attribute for operand ")
10191115
<< operand.getOperandNumber() << " of operation " << op->getName();
@@ -1074,6 +1170,25 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
10741170
// TODO: shuffleFn is not used.
10751171
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
10761172
int64_t warpSz) { return Value(); };
1173+
1174+
auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
1175+
vector::CombiningKind kind, uint32_t size) {
1176+
// First reduce on a single thread to get per lane reduction value.
1177+
Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
1178+
// Parallel reduction using butterfly shuffles.
1179+
for (uint64_t i = 1; i < size; i <<= 1) {
1180+
Value shuffled =
1181+
builder
1182+
.create<gpu::ShuffleOp>(loc, laneVal, i,
1183+
/*width=*/size,
1184+
/*mode=*/gpu::ShuffleMode::XOR)
1185+
.getShuffleResult();
1186+
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
1187+
}
1188+
return laneVal;
1189+
};
1190+
1191+
vector::populateDistributeReduction(patterns, warpReduction);
10771192
vector::populatePropagateWarpVectorDistributionPatterns(
10781193
patterns, distributionFn, shuffleFn);
10791194
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {

0 commit comments

Comments
 (0)