1515#include " mlir/Dialect/Vector/IR/VectorOps.h"
1616#include " mlir/Dialect/Vector/Transforms/VectorDistribution.h"
1717#include " mlir/IR/AffineExpr.h"
18+ #include " mlir/IR/Attributes.h"
19+ #include " mlir/IR/BuiltinTypes.h"
1820#include " mlir/Interfaces/SideEffectInterfaces.h"
1921#include " mlir/Transforms/RegionUtils.h"
2022#include " llvm/ADT/SetVector.h"
23+ #include " llvm/ADT/SmallVectorExtras.h"
2124#include " llvm/Support/FormatVariadic.h"
2225#include < utility>
2326
@@ -52,6 +55,21 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
5255 return map;
5356}
5457
58+ static int getDistributedDim (VectorType origType, VectorType distributedType) {
59+ assert (origType.getRank () == distributedType.getRank () &&
60+ " sequential and distributed vector types must have the same rank" );
61+ int64_t distributedDim = -1 ;
62+ for (int64_t i = 0 ; i < origType.getRank (); ++i) {
63+ if (distributedType.getDimSize (i) != origType.getDimSize (i)) {
64+ // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
65+ // support distributing multiple dimensions in the future.
66+ assert (distributedDim == -1 && " found multiple distributed dims" );
67+ distributedDim = i;
68+ }
69+ }
70+ return distributedDim;
71+ }
72+
5573namespace {
5674
5775// / Helper struct to create the load / store operations that permit transit
@@ -1076,6 +1094,123 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
10761094 }
10771095};
10781096
1097+ // / Sink out insert_strided_slice op feeding into a warp op yield.
1098+ // / ```
1099+ // / %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
1100+ // / ...
1101+ // / %src = ... : vector<4x16xf32>
1102+ // / %dest = ... : vector<8x16xf32>
1103+ // / %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1104+ // / strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
1105+ // / gpu.yield %insert : vector<8x16xf32>
1106+ // / }
1107+ // / ```
1108+ // / To
1109+ // / ```
1110+ // / %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
1111+ // / vector<8x1xf32>) {
1112+ // / ...
1113+ // / %src = ... : vector<4x16xf32>
1114+ // / %dest = ... : vector<8x16xf32>
1115+ // / gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
1116+ // / }
1117+ // / %insert = vector.insert_strided_slice %0#0, %0#1,
1118+ // / offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
1119+ // / ```
1120+ // / NOTE: Current support assume that both src and dest vectors are distributed
1121+ // / to lanes and sinking the insert op does not require any cross lane
1122+ // / communication.
1123+ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
1124+ using Base::Base;
1125+ LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1126+ PatternRewriter &rewriter) const override {
1127+ OpOperand *operand =
1128+ getWarpResult (warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1129+ if (!operand)
1130+ return failure ();
1131+ unsigned int operandNumber = operand->getOperandNumber ();
1132+ auto insertOp =
1133+ operand->get ().getDefiningOp <vector::InsertStridedSliceOp>();
1134+ auto distributedType =
1135+ cast<VectorType>(warpOp.getResult (operandNumber).getType ());
1136+ // Distributed type must be 2D or higher.
1137+ // TODO: Support 1D distributed types.
1138+ if (distributedType.getRank () < 2 )
1139+ return rewriter.notifyMatchFailure (
1140+ insertOp, " result vector type must be 2D or higher" );
1141+ // Find the distributed dimension of the dest vector. There should be
1142+ // exactly one.
1143+ auto yieldedType = cast<VectorType>(operand->get ().getType ());
1144+ int64_t destDistributedDim =
1145+ getDistributedDim (yieldedType, distributedType);
1146+ assert (destDistributedDim != -1 && " could not find distributed dimension" );
1147+ (void )destDistributedDim;
1148+ VectorType srcType = insertOp.getSourceVectorType ();
1149+ VectorType destType = insertOp.getDestVectorType ();
1150+ // Currently we require that both source (kD) and dest (nD) vectors are
1151+ // distributed. This requires that distributedDim (d) is contained in the
1152+ // last k dims of the dest vector (d >= n - k).
1153+ // TODO: Add support for case where source vector is not distributed.
1154+ int64_t sourceDistributedDim =
1155+ destDistributedDim - (destType.getRank () - srcType.getRank ());
1156+ if (sourceDistributedDim < 0 )
1157+ return rewriter.notifyMatchFailure (
1158+ insertOp, " distributed dimension must be in the last k dims" );
1159+ // Distributed dimension must be fully inserted.
1160+ if (srcType.getDimSize (sourceDistributedDim) !=
1161+ destType.getDimSize (destDistributedDim))
1162+ return rewriter.notifyMatchFailure (
1163+ insertOp, " distributed dimension must be fully inserted" );
1164+ SmallVector<int64_t > newSourceDistShape (
1165+ insertOp.getSourceVectorType ().getShape ()),
1166+ newDestDistShape (insertOp.getDestVectorType ().getShape ());
1167+ newSourceDistShape[sourceDistributedDim] =
1168+ distributedType.getDimSize (destDistributedDim);
1169+ newDestDistShape[destDistributedDim] =
1170+ distributedType.getDimSize (destDistributedDim);
1171+ auto newSourceTy =
1172+ VectorType::get (newSourceDistShape, distributedType.getElementType ());
1173+ auto newDestTy =
1174+ VectorType::get (newDestDistShape, distributedType.getElementType ());
1175+ SmallVector<size_t > newRetIndices;
1176+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1177+ rewriter, warpOp, {insertOp.getValueToStore (), insertOp.getDest ()},
1178+ {newSourceTy, newDestTy}, newRetIndices);
1179+ rewriter.setInsertionPointAfter (newWarpOp);
1180+ auto distributedSource = newWarpOp->getResult (newRetIndices[0 ]);
1181+ auto distributedDest = newWarpOp->getResult (newRetIndices[1 ]);
1182+ // Create a new insert strided slice op that inserts distributed source into
1183+ // distributed dest.
1184+ Value newInsert = rewriter.create <vector::InsertStridedSliceOp>(
1185+ insertOp.getLoc (), distributedDest.getType (), distributedSource,
1186+ distributedDest, insertOp.getOffsets (), insertOp.getStrides ());
1187+ rewriter.replaceAllUsesWith (newWarpOp->getResult (operandNumber), newInsert);
1188+ return success ();
1189+ }
1190+ };
1191+
1192+ // / Sink out extract_strided_slice op feeding into a warp op yield.
1193+ // / ```
1194+ // / %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
1195+ // / ...
1196+ // / %src = ... : vector<32x16xf32>
1197+ // / %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1198+ // / strides = [1] : vector<32x16xf32> to vector<16x16xf32>
1199+ // / gpu.yield %extract : vector<16x16xf32>
1200+ // / }
1201+ // / ```
1202+ // / To
1203+ // / ````
1204+ // / %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32>) {
1205+ // / ...
1206+ // / %src = ... : vector<32x16xf32>
1207+ // / gpu.yield %src : vector<32x16xf32>
1208+ // / }
1209+ // / %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1210+ // / strides = [1] : vector<32x1xf32> to vector<16x1xf32>
1211+ // / ```
1212+ // / NOTE: Current support assumes that the extraction happens only on non
1213+ // / distributed dimensions (does not require cross lane communication).
10791214struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
10801215 using Base::Base;
10811216 LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
@@ -1087,6 +1222,63 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
10871222 unsigned int operandNumber = operand->getOperandNumber ();
10881223 auto extractOp =
10891224 operand->get ().getDefiningOp <vector::ExtractStridedSliceOp>();
1225+ auto distributedType =
1226+ cast<VectorType>(warpOp.getResult (operandNumber).getType ());
1227+ // Distributed type must be 2D or higher.
1228+ // TODO: Support 1D distributed types.
1229+ if (distributedType.getRank () < 2 )
1230+ return rewriter.notifyMatchFailure (
1231+ extractOp, " result vector type must be 2D or higher" );
1232+
1233+ // Find the distributed dimension. There should be exactly one.
1234+ auto yieldedType = cast<VectorType>(operand->get ().getType ());
1235+ int64_t distributedDim = getDistributedDim (yieldedType, distributedType);
1236+ assert (distributedDim != -1 && " could not find distributed dimension" );
1237+ (void )distributedDim;
1238+
1239+ // Distributed dimension must be fully extracted.
1240+ // TODO: Partial extraction from distributed dimension require cross lane
1241+ // communication.
1242+ if (distributedDim < static_cast <int64_t >(extractOp.getSizes ().size ())) {
1243+ int64_t distributedDimOffset =
1244+ llvm::cast<IntegerAttr>(extractOp.getOffsets ()[distributedDim])
1245+ .getInt ();
1246+ int64_t distributedDimSize =
1247+ llvm::cast<IntegerAttr>(extractOp.getSizes ()[distributedDim])
1248+ .getInt ();
1249+ if (distributedDimOffset != 0 ||
1250+ distributedDimSize != yieldedType.getDimSize (distributedDim))
1251+ return rewriter.notifyMatchFailure (
1252+ extractOp, " distributed dimension must be fully extracted" );
1253+ }
1254+ SmallVector<int64_t > newDistributedShape (
1255+ extractOp.getSourceVectorType ().getShape ());
1256+ newDistributedShape[distributedDim] =
1257+ distributedType.getDimSize (distributedDim);
1258+ auto newDistributedType =
1259+ VectorType::get (newDistributedShape, distributedType.getElementType ());
1260+ SmallVector<size_t > newRetIndices;
1261+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1262+ rewriter, warpOp, {extractOp.getVector ()}, {newDistributedType},
1263+ newRetIndices);
1264+ rewriter.setInsertionPointAfter (newWarpOp);
1265+ SmallVector<Attribute> distributedSizes = llvm::map_to_vector (
1266+ extractOp.getSizes (), [](Attribute attr) { return attr; });
1267+ // Update the distributed sizes to match the distributed type.
1268+ if (distributedDim < static_cast <int64_t >(distributedSizes.size ()))
1269+ distributedSizes[distributedDim] = rewriter.getI64IntegerAttr (
1270+ distributedType.getDimSize (distributedDim));
1271+
1272+ // Create a new extract strided slice op that extracts from the
1273+ // distributed vector.
1274+ Value distributedVec = newWarpOp->getResult (newRetIndices[0 ]);
1275+ Value newExtract = rewriter.create <vector::ExtractStridedSliceOp>(
1276+ extractOp.getLoc (), distributedType, distributedVec,
1277+ extractOp.getOffsets (),
1278+ ArrayAttr::get (rewriter.getContext (), distributedSizes),
1279+ extractOp.getStrides ());
1280+ rewriter.replaceAllUsesWith (newWarpOp->getResult (operandNumber),
1281+ newExtract);
10901282 return success ();
10911283 }
10921284};
@@ -1137,15 +1329,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
11371329 auto distributedType =
11381330 cast<VectorType>(warpOp.getResult (operandNumber).getType ());
11391331 auto yieldedType = cast<VectorType>(operand->get ().getType ());
1140- int64_t distributedDim = -1 ;
1141- for (int64_t i = 0 ; i < yieldedType.getRank (); ++i) {
1142- if (distributedType.getDimSize (i) != yieldedType.getDimSize (i)) {
1143- // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1144- // support distributing multiple dimensions in the future.
1145- assert (distributedDim == -1 && " found multiple distributed dims" );
1146- distributedDim = i;
1147- }
1148- }
1332+ int64_t distributedDim = getDistributedDim (yieldedType, distributedType);
11491333 assert (distributedDim != -1 && " could not find distributed dimension" );
11501334 (void )distributedDim;
11511335
@@ -1776,12 +1960,12 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
17761960 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
17771961 PatternBenefit readBenefit) {
17781962 patterns.add <WarpOpTransferRead>(patterns.getContext (), readBenefit);
1779- patterns
1780- . add <WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast ,
1781- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant ,
1782- WarpOpExtractElement, WarpOpInsertElement, WarpOpInsertScalar ,
1783- WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice>(
1784- patterns.getContext (), benefit);
1963+ patterns. add <WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1964+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand ,
1965+ WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement ,
1966+ WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask ,
1967+ WarpOpExtractStridedSlice, WarpOpInsertStridedSlice >(
1968+ patterns.getContext (), benefit);
17851969 patterns.add <WarpOpExtractScalar>(patterns.getContext (), warpShuffleFromIdxFn,
17861970 benefit);
17871971 patterns.add <WarpOpScfForOp>(patterns.getContext (), distributionMapFn,
0 commit comments