Skip to content

Commit 8728eee

Browse files
committed
save work
1 parent 3c06f28 commit 8728eee

File tree

3 files changed

+215
-93
lines changed

3 files changed

+215
-93
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
2727
}];
2828
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
2929
"vector::VectorDialect"];
30+
let options = [Option<
31+
"enableSGReductions", "enable-sg-reductions", "bool",
32+
/*default=*/"true",
33+
"Enable subgroup reductions using subgroup shuffles.">];
3034
}
3135

3236
def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 98 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,24 @@ namespace {
5858
//===----------------------------------------------------------------------===//
5959
// SIMT Distribution Patterns
6060
//===----------------------------------------------------------------------===//
61+
static SmallVector<int64_t>
62+
computeEffectiveLaneLayout(const xegpu::DistributeLayoutAttr layout) {
63+
SmallVector<int64_t> effectiveLaneLayout;
64+
// If the layout is a slice, we need to get effective lane layout by removing
65+
// sliced dims.
66+
if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
67+
ArrayRef<int64_t> slicedDims = sliceAttr.flatten().getDims().asArrayRef();
68+
llvm::DenseSet<int64_t> lookUp(slicedDims.begin(), slicedDims.end());
69+
for (auto [i, dim] :
70+
llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) {
71+
if (!lookUp.contains(i))
72+
effectiveLaneLayout.push_back(dim);
73+
}
74+
} else {
75+
effectiveLaneLayout = cast<xegpu::LayoutAttr>(layout).getLaneLayoutAsInt();
76+
}
77+
return effectiveLaneLayout;
78+
}
6179

6280
/// Helper function to get distributed vector type for a source vector type
6381
/// according to the lane_layout. We simply divide each dimension of tensor
@@ -79,20 +97,7 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
7997
return failure();
8098
assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
8199
"Expecting a valid layout.");
82-
SmallVector<int64_t> effectiveLaneLayout;
83-
// If the layout is a slice, we need to get effective lane layout by removing
84-
// sliced dims.
85-
if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
86-
ArrayRef<int64_t> slicedDims = sliceAttr.flatten().getDims().asArrayRef();
87-
llvm::DenseSet<int64_t> lookUp(slicedDims.begin(), slicedDims.end());
88-
for (auto [i, dim] :
89-
llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) {
90-
if (!lookUp.contains(i))
91-
effectiveLaneLayout.push_back(dim);
92-
}
93-
} else {
94-
effectiveLaneLayout = cast<xegpu::LayoutAttr>(layout).getLaneLayoutAsInt();
95-
}
100+
SmallVector<int64_t> effectiveLaneLayout = computeEffectiveLaneLayout(layout);
96101

97102
assert(originalType.getShape().size() >= effectiveLaneLayout.size() &&
98103
"Rank of the original vector type should be greater or equal to the "
@@ -824,13 +829,64 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
824829
}
825830
};
826831

