Skip to content

Commit c3e5986

Browse files
committed
Add pattern for reduction
1 parent 58bf9ac commit c3e5986

File tree

2 files changed

+219
-5
lines changed

2 files changed

+219
-5
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 206 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,6 @@ struct UnrealizedConversionCastOpPattern
726726
}
727727
};
728728

729-
// This pattern distributes arith.constant op into subgroup-level constants
730729
struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
731730
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
732731

@@ -756,8 +755,15 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
756755
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
757756
auto cstOp =
758757
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
759-
if (auto newLayout = layout.dropSgLayoutAndData())
760-
xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
758+
if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
759+
if (sliceAttr.isForSubgroup())
760+
xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
761+
sliceAttr.dropSgLayoutAndData());
762+
} else if (auto layoutAttr =
763+
dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
764+
if (auto newLayout = layoutAttr.dropSgLayoutAndData())
765+
xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
766+
}
761767
SmallVector<Value> newConsts(count, cstOp);
762768

763769
rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -815,6 +821,191 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
815821
}
816822
};
817823

824+
// Pattern to distribute vector.multi_dim_reduction op to subgroup level.
825+
struct WgToSgMultiDimReductionOp
826+
: public OpConversionPattern<vector::MultiDimReductionOp> {
827+
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
828+
829+
LogicalResult
830+
matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
831+
ConversionPatternRewriter &rewriter) const override {
832+
Location loc = op.getLoc();
833+
// Only support reduction with layout and on a single dimension for now.
834+
VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
835+
VectorType accType = dyn_cast<VectorType>(op.getAcc().getType());
836+
VectorType resType = dyn_cast<VectorType>(op.getResult().getType());
837+
Type elemTy = srcType.getElementType();
838+
if (!srcType || !accType || !resType)
839+
return failure();
840+
841+
ArrayRef<int64_t> wgShape = resType.getShape();
842+
// Handle both LayoutAttr and SliceAttr for the op result.
843+
auto layoutName = xegpu::getLayoutName(op->getResult(0));
844+
auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
845+
if (!sliceAttr || sliceAttr.getRank() != 1)
846+
return failure();
847+
848+
SmallVector<int64_t> dims =
849+
llvm::to_vector(sliceAttr.getDims().asArrayRef());
850+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, sliceAttr).first;
851+
852+
int64_t reduceDim = dims[0];
853+
854+
// Step 1: Subgroup-level reduction
855+
// Each subgroup reduces its local tile.
856+
SmallVector<Value> newReductions;
857+
VectorType newType = VectorType::get(sgShape, srcType.getElementType());
858+
SmallVector<int64_t> shapeCastShape = sgShape;
859+
if (reduceDim == 0)
860+
shapeCastShape.insert(shapeCastShape.begin(), 1);
861+
else
862+
shapeCastShape.push_back(1);
863+
for (auto [sgSrc, sgAcc] :
864+
llvm::zip(adaptor.getSource(), adaptor.getAcc())) {
865+
auto sgReduce = rewriter.create<vector::MultiDimReductionOp>(
866+
op.getLoc(), newType, op.getKind(), sgSrc, sgAcc,
867+
op.getReductionDims());
868+
// Compute the shape for the shape cast: set reducedDim to 1, keep other
869+
// dims as sgShape
870+
auto shapeCastTy =
871+
VectorType::get(shapeCastShape, srcType.getElementType());
872+
auto shapeCast = rewriter.create<vector::ShapeCastOp>(
873+
op.getLoc(), shapeCastTy, sgReduce.getResult());
874+
// TODO: Change it to shapeCast
875+
newReductions.push_back(shapeCast.getResult());
876+
}
877+
878+
rewriter.setInsertionPoint(op);
879+
880+
// Get layout of the source tensor
881+
SmallVector<int64_t> sgLayoutParent =
882+
sliceAttr.getParent().getSgLayoutAsInt();
883+
884+
// Allocate SLM
885+
auto bitWidth = elemTy.getIntOrFloatBitWidth();
886+
auto flattenFactor = bitWidth / 8;
887+
auto slmSize =
888+
resType.getNumElements() * sgLayoutParent[reduceDim] * flattenFactor;
889+
auto slmTy = MemRefType::get(slmSize, rewriter.getI8Type(), {}, 3);
890+
auto slm = rewriter.create<memref::AllocaOp>(loc, slmTy);
891+
892+
// Create a view for the SLM buffer using xegpu.create_mem_desc
893+
SmallVector<int64_t> viewShape;
894+
auto srcVecType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
895+
ArrayRef<int64_t> srcShape =
896+
srcVecType ? srcVecType.getShape() : ArrayRef<int64_t>();
897+
for (size_t i = 0; i < srcShape.size(); ++i) {
898+
if (static_cast<int64_t>(i) == reduceDim) {
899+
// For the reduced dimension, use sgLayoutParent[i]
900+
viewShape.push_back(sgLayoutParent[i]);
901+
} else {
902+
// For other dimensions, multiply sgLayoutParent[i] by sgShape[i]
903+
viewShape.push_back(sgLayoutParent[i] * srcShape[i]);
904+
}
905+
}
906+
907+
auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), viewShape,
908+
elemTy, nullptr);
909+
auto memDesc =
910+
rewriter.create<xegpu::CreateMemDescOp>(loc, memDescType, slm);
911+
912+
// Step 2: Store subgroup results to SLM (shared local memory)
913+
// SLM layout: sgLayout same as srcLayout, sgData is shapeCastShape
914+
SmallVector<int64_t> slmSgData = shapeCastShape;
915+
916+
// Get subgroup id and delinearize
917+
auto sgId = rewriter.create<gpu::SubgroupIdOp>(loc, rewriter.getIndexType(),
918+
nullptr);
919+
920+
SmallVector<Value> srcSgLayoutDim(sgLayoutParent.size());
921+
922+
for (size_t i = 0; i < sgLayoutParent.size(); i++) {
923+
srcSgLayoutDim[i] =
924+
arith::ConstantIndexOp::create(rewriter, loc, sgLayoutParent[i]);
925+
}
926+
927+
auto sgIdVec =
928+
affine::delinearizeIndex(rewriter, loc, sgId, srcSgLayoutDim);
929+
if (failed(sgIdVec))
930+
return failure();
931+
SmallVector<Value> sgIds = *sgIdVec;
932+
933+
// Calculate offsets for store_matrix
934+
SmallVector<OpFoldResult> slmStoreOffsets;
935+
for (size_t i = 0; i < sgLayoutParent.size(); ++i) {
936+
Value offset = rewriter.createOrFold<index::MulOp>(
937+
loc, sgIds[i],
938+
arith::ConstantIndexOp::create(rewriter, loc, slmSgData[i]));
939+
slmStoreOffsets.push_back(offset);
940+
}
941+
942+
// Store subgroup result to SLM
943+
rewriter.create<xegpu::StoreMatrixOp>(
944+
loc, newReductions[0], memDesc.getResult(),
945+
ArrayRef<OpFoldResult>(slmStoreOffsets),
946+
/*layout=*/nullptr);
947+
948+
// Barrier to synchronize subgroups
949+
rewriter.create<gpu::BarrierOp>(loc);
950+
951+
// Step 3: Load from SLM for the second reduction
952+
SmallVector<int64_t> slmLoadShape;
953+
954+
for (size_t i = 0; i < viewShape.size(); ++i) {
955+
if (static_cast<int64_t>(i) == reduceDim) {
956+
slmLoadShape.push_back(viewShape[i]);
957+
} else {
958+
int64_t divisor = computeProduct(sgLayoutParent);
959+
slmLoadShape.push_back(viewShape[i] / divisor);
960+
}
961+
}
962+
963+
// Calculate offsets for create_nd_desc
964+
SmallVector<OpFoldResult> slmLoadOffsets;
965+
for (size_t i = 0; i < sgLayoutParent.size(); ++i) {
966+
Value offset = rewriter.createOrFold<index::MulOp>(
967+
loc, sgIds[i],
968+
arith::ConstantIndexOp::create(rewriter, loc, slmLoadShape[i]));
969+
slmLoadOffsets.push_back(offset);
970+
}
971+
972+
auto load = rewriter.create<xegpu::LoadMatrixOp>(
973+
loc, VectorType::get(slmLoadShape, elemTy), memDesc,
974+
llvm::ArrayRef<OpFoldResult>({slmLoadOffsets}),
975+
/*layout=*/nullptr);
976+
977+
// Step 4: Create a constant accumulator for the second reduction
978+
// with same vallue as adaptor.getAcc()[0] and shape set to
979+
// the non reduce dimension of shapeCastLoad
980+
auto accShape = load.getType().getShape();
981+
SmallVector<int64_t> accShapeWithoutReduceDim;
982+
for (size_t i = 0; i < accShape.size(); ++i) {
983+
if (static_cast<int64_t>(i) != reduceDim)
984+
accShapeWithoutReduceDim.push_back(accShape[i]);
985+
}
986+
auto accTy = VectorType::get(accShapeWithoutReduceDim, elemTy);
987+
auto accConstOp = adaptor.getAcc()[0].getDefiningOp<arith::ConstantOp>();
988+
Attribute accSplatValue;
989+
if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(
990+
accConstOp ? accConstOp.getValue() : nullptr)) {
991+
accSplatValue =
992+
denseAttr.isSplat() ? denseAttr.getSplatValue<Attribute>() : nullptr;
993+
}
994+
if (!accSplatValue)
995+
return failure();
996+
auto accValue = rewriter.create<arith::ConstantOp>(
997+
loc, accTy, DenseElementsAttr::get(accTy, accSplatValue));
998+
// Step 5: Perform the second reduction
999+
VectorType secondReduceVecType =
1000+
VectorType::get(accShapeWithoutReduceDim, srcType.getElementType());
1001+
auto secondReduce = rewriter.create<vector::MultiDimReductionOp>(
1002+
loc, secondReduceVecType, op.getKind(), load, accValue,
1003+
op.getReductionDims());
1004+
rewriter.replaceOpWithMultiple(op, {secondReduce.getResult()});
1005+
return success();
1006+
}
1007+
};
1008+
8181009
} // namespace
8191010

