3434#include " llvm/ADT/ArrayRef.h"
3535#include " llvm/ADT/STLExtras.h"
3636#include " llvm/ADT/SmallVector.h"
37+ #include " llvm/Support/LogicalResult.h"
3738
3839namespace mlir {
3940namespace xegpu {
@@ -72,27 +73,43 @@ namespace {
7273// / | 32x16 | [2, 8] | 16x2 |
7374// / | 2x32x16 | [1, 16] | 2x32x1 |
7475static FailureOr<VectorType>
75- getDistVecTypeBasedOnLaneLayout (xegpu::LayoutAttr layout,
76+ getDistVecTypeBasedOnLaneLayout (xegpu::DistributeLayoutAttr layout,
7677 VectorType originalType) {
7778 if (!layout)
7879 return failure ();
80+ assert ((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
81+ " Expecting a valid layout." );
82+ SmallVector<int64_t > effectiveLaneLayout;
83+ // If the layout is a slice, we need to get effective lane layout by removing
84+ // sliced dims.
85+ if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
86+ ArrayRef<int64_t > slicedDims = sliceAttr.flatten ().getDims ().asArrayRef ();
87+ llvm::DenseSet<int64_t > lookUp (slicedDims.begin (), slicedDims.end ());
88+ for (auto [i, dim] :
89+ llvm::enumerate (sliceAttr.getParent ().getLaneLayoutAsInt ())) {
90+ if (!lookUp.contains (i))
91+ effectiveLaneLayout.push_back (dim);
92+ }
93+ } else {
94+ effectiveLaneLayout = cast<xegpu::LayoutAttr>(layout).getLaneLayoutAsInt ();
95+ }
7996
80- auto laneLayout = layout.getLaneLayout ().asArrayRef ();
81- assert (originalType.getShape ().size () >= laneLayout.size () &&
97+ assert (originalType.getShape ().size () >= effectiveLaneLayout.size () &&
8298 " Rank of the original vector type should be greater or equal to the "
8399 " size of the lane layout to distribute the vector type." );
84100 SmallVector<int64_t > distributedShape (originalType.getShape ());
85101 // Only distribute the last `laneLayout.size()` dimensions. The remaining
86102 // dimensions are not distributed.
87- unsigned distributionStart = originalType.getRank () - laneLayout.size ();
103+ unsigned distributionStart =
104+ originalType.getRank () - effectiveLaneLayout.size ();
88105 for (auto [i, dim] : llvm::enumerate (originalType.getShape ())) {
89106 if (i < distributionStart)
90107 continue ;
91108
92109 // Check if the dimension can be distributed evenly.
93- if (dim % laneLayout [i - distributionStart] != 0 )
110+ if (dim % effectiveLaneLayout [i - distributionStart] != 0 )
94111 return failure ();
95- distributedShape[i] = dim / laneLayout [i - distributionStart];
112+ distributedShape[i] = dim / effectiveLaneLayout [i - distributionStart];
96113 }
97114 return VectorType::get (distributedShape, originalType.getElementType ());
98115}
@@ -858,7 +875,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
858875// / gpu.yield %1 : vector<2xf32>
859876// / }
860877struct VectorMultiReductionDistribution : public gpu ::WarpDistributionPattern {
861- using Base::Base ;
878+ using gpu::WarpDistributionPattern::WarpDistributionPattern ;
862879 LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
863880 PatternRewriter &rewriter) const override {
864881 OpOperand *yieldOperand =
@@ -869,83 +886,108 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
869886 cast<vector::MultiDimReductionOp>(yieldOperand->get ().getDefiningOp ());
870887 unsigned operandNumber = yieldOperand->getOperandNumber ();
871888 VectorType sourceType = reductionOp.getSourceVectorType ();
872-
873889 // Only 2D vectors are supported.
874890 if (sourceType.getRank () != 2 )
875891 return rewriter.notifyMatchFailure (warpOp,
876892 " Only 2D reductions are supported." );
877893 ArrayRef<int64_t > reductionDims = reductionOp.getReductionDims ();
878- // Only 1 reduction dimension supported. This also ensures that result is
879- // also vector type.
894+ // Only 1 reduction dimension supported. This also ensures that the result
895+ // is vector type.
880896 if (reductionDims.size () != 1 )
881897 return rewriter.notifyMatchFailure (
882898 warpOp, " Only 1 reduction dimension is supported." );
883899 int64_t reductionDim = reductionDims[0 ];
884- auto resultType = cast<VectorType>(reductionOp.getType ());
885- auto distributedResultType =
900+ VectorType distributedResultType =
886901 cast<VectorType>(warpOp.getResult (operandNumber).getType ());
902+ VectorType resultType = cast<VectorType>(reductionOp.getType ());
887903 Type elementType = distributedResultType.getElementType ();
904+ xegpu::DistributeLayoutAttr sourceLayout =
905+ xegpu::getDistributeLayoutAttr (reductionOp.getSource ());
888906
889- // Currently we make the following assumptions.
890- // 1. The source vector is distributed in the column dimension. Each lane
891- // owns complete column(s) of the source vector.
892- // 2. If the reduction dim == 0, its a lane-local col reduction. In this
893- // case each lane owns its portion of the result (i.e. result is also
894- // distributed).
895- // 3. If reduction dim == 1, its a row reduction that require cross lanes
896- // shuffles. In this case, the reduction result is not distributed across
897- // lanes. Instead each lane owns a complete copy of the result
898- // (broadcasted).
899- // TODO: These assumptions are fairly restrictive. For example, source
900- // vector can have row distributed layout. Improve support for such cases.
901- if (sourceType.getShape ()[1 ] % warpOp.getWarpSize () != 0 )
907+ FailureOr<VectorType> sourceDistTypeOrFailure =
908+ getDistVecTypeBasedOnLaneLayout (sourceLayout, sourceType);
909+ if (failed (sourceDistTypeOrFailure))
902910 return rewriter.notifyMatchFailure (
903- warpOp, " Source vector dimension must be divisible by warp size." );
904- bool isResultDistributed =
911+ warpOp, " Failed to distribute the source vector type." );
912+ VectorType sourceDistType = sourceDistTypeOrFailure.value ();
913+ // Only single dimension distribution is supported.
914+ bool dim0Distributed =
915+ sourceDistType.getShape ()[0 ] != sourceType.getShape ()[0 ];
916+ bool dim1Distributed =
917+ sourceDistType.getShape ()[1 ] != sourceType.getShape ()[1 ];
918+ if (dim0Distributed && dim1Distributed)
919+ return rewriter.notifyMatchFailure (
920+ warpOp, " Expecting source to be distributed in a single dimension." );
921+ int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1 );
922+ if (sourceDistDim == -1 )
923+ return rewriter.notifyMatchFailure (
924+ warpOp, " Expecting a distributed source vector." );
925+ bool resultDistributed =
905926 distributedResultType.getNumElements () < resultType.getNumElements ();
906- if (reductionDim == 0 && !isResultDistributed)
927+ // If the lane owns all the data required for reduction (i.e. reduction is
928+ // fully parallel accross lanes), then each lane owns part of the result
929+ // (i.e. result is distributed). If the reduction require cross-lane
930+ // shuffling, then the result is shared among all lanes (broadcasted).
931+ // Therefore we expect following cases:
932+ //
933+ // | Source vector | Reduction dim | Result vector |
934+ // |----------------------|----------------|----------------|
935+ // | dim-0 distributed | 0 | broadcasted |
936+ // | dim-0 distributed | 1 | distributed |
937+ // | dim-1 distributed | 0 | distributed |
938+ // | dim-1 distributed | 1 | broadcasted |
939+
940+ bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1 ) ||
941+ (sourceDistDim == 1 && reductionDim == 0 );
942+ if (isReductionLaneLocal && !resultDistributed)
907943 return rewriter.notifyMatchFailure (
908- warpOp,
909- " Expecting result vector to be distributed in a col reduction. " );
910- if (reductionDim == 1 && isResultDistributed )
944+ warpOp, " Expecting a distributed result for lane-local reduction. " );
945+
946+ if (!isReductionLaneLocal && resultDistributed )
911947 return rewriter.notifyMatchFailure (
912948 warpOp,
913- " Expecting result vector to be broadcasted in a row reduction." );
949+ " Expecting a broadcasted result for non-lane-local reduction." );
914950
915951 // Create a constant vector to store the result of the reduction per lane.
952+ rewriter.setInsertionPoint (warpOp);
916953 TypedAttr zeroAttr =
917954 rewriter.getZeroAttr (distributedResultType.getElementType ());
918955 Value result = arith::ConstantOp::create (
919956 rewriter, reductionOp->getLoc (), distributedResultType,
920957 DenseElementsAttr::get (distributedResultType, zeroAttr));
921- // Col reduction.
922- if (reductionDim == 0 ) {
923- // Compute source distributed type assuming each lane owns cols.
924- SmallVector<int64_t > shape (sourceType.getShape ());
925- shape[1 ] = shape[1 ] / warpOp.getWarpSize ();
926- auto sourceDistributedType = VectorType::get (shape, elementType);
927958
959+ // Handle lane-local reduction case. In this case we fully distribute the
960+ // reduction.
961+ if (isReductionLaneLocal) {
928962 // Yield the source and acc vectors from the WarpOp.
929963 SmallVector<size_t > newRetIndices;
930964 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
931965 rewriter, warpOp, {reductionOp.getSource (), reductionOp.getAcc ()},
932- {sourceDistributedType , distributedResultType}, newRetIndices);
966+ {sourceDistType , distributedResultType}, newRetIndices);
933967 rewriter.setInsertionPointAfter (newWarpOp);
934968
935- int nCols = sourceDistributedType .getShape ()[1 ];
969+ int nSlices = sourceDistType .getShape ()[sourceDistDim ];
936970 Value source = newWarpOp.getResult (newRetIndices[0 ]);
937971 Value acc = newWarpOp.getResult (newRetIndices[1 ]);
938- // For each column owned by a lane, extract the column (of size nRows x
939- // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the
940- // result back to the result vector.
941- for (int i = 0 ; i < nCols; ++i) {
972+ // For each slice owned by a lane, extract the slice, shape cast to 1D, do
973+ // a vector.reduction and, insert the result back to the result vector.
974+ for (int i = 0 ; i < nSlices; ++i) {
975+ SmallVector<int64_t , 2 > sliceOffsets, sliceSizes;
976+ if (sourceDistDim == 0 ) {
977+ sliceOffsets = {i, 0 };
978+ sliceSizes = {1 , sourceDistType.getShape ()[1 ]};
979+ } else {
980+ sliceOffsets = {0 , i};
981+ sliceSizes = {sourceDistType.getShape ()[0 ], 1 };
982+ }
942983 Value col = vector::ExtractStridedSliceOp::create (
943- rewriter, reductionOp.getLoc (), source, {0 , i},
944- {sourceDistributedType.getShape ()[0 ], 1 }, {1 , 1 });
984+ rewriter, reductionOp.getLoc (), source, sliceOffsets, sliceSizes,
985+ {1 , 1 });
986+ int64_t col1DSize =
987+ sourceDistType.getShape ()[sourceDistDim == 1 ? 0 : 1 ];
945988 col = vector::ShapeCastOp::create (
946989 rewriter, reductionOp.getLoc (),
947- VectorType::get ({sourceDistributedType.getShape ()[0 ]}, elementType),
948- col);
990+ VectorType::get ({col1DSize}, elementType), col);
949991 Value accCol =
950992 vector::ExtractOp::create (rewriter, reductionOp.getLoc (), acc, i);
951993 Value colReduce = vector::ReductionOp::create (
@@ -957,26 +999,79 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
957999 rewriter.replaceAllUsesWith (newWarpOp.getResult (operandNumber), result);
9581000 return success ();
9591001 }
960- // For row reductions , we simply rewrite the MultiReductionOp in terms of
961- // multiple ReductionOps. Actual distribution is done by the WarpOpReduction
962- // pattern.
1002+ // For non-lane-local case , we simply rewrite the MultiReductionOp in terms
1003+ // of multiple ReductionOps. Actual distribution is done by the
1004+ // WarpOpReduction pattern.
9631005 rewriter.setInsertionPointAfter (reductionOp);
964- int nRows = sourceType.getShape ()[0 ];
965- // For each row of the source, extract the row vector, do a reduction and,
966- // insert the result back to the result.
967- for (int i = 0 ; i < nRows; ++i) {
968- Value source = vector::ExtractOp::create (rewriter, reductionOp.getLoc (),
969- reductionOp.getSource (), i);
970- Value acc = vector::ExtractOp::create (rewriter, reductionOp.getLoc (),
971- reductionOp.getAcc (), i);
972- Value rowReduce = vector::ReductionOp::create (
973- rewriter, reductionOp.getLoc (), reductionOp.getKind (), source, acc);
1006+ int nSlices = sourceType.getShape ()[sourceDistDim == 0 ? 1 : 0 ];
1007+ // For each slice of the source, extract the slice vector, do a reduction
1008+ // and, insert the result back to the result.
1009+ for (int i = 0 ; i < nSlices; ++i) {
1010+ SmallVector<int64_t , 2 > sliceOffsets, sliceSizes;
1011+ if (sourceDistDim == 1 ) {
1012+ sliceOffsets = {i, 0 };
1013+ sliceSizes = {1 , sourceType.getShape ()[1 ]};
1014+ } else {
1015+ sliceOffsets = {0 , i};
1016+ sliceSizes = {sourceType.getShape ()[0 ], 1 };
1017+ }
1018+ Value col = vector::ExtractStridedSliceOp::create (
1019+ rewriter, reductionOp.getLoc (), reductionOp.getSource (), sliceOffsets,
1020+ sliceSizes, {1 , 1 });
1021+ int64_t col1DSize = sourceType.getShape ()[sourceDistDim];
1022+ col = vector::ShapeCastOp::create (
1023+ rewriter, reductionOp.getLoc (),
1024+ VectorType::get ({col1DSize}, elementType), col);
1025+ Value accCol = vector::ExtractOp::create (rewriter, reductionOp.getLoc (),
1026+ reductionOp.getAcc (), i);
1027+ Value colReduce = vector::ReductionOp::create (
1028+ rewriter, reductionOp.getLoc (), reductionOp.getKind (), col, accCol);
9741029 result = vector::InsertOp::create (rewriter, reductionOp.getLoc (),
975- rowReduce , result, i);
1030+ colReduce , result, i);
9761031 }
9771032 // Replace the warp op result with the final result.
9781033 rewriter.replaceAllUsesWith (reductionOp.getResult (), result);
1034+ return success ();
1035+ }
1036+ };
9791037
1038+ struct VectorShapeCastDistribution : public gpu ::WarpDistributionPattern {
1039+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
1040+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
1041+ PatternRewriter &rewriter) const override {
1042+ OpOperand *yieldOperand =
1043+ getWarpResult (warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1044+ if (!yieldOperand)
1045+ return failure ();
1046+ auto shapeCastOp =
1047+ cast<vector::ShapeCastOp>(yieldOperand->get ().getDefiningOp ());
1048+ unsigned operandNumber = yieldOperand->getOperandNumber ();
1049+ auto resultDistTy =
1050+ cast<VectorType>(warpOp.getResult (operandNumber).getType ());
1051+ xegpu::DistributeLayoutAttr sourceLayout =
1052+ xegpu::getDistributeLayoutAttr (shapeCastOp.getSource ());
1053+ if (!sourceLayout)
1054+ return rewriter.notifyMatchFailure (
1055+ warpOp, " the source of shape_cast op lacks distribution layout" );
1056+ FailureOr<VectorType> sourceDistTypeOrFailure =
1057+ getDistVecTypeBasedOnLaneLayout (sourceLayout,
1058+ shapeCastOp.getSourceVectorType ());
1059+ if (failed (sourceDistTypeOrFailure))
1060+ return rewriter.notifyMatchFailure (
1061+ warpOp, " failed to get distributed vector type for source" );
1062+ VectorType sourceDistType = sourceDistTypeOrFailure.value ();
1063+ // Create a new warp op that yields the source of the shape_cast op.
1064+ SmallVector<size_t > newRetIndices;
1065+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1066+ rewriter, warpOp, {shapeCastOp.getSource ()}, {sourceDistType},
1067+ newRetIndices);
1068+ rewriter.setInsertionPointAfter (newWarpOp);
1069+ Value source = newWarpOp.getResult (newRetIndices[0 ]);
1070+ // Create a new shape_cast op outside the warp op.
1071+ Value newShapeCast = vector::ShapeCastOp::create (
1072+ rewriter, shapeCastOp.getLoc (), resultDistTy, source);
1073+ rewriter.replaceAllUsesWith (newWarpOp.getResult (operandNumber),
1074+ newShapeCast);
9801075 return success ();
9811076 }
9821077};
@@ -998,6 +1093,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
9981093 DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
9991094 GpuBarrierDistribution, VectorMultiReductionDistribution>(
10001095 patterns.getContext ());
1096+ patterns.add <VectorShapeCastDistribution>(patterns.getContext (),
1097+ /* benefit=*/ 2 );
10011098}
10021099
10031100void XeGPUSubgroupDistributePass::runOnOperation () {
@@ -1012,8 +1109,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
10121109 if (!isa<VectorType>(operand.get ().getType ()))
10131110 continue ;
10141111
1015- auto layout =
1016- xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
1112+ auto layout = xegpu::getDistributeLayoutAttr (operand.get ());
10171113 if (!layout) {
10181114 op->emitError (" Could not find layout attribute for operand " )
10191115 << operand.getOperandNumber () << " of operation " << op->getName ();
@@ -1074,6 +1170,25 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
10741170 // TODO: shuffleFn is not used.
10751171 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
10761172 int64_t warpSz) { return Value (); };
1173+
1174+ auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
1175+ vector::CombiningKind kind, uint32_t size) {
1176+ // First reduce on a single thread to get per lane reduction value.
1177+ Value laneVal = builder.create <vector::ReductionOp>(loc, kind, input);
1178+ // Parallel reduction using butterfly shuffles.
1179+ for (uint64_t i = 1 ; i < size; i <<= 1 ) {
1180+ Value shuffled =
1181+ builder
1182+ .create <gpu::ShuffleOp>(loc, laneVal, i,
1183+ /* width=*/ size,
1184+ /* mode=*/ gpu::ShuffleMode::XOR)
1185+ .getShuffleResult ();
1186+ laneVal = makeArithReduction (builder, loc, kind, laneVal, shuffled);
1187+ }
1188+ return laneVal;
1189+ };
1190+
1191+ vector::populateDistributeReduction (patterns, warpReduction);
10771192 vector::populatePropagateWarpVectorDistributionPatterns (
10781193 patterns, distributionFn, shuffleFn);
10791194 if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
0 commit comments