832+
/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
833+
/// VectorReductionOps.
834+
static Value lowerToVectorReductions(TypedValue<VectorType> src,
835+
TypedValue<VectorType> acc,
836+
vector::CombiningKind kind,
837+
int64_t reductionDim, Location loc,
838+
PatternRewriter &rewriter) {
839+
// Expecting a 2D source vector.
840+
assert(src.getType().getRank() == 2 && "expected a 2D source vector");
841+
VectorType sourceType = src.getType();
842+
int64_t sourceH = sourceType.getShape()[0];
843+
int64_t sourceW = sourceType.getShape()[1];
844+
int nSlices = (reductionDim == 0) ? sourceW : sourceH;
845+
// Create a constant vector to hold the result of the reduction.
846+
TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
847+
Value reductionResult = arith::ConstantOp::create(
848+
rewriter, loc, acc.getType(),
849+
DenseElementsAttr::get(acc.getType(), zeroAttr));
850+
// For each slice of the source, extract the slice vector, do a reduction
851+
// and, insert the reduced value back to the result vector.
852+
for (int i = 0; i < nSlices; ++i) {
853+
SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
854+
if (reductionDim == 1) {
855+
sliceOffsets = {i, 0};
856+
sliceSizes = {1, sourceW};
857+
} else {
858+
sliceOffsets = {0, i};
859+
sliceSizes = {sourceH, 1};
860+
}
861+
vector::ExtractStridedSliceOp extractOp =
862+
vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
863+
sliceSizes, {1, 1});
864+
int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
865+
Value slice = vector::ShapeCastOp::create(
866+
rewriter, loc,
867+
VectorType::get({nSliceElements}, sourceType.getElementType()),
868+
extractOp.getResult());
869+
Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
870+
Value reduction =
871+
vector::ReductionOp::create(rewriter, loc, kind, slice, accExtract);
872+
reductionResult =
873+
vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
874+
}
875+
return reductionResult;
876+
}
877+
827878
/// This patterns distribute the `vector.multi_reduction` operation across
828-
/// lanes in a warp. Currently only 2D to 1D reductions are supported and
829-
/// assumes that source vector is distributed in column dimension (i.e. Each
830-
/// lane owns complete column(s) of the source vector).
831-
/// TODO: Add support for the case where source rows are distributed across
832-
/// lanes. Requires `DistributionMapFn` to express the data distribution.
833-
/// Example 1 (Col reduction):
879+
/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
880+
/// layouts for the source and accumulator vectors,
881+
/// * If the reduction dimension is distributed across lanes, the reduction is
882+
/// non-lane-local and the reduction is done using warp shuffles. Here we
883+
/// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
884+
/// the warp op body.
885+
/// * If the reduction dimension is not distributed across lanes, the reduction
886+
/// is lane-local. In this case, we yield the source and accumulator vectors
887+
/// from the warp op and perform the lane-local reduction outside the warp op
888+
/// using a sequence of ReductionOps.
889+
/// Example 1 (Reduction is lane-local):
834890
/// ```
835891
/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
836892
/// %0 = "some_def"() : () -> (vector<16x32xf32>)
@@ -852,7 +908,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
852908
/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
853909
/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
854910
/// ```
855-
/// Example 2 (Row reduction):
911+
/// Example 2 (Reduction is non-lane-local):
856912
/// ```
857913
/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
858914
/// %0 = "some_def"() : () -> (vector<2x32xf32>)
@@ -900,7 +956,6 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
900956
VectorType distributedResultType =
901957
cast<VectorType>(warpOp.getResult(operandNumber).getType());
902958
VectorType resultType = cast<VectorType>(reductionOp.getType());
903-
Type elementType = distributedResultType.getElementType();
904959
xegpu::DistributeLayoutAttr sourceLayout =
905960
xegpu::getDistributeLayoutAttr(reductionOp.getSource());
906961

@@ -948,87 +1003,31 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
9481003
warpOp,
9491004
"Expecting a broadcasted result for non-lane-local reduction.");
9501005