8201011
namespace mlir {
@@ -826,8 +1017,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
8261017
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
8271018
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
8281019
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
829-
WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
830-
patterns.getContext());
1020+
WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
1021+
WgToSgMultiDimReductionOp>(patterns.getContext());
8311022
}
8321023
} // namespace xegpu
8331024
} // namespace mlir
@@ -987,6 +1178,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
9871178
return isLegal(layout);
9881179
});
9891180

1181+
target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1182+
[=](vector::MultiDimReductionOp op) -> bool {
1183+
// Only allow MultiDimReductionOp with a single reduction dimension
1184+
if (op.getReductionDims().size() != 1)
1185+
return true;
1186+
1187+
// Check if the layout is legal
1188+
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
1189+
});
1190+
9901191
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
9911192
[=](UnrealizedConversionCastOp op) {
9921193
return llvm::is_contained(existingCastOps, op.getOperation());

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,4 +321,17 @@ gpu.module @test_distribution {
321321
xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
322322
gpu.return
323323
}
324+
325+
//CHECK-LABEL: vector_reduce
326+
gpu.func @vector_reduce(%src: memref<256x128xf32>) {
327+
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<1.0> : vector<128xf32>
328+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
329+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
330+
%load = xegpu.load_nd %tdesc
331+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
332+
-> vector<256x128xf32>
333+
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0]
334+
: vector<256x128xf32> to vector<128xf32>
335+
gpu.return
336+
}
324337
}

0 commit comments

Comments
 (0)