@@ -58,6 +58,24 @@ namespace {
5858// ===----------------------------------------------------------------------===//
5959// SIMT Distribution Patterns
6060// ===----------------------------------------------------------------------===//
61+ static SmallVector<int64_t >
62+ computeEffectiveLaneLayout (const xegpu::DistributeLayoutAttr layout) {
63+ SmallVector<int64_t > effectiveLaneLayout;
64+ // If the layout is a slice, we need to get effective lane layout by removing
65+ // sliced dims.
66+ if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
67+ ArrayRef<int64_t > slicedDims = sliceAttr.flatten ().getDims ().asArrayRef ();
68+ llvm::DenseSet<int64_t > lookUp (slicedDims.begin (), slicedDims.end ());
69+ for (auto [i, dim] :
70+ llvm::enumerate (sliceAttr.getParent ().getLaneLayoutAsInt ())) {
71+ if (!lookUp.contains (i))
72+ effectiveLaneLayout.push_back (dim);
73+ }
74+ } else {
75+ effectiveLaneLayout = cast<xegpu::LayoutAttr>(layout).getLaneLayoutAsInt ();
76+ }
77+ return effectiveLaneLayout;
78+ }
6179
6280// / Helper function to get distributed vector type for a source vector type
6381// / according to the lane_layout. We simply divide each dimension of tensor
@@ -79,20 +97,7 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
7997 return failure ();
8098 assert ((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
8199 " Expecting a valid layout." );
82- SmallVector<int64_t > effectiveLaneLayout;
83- // If the layout is a slice, we need to get effective lane layout by removing
84- // sliced dims.
85- if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
86- ArrayRef<int64_t > slicedDims = sliceAttr.flatten ().getDims ().asArrayRef ();
87- llvm::DenseSet<int64_t > lookUp (slicedDims.begin (), slicedDims.end ());
88- for (auto [i, dim] :
89- llvm::enumerate (sliceAttr.getParent ().getLaneLayoutAsInt ())) {
90- if (!lookUp.contains (i))
91- effectiveLaneLayout.push_back (dim);
92- }
93- } else {
94- effectiveLaneLayout = cast<xegpu::LayoutAttr>(layout).getLaneLayoutAsInt ();
95- }
100+ SmallVector<int64_t > effectiveLaneLayout = computeEffectiveLaneLayout (layout);
96101
97102 assert (originalType.getShape ().size () >= effectiveLaneLayout.size () &&
98103 " Rank of the original vector type should be greater or equal to the "
@@ -824,13 +829,64 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
824829 }
825830};
826831
832+ // / Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
833+ // / VectorReductionOps.
834+ static Value lowerToVectorReductions (TypedValue<VectorType> src,
835+ TypedValue<VectorType> acc,
836+ vector::CombiningKind kind,
837+ int64_t reductionDim, Location loc,
838+ PatternRewriter &rewriter) {
839+ // Expecting a 2D source vector.
840+ assert (src.getType ().getRank () == 2 && " expected a 2D source vector" );
841+ VectorType sourceType = src.getType ();
842+ int64_t sourceH = sourceType.getShape ()[0 ];
843+ int64_t sourceW = sourceType.getShape ()[1 ];
844+ int nSlices = (reductionDim == 0 ) ? sourceW : sourceH;
845+ // Create a constant vector to hold the result of the reduction.
846+ TypedAttr zeroAttr = rewriter.getZeroAttr (sourceType.getElementType ());
847+ Value reductionResult = arith::ConstantOp::create (
848+ rewriter, loc, acc.getType (),
849+ DenseElementsAttr::get (acc.getType (), zeroAttr));
850+ // For each slice of the source, extract the slice vector, do a reduction
851+ // and, insert the reduced value back to the result vector.
852+ for (int i = 0 ; i < nSlices; ++i) {
853+ SmallVector<int64_t , 2 > sliceOffsets, sliceSizes;
854+ if (reductionDim == 1 ) {
855+ sliceOffsets = {i, 0 };
856+ sliceSizes = {1 , sourceW};
857+ } else {
858+ sliceOffsets = {0 , i};
859+ sliceSizes = {sourceH, 1 };
860+ }
861+ vector::ExtractStridedSliceOp extractOp =
862+ vector::ExtractStridedSliceOp::create (rewriter, loc, src, sliceOffsets,
863+ sliceSizes, {1 , 1 });
864+ int64_t nSliceElements = extractOp.getResult ().getType ().getNumElements ();
865+ Value slice = vector::ShapeCastOp::create (
866+ rewriter, loc,
867+ VectorType::get ({nSliceElements}, sourceType.getElementType ()),
868+ extractOp.getResult ());
869+ Value accExtract = vector::ExtractOp::create (rewriter, loc, acc, i);
870+ Value reduction =
871+ vector::ReductionOp::create (rewriter, loc, kind, slice, accExtract);
872+ reductionResult =
873+ vector::InsertOp::create (rewriter, loc, reduction, reductionResult, i);
874+ }
875+ return reductionResult;
876+ }
877+
827878// / This patterns distribute the `vector.multi_reduction` operation across
828- // / lanes in a warp. Currently only 2D to 1D reductions are supported and
829- // / assumes that source vector is distributed in column dimension (i.e. Each
830- // / lane owns complete column(s) of the source vector).
831- // / TODO: Add support for the case where source rows are distributed across
832- // / lanes. Requires `DistributionMapFn` to express the data distribution.
833- // / Example 1 (Col reduction):
879+ // / lanes in a warp. Currently only 2D to 1D reductions are supported. Given
880+ // / layouts for the source and accumulator vectors,
881+ // / * If the reduction dimension is distributed across lanes, the reduction is
882+ // / non-lane-local and the reduction is done using warp shuffles. Here we
883+ // / simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
884+ // / the warp op body.
885+ // / * If the reduction dimension is not distributed across lanes, the reduction
886+ // / is lane-local. In this case, we yield the source and accumulator vectors
887+ // / from the warp op and perform the lane-local reduction outside the warp op
888+ // / using a sequence of ReductionOps.
889+ // / Example 1 (Reduction is lane-local):
834890// / ```
835891// / %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
836892// / %0 = "some_def"() : () -> (vector<16x32xf32>)
@@ -852,7 +908,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
852908// / %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
853909// / %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
854910// / ```
855- // / Example 2 (Row reduction ):
911+ // / Example 2 (Reduction is non-lane-local ):
856912// / ```
857913// / %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
858914// / %0 = "some_def"() : () -> (vector<2x32xf32>)
@@ -900,7 +956,6 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
900956 VectorType distributedResultType =
901957 cast<VectorType>(warpOp.getResult (operandNumber).getType ());
902958 VectorType resultType = cast<VectorType>(reductionOp.getType ());
903- Type elementType = distributedResultType.getElementType ();
904959 xegpu::DistributeLayoutAttr sourceLayout =
905960 xegpu::getDistributeLayoutAttr (reductionOp.getSource ());
906961
@@ -948,87 +1003,31 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
9481003 warpOp,
9491004 " Expecting a broadcasted result for non-lane-local reduction." );
9501005
951- // Create a constant vector to store the result of the reduction per lane.
952- rewriter.setInsertionPoint (warpOp);
953- TypedAttr zeroAttr =
954- rewriter.getZeroAttr (distributedResultType.getElementType ());
955- Value result = arith::ConstantOp::create (
956- rewriter, reductionOp->getLoc (), distributedResultType,
957- DenseElementsAttr::get (distributedResultType, zeroAttr));
958-
9591006 // Handle lane-local reduction case. In this case we fully distribute the
960- // reduction.
1007+ // reduction result .
9611008 if (isReductionLaneLocal) {
9621009 // Yield the source and acc vectors from the WarpOp.
9631010 SmallVector<size_t > newRetIndices;
9641011 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
9651012 rewriter, warpOp, {reductionOp.getSource (), reductionOp.getAcc ()},
9661013 {sourceDistType, distributedResultType}, newRetIndices);
9671014 rewriter.setInsertionPointAfter (newWarpOp);
968-
969- int nSlices = sourceDistType.getShape ()[sourceDistDim];
970- Value source = newWarpOp.getResult (newRetIndices[0 ]);
971- Value acc = newWarpOp.getResult (newRetIndices[1 ]);
972- // For each slice owned by a lane, extract the slice, shape cast to 1D, do
973- // a vector.reduction and, insert the result back to the result vector.
974- for (int i = 0 ; i < nSlices; ++i) {
975- SmallVector<int64_t , 2 > sliceOffsets, sliceSizes;
976- if (sourceDistDim == 0 ) {
977- sliceOffsets = {i, 0 };
978- sliceSizes = {1 , sourceDistType.getShape ()[1 ]};
979- } else {
980- sliceOffsets = {0 , i};
981- sliceSizes = {sourceDistType.getShape ()[0 ], 1 };
982- }
983- Value col = vector::ExtractStridedSliceOp::create (
984- rewriter, reductionOp.getLoc (), source, sliceOffsets, sliceSizes,
985- {1 , 1 });
986- int64_t col1DSize =
987- sourceDistType.getShape ()[sourceDistDim == 1 ? 0 : 1 ];
988- col = vector::ShapeCastOp::create (
989- rewriter, reductionOp.getLoc (),
990- VectorType::get ({col1DSize}, elementType), col);
991- Value accCol =
992- vector::ExtractOp::create (rewriter, reductionOp.getLoc (), acc, i);
993- Value colReduce = vector::ReductionOp::create (
994- rewriter, reductionOp.getLoc (), reductionOp.getKind (), col, accCol);
995- result = vector::InsertOp::create (rewriter, reductionOp.getLoc (),
996- colReduce, result, i);
997- }
998- // Replace the warp op result with the new reduction op.
999- rewriter.replaceAllUsesWith (newWarpOp.getResult (operandNumber), result);
1015+ Value result = lowerToVectorReductions (
1016+ cast<TypedValue<VectorType>>(newWarpOp->getResult (newRetIndices[0 ])),
1017+ cast<TypedValue<VectorType>>(newWarpOp->getResult (newRetIndices[1 ])),
1018+ reductionOp.getKind (), reductionDim, reductionOp.getLoc (), rewriter);
1019+ // Replace the warp op result with the final result.
1020+ rewriter.replaceAllUsesWith (reductionOp.getResult (), result);
10001021 return success ();
10011022 }
10021023 // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
10031024 // of multiple ReductionOps. Actual distribution is done by the
10041025 // WarpOpReduction pattern.
10051026 rewriter.setInsertionPointAfter (reductionOp);
1006- int nSlices = sourceType.getShape ()[sourceDistDim == 0 ? 1 : 0 ];
1007- // For each slice of the source, extract the slice vector, do a reduction
1008- // and, insert the result back to the result.
1009- for (int i = 0 ; i < nSlices; ++i) {
1010- SmallVector<int64_t , 2 > sliceOffsets, sliceSizes;
1011- if (sourceDistDim == 1 ) {
1012- sliceOffsets = {i, 0 };
1013- sliceSizes = {1 , sourceType.getShape ()[1 ]};
1014- } else {
1015- sliceOffsets = {0 , i};
1016- sliceSizes = {sourceType.getShape ()[0 ], 1 };
1017- }
1018- Value col = vector::ExtractStridedSliceOp::create (
1019- rewriter, reductionOp.getLoc (), reductionOp.getSource (), sliceOffsets,
1020- sliceSizes, {1 , 1 });
1021- int64_t col1DSize = sourceType.getShape ()[sourceDistDim];
1022- col = vector::ShapeCastOp::create (
1023- rewriter, reductionOp.getLoc (),
1024- VectorType::get ({col1DSize}, elementType), col);
1025- Value accCol = vector::ExtractOp::create (rewriter, reductionOp.getLoc (),
1026- reductionOp.getAcc (), i);
1027- Value colReduce = vector::ReductionOp::create (
1028- rewriter, reductionOp.getLoc (), reductionOp.getKind (), col, accCol);
1029- result = vector::InsertOp::create (rewriter, reductionOp.getLoc (),
1030- colReduce, result, i);
1031- }
1027+ Value result = lowerToVectorReductions (
1028+ cast<TypedValue<VectorType>>(reductionOp.getSource ()),
1029+ cast<TypedValue<VectorType>>(reductionOp.getAcc ()),
1030+ reductionOp.getKind (), reductionDim, reductionOp.getLoc (), rewriter);
10321031 // Replace the warp op result with the final result.
10331032 rewriter.replaceAllUsesWith (reductionOp.getResult (), result);
10341033 return success ();
@@ -1082,6 +1081,11 @@ namespace {
10821081struct XeGPUSubgroupDistributePass final
10831082 : public xegpu::impl::XeGPUSubgroupDistributeBase<
10841083 XeGPUSubgroupDistributePass> {
1084+ XeGPUSubgroupDistributePass () = default ;
1085+ XeGPUSubgroupDistributePass (const XeGPUSubgroupDistributePass &other) =
1086+ default ;
1087+ XeGPUSubgroupDistributePass (xegpu::XeGPUSubgroupDistributeOptions options)
1088+ : XeGPUSubgroupDistributeBase(options) {}
10851089 void runOnOperation () override ;
10861090};
10871091} // namespace
@@ -1150,16 +1154,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
11501154 if (vecRank == 0 )
11511155 return AffineMap::get (val.getContext ());
11521156 // Get the layout of the vector type.
1153- // TODO: support more layout types
1154- auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val);
1157+ xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr (val);
11551158 // If no layout is specified, assume the inner most dimension is distributed
11561159 // for now.
11571160 if (!layout)
11581161 return AffineMap::getMultiDimMapWithTargets (
11591162 vecRank, {static_cast <unsigned int >(vecRank - 1 )}, val.getContext ());
11601163 SmallVector<unsigned int > distributedDims;
11611164 // Get the distributed dimensions based on the layout.
1162- ArrayRef< int > laneLayout = layout. getLaneLayout (). asArrayRef ( );
1165+ SmallVector< int64_t > laneLayout = computeEffectiveLaneLayout (layout );
11631166 for (unsigned i = 0 ; i < laneLayout.size (); ++i) {
11641167 if (laneLayout[i] > 1 )
11651168 distributedDims.push_back (i);
@@ -1188,7 +1191,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
11881191 return laneVal;
11891192 };
11901193
1191- vector::populateDistributeReduction (patterns, warpReduction);
1194+ if (enableSGReductions)
1195+ vector::populateDistributeReduction (patterns, warpReduction);
1196+
11921197 vector::populatePropagateWarpVectorDistributionPatterns (
11931198 patterns, distributionFn, shuffleFn);
11941199 if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
0 commit comments