951-
// Create a constant vector to store the result of the reduction per lane.
952-
rewriter.setInsertionPoint(warpOp);
953-
TypedAttr zeroAttr =
954-
rewriter.getZeroAttr(distributedResultType.getElementType());
955-
Value result = arith::ConstantOp::create(
956-
rewriter, reductionOp->getLoc(), distributedResultType,
957-
DenseElementsAttr::get(distributedResultType, zeroAttr));
958-
9591006
// Handle lane-local reduction case. In this case we fully distribute the
960-
// reduction.
1007+
// reduction result.
9611008
if (isReductionLaneLocal) {
9621009
// Yield the source and acc vectors from the WarpOp.
9631010
SmallVector<size_t> newRetIndices;
9641011
auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
9651012
rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
9661013
{sourceDistType, distributedResultType}, newRetIndices);
9671014
rewriter.setInsertionPointAfter(newWarpOp);
968-
969-
int nSlices = sourceDistType.getShape()[sourceDistDim];
970-
Value source = newWarpOp.getResult(newRetIndices[0]);
971-
Value acc = newWarpOp.getResult(newRetIndices[1]);
972-
// For each slice owned by a lane, extract the slice, shape cast to 1D, do
973-
// a vector.reduction and, insert the result back to the result vector.
974-
for (int i = 0; i < nSlices; ++i) {
975-
SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
976-
if (sourceDistDim == 0) {
977-
sliceOffsets = {i, 0};
978-
sliceSizes = {1, sourceDistType.getShape()[1]};
979-
} else {
980-
sliceOffsets = {0, i};
981-
sliceSizes = {sourceDistType.getShape()[0], 1};
982-
}
983-
Value col = vector::ExtractStridedSliceOp::create(
984-
rewriter, reductionOp.getLoc(), source, sliceOffsets, sliceSizes,
985-
{1, 1});
986-
int64_t col1DSize =
987-
sourceDistType.getShape()[sourceDistDim == 1 ? 0 : 1];
988-
col = vector::ShapeCastOp::create(
989-
rewriter, reductionOp.getLoc(),
990-
VectorType::get({col1DSize}, elementType), col);
991-
Value accCol =
992-
vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
993-
Value colReduce = vector::ReductionOp::create(
994-
rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
995-
result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
996-
colReduce, result, i);
997-
}
998-
// Replace the warp op result with the new reduction op.
999-
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
1015+
Value result = lowerToVectorReductions(
1016+
cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
1017+
cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
1018+
reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1019+
// Replace the warp op result with the final result.
1020+
rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
10001021
return success();
10011022
}
10021023
// For non-lane-local case, we simply rewrite the MultiReductionOp in terms
10031024
// of multiple ReductionOps. Actual distribution is done by the
10041025
// WarpOpReduction pattern.
10051026
rewriter.setInsertionPointAfter(reductionOp);
1006-
int nSlices = sourceType.getShape()[sourceDistDim == 0 ? 1 : 0];
1007-
// For each slice of the source, extract the slice vector, do a reduction
1008-
// and, insert the result back to the result.
1009-
for (int i = 0; i < nSlices; ++i) {
1010-
SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
1011-
if (sourceDistDim == 1) {
1012-
sliceOffsets = {i, 0};
1013-
sliceSizes = {1, sourceType.getShape()[1]};
1014-
} else {
1015-
sliceOffsets = {0, i};
1016-
sliceSizes = {sourceType.getShape()[0], 1};
1017-
}
1018-
Value col = vector::ExtractStridedSliceOp::create(
1019-
rewriter, reductionOp.getLoc(), reductionOp.getSource(), sliceOffsets,
1020-
sliceSizes, {1, 1});
1021-
int64_t col1DSize = sourceType.getShape()[sourceDistDim];
1022-
col = vector::ShapeCastOp::create(
1023-
rewriter, reductionOp.getLoc(),
1024-
VectorType::get({col1DSize}, elementType), col);
1025-
Value accCol = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
1026-
reductionOp.getAcc(), i);
1027-
Value colReduce = vector::ReductionOp::create(
1028-
rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
1029-
result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
1030-
colReduce, result, i);
1031-
}
1027+
Value result = lowerToVectorReductions(
1028+
cast<TypedValue<VectorType>>(reductionOp.getSource()),
1029+
cast<TypedValue<VectorType>>(reductionOp.getAcc()),
1030+
reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
10321031
// Replace the warp op result with the final result.
10331032
rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
10341033
return success();
@@ -1082,6 +1081,11 @@ namespace {
10821081
struct XeGPUSubgroupDistributePass final
10831082
: public xegpu::impl::XeGPUSubgroupDistributeBase<
10841083
XeGPUSubgroupDistributePass> {
1084+
XeGPUSubgroupDistributePass() = default;
1085+
XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
1086+
default;
1087+
XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
1088+
: XeGPUSubgroupDistributeBase(options) {}
10851089
void runOnOperation() override;
10861090
};
10871091
} // namespace
@@ -1150,16 +1154,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
11501154
if (vecRank == 0)
11511155
return AffineMap::get(val.getContext());
11521156
// Get the layout of the vector type.
1153-
// TODO: support more layout types
1154-
auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val);
1157+
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
11551158
// If no layout is specified, assume the inner most dimension is distributed
11561159
// for now.
11571160
if (!layout)
11581161
return AffineMap::getMultiDimMapWithTargets(
11591162
vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
11601163
SmallVector<unsigned int> distributedDims;
11611164
// Get the distributed dimensions based on the layout.
1162-
ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
1165+
SmallVector<int64_t> laneLayout = computeEffectiveLaneLayout(layout);
11631166
for (unsigned i = 0; i < laneLayout.size(); ++i) {
11641167
if (laneLayout[i] > 1)
11651168
distributedDims.push_back(i);
@@ -1188,7 +1191,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
11881191
return laneVal;
11891192
};
11901193

1191-
vector::populateDistributeReduction(patterns, warpReduction);
1194+
if (enableSGReductions)
1195+
vector::populateDistributeReduction(patterns, warpReduction);
1196+
11921197
vector::populatePropagateWarpVectorDistributionPatterns(
11931198
patterns, distributionFn, shuffleFn);
11941199
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {

0 commit comments

Comments
 (0)