| 
15 | 15 | #include "mlir/Dialect/Vector/IR/VectorOps.h"  | 
16 | 16 | #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"  | 
17 | 17 | #include "mlir/IR/AffineExpr.h"  | 
 | 18 | +#include "mlir/IR/Attributes.h"  | 
 | 19 | +#include "mlir/IR/BuiltinTypes.h"  | 
18 | 20 | #include "mlir/Interfaces/SideEffectInterfaces.h"  | 
19 | 21 | #include "mlir/Transforms/RegionUtils.h"  | 
20 | 22 | #include "llvm/ADT/SetVector.h"  | 
 | 23 | +#include "llvm/ADT/SmallVectorExtras.h"  | 
21 | 24 | #include "llvm/Support/FormatVariadic.h"  | 
22 | 25 | #include <utility>  | 
23 | 26 | 
 
  | 
@@ -52,6 +55,25 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,  | 
52 | 55 |   return map;  | 
53 | 56 | }  | 
54 | 57 | 
 
  | 
 | 58 | +/// Given a sequential and distributed vector type, returns the distributed  | 
 | 59 | +/// dimension. This function expects that only a single dimension is  | 
 | 60 | +/// distributed.  | 
 | 61 | +static int getDistributedDim(VectorType sequentialType,  | 
 | 62 | +                             VectorType distributedType) {  | 
 | 63 | +  assert(sequentialType.getRank() == distributedType.getRank() &&  | 
 | 64 | +         "sequential and distributed vector types must have the same rank");  | 
 | 65 | +  int64_t distributedDim = -1;  | 
 | 66 | +  for (int64_t i = 0; i < sequentialType.getRank(); ++i) {  | 
 | 67 | +    if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {  | 
 | 68 | +      // Keep this assert here in case WarpExecuteOnLane0Op gets extended to  | 
 | 69 | +      // support distributing multiple dimensions in the future.  | 
 | 70 | +      assert(distributedDim == -1 && "found multiple distributed dims");  | 
 | 71 | +      distributedDim = i;  | 
 | 72 | +    }  | 
 | 73 | +  }  | 
 | 74 | +  return distributedDim;  | 
 | 75 | +}  | 
 | 76 | + | 
55 | 77 | namespace {  | 
56 | 78 | 
 
  | 
57 | 79 | /// Helper struct to create the load / store operations that permit transit  | 
@@ -1076,6 +1098,196 @@ struct WarpOpCreateMask : public WarpDistributionPattern {  | 
1076 | 1098 |   }  | 
1077 | 1099 | };  | 
1078 | 1100 | 
 
  | 
 | 1101 | +/// Sink out insert_strided_slice op feeding into a warp op yield.  | 
 | 1102 | +/// ```  | 
 | 1103 | +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {  | 
 | 1104 | +///   ...  | 
 | 1105 | +///   %src = ... : vector<4x32xf32>  | 
 | 1106 | +///   %dest = ... : vector<8x32xf32>  | 
 | 1107 | +///   %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],  | 
 | 1108 | +///     strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>  | 
 | 1109 | +///   gpu.yield %insert : vector<8x32xf32>  | 
 | 1110 | +/// }  | 
 | 1111 | +/// ```  | 
 | 1112 | +/// To  | 
 | 1113 | +/// ```  | 
 | 1114 | +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,  | 
 | 1115 | +/// vector<8x1xf32>) {  | 
 | 1116 | +///   ...  | 
 | 1117 | +///   %src = ... : vector<4x32xf32>  | 
 | 1118 | +///   %dest = ... : vector<8x32xf32>  | 
 | 1119 | +///   gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>  | 
 | 1120 | +/// }  | 
 | 1121 | +/// %insert = vector.insert_strided_slice %0#0, %0#1,  | 
 | 1122 | +///   offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>  | 
 | 1123 | +/// ```  | 
 | 1124 | +/// NOTE: Current support assumes that both src and dest vectors are distributed  | 
 | 1125 | +/// to lanes and sinking the insert op does not require any cross lane  | 
 | 1126 | +/// communication.  | 
 | 1127 | +struct WarpOpInsertStridedSlice : public WarpDistributionPattern {  | 
 | 1128 | +  using Base::Base;  | 
 | 1129 | +  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,  | 
 | 1130 | +                                PatternRewriter &rewriter) const override {  | 
 | 1131 | +    OpOperand *operand =  | 
 | 1132 | +        getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);  | 
 | 1133 | +    if (!operand)  | 
 | 1134 | +      return failure();  | 
 | 1135 | +    unsigned int operandNumber = operand->getOperandNumber();  | 
 | 1136 | +    auto insertOp =  | 
 | 1137 | +        operand->get().getDefiningOp<vector::InsertStridedSliceOp>();  | 
 | 1138 | +    auto distributedType =  | 
 | 1139 | +        cast<VectorType>(warpOp.getResult(operandNumber).getType());  | 
 | 1140 | +    // Distributed type must be 2D or higher.  | 
 | 1141 | +    // TODO: Support 1D distributed types.  | 
 | 1142 | +    if (distributedType.getRank() < 2)  | 
 | 1143 | +      return rewriter.notifyMatchFailure(  | 
 | 1144 | +          insertOp, "result vector type must be 2D or higher");  | 
 | 1145 | +    // Find the distributed dimension of the dest vector. There should be  | 
 | 1146 | +    // exactly one.  | 
 | 1147 | +    auto yieldedType = cast<VectorType>(operand->get().getType());  | 
 | 1148 | +    int64_t destDistributedDim =  | 
 | 1149 | +        getDistributedDim(yieldedType, distributedType);  | 
 | 1150 | +    assert(destDistributedDim != -1 && "could not find distributed dimension");  | 
 | 1151 | + | 
 | 1152 | +    VectorType srcType = insertOp.getSourceVectorType();  | 
 | 1153 | +    VectorType destType = insertOp.getDestVectorType();  | 
 | 1154 | +    // Currently we require that both source (kD) and dest (nD) vectors are  | 
 | 1155 | +    // distributed. This requires that distributedDim (d) is contained in the  | 
 | 1156 | +    // last k dims of the dest vector (d >= n - k).  | 
 | 1157 | +    // TODO: Add support for case where source vector is not distributed.  | 
 | 1158 | +    int64_t sourceDistributedDim =  | 
 | 1159 | +        destDistributedDim - (destType.getRank() - srcType.getRank());  | 
 | 1160 | +    if (sourceDistributedDim < 0)  | 
 | 1161 | +      return rewriter.notifyMatchFailure(  | 
 | 1162 | +          insertOp,  | 
 | 1163 | +          "distributed dimension must be in the last k dims of dest vector");  | 
 | 1164 | +    // Distributed dimension must be fully inserted.  | 
 | 1165 | +    if (srcType.getDimSize(sourceDistributedDim) !=  | 
 | 1166 | +        destType.getDimSize(destDistributedDim))  | 
 | 1167 | +      return rewriter.notifyMatchFailure(  | 
 | 1168 | +          insertOp, "distributed dimension must be fully inserted");  | 
 | 1169 | +    SmallVector<int64_t> newSourceDistShape(  | 
 | 1170 | +        insertOp.getSourceVectorType().getShape());  | 
 | 1171 | +    newSourceDistShape[sourceDistributedDim] =  | 
 | 1172 | +        distributedType.getDimSize(destDistributedDim);  | 
 | 1173 | +    auto newSourceTy =  | 
 | 1174 | +        VectorType::get(newSourceDistShape, distributedType.getElementType());  | 
 | 1175 | +    VectorType newDestTy = distributedType;  | 
 | 1176 | +    SmallVector<size_t> newRetIndices;  | 
 | 1177 | +    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(  | 
 | 1178 | +        rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},  | 
 | 1179 | +        {newSourceTy, newDestTy}, newRetIndices);  | 
 | 1180 | +    rewriter.setInsertionPointAfter(newWarpOp);  | 
 | 1181 | +    Value distributedSource = newWarpOp->getResult(newRetIndices[0]);  | 
 | 1182 | +    Value distributedDest = newWarpOp->getResult(newRetIndices[1]);  | 
 | 1183 | +    // Create a new insert strided slice op that inserts distributed source into  | 
 | 1184 | +    // distributed dest.  | 
 | 1185 | +    Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(  | 
 | 1186 | +        insertOp.getLoc(), distributedDest.getType(), distributedSource,  | 
 | 1187 | +        distributedDest, insertOp.getOffsets(), insertOp.getStrides());  | 
 | 1188 | +    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);  | 
 | 1189 | +    return success();  | 
 | 1190 | +  }  | 
 | 1191 | +};  | 
 | 1192 | + | 
 | 1193 | +/// Sink out extract_strided_slice op feeding into a warp op yield.  | 
 | 1194 | +/// ```  | 
 | 1195 | +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {  | 
 | 1196 | +///   ...  | 
 | 1197 | +///   %src = ... : vector<64x32xf32>  | 
 | 1198 | +///   %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],  | 
 | 1199 | +///     strides = [1] : vector<64x32xf32> to vector<16x32xf32>  | 
 | 1200 | +///   gpu.yield %extract : vector<16x32xf32>  | 
 | 1201 | +/// }  | 
 | 1202 | +/// ```  | 
 | 1203 | +/// To  | 
 | 1204 | +/// ```  | 
 | 1205 | +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {  | 
 | 1206 | +///   ...  | 
 | 1207 | +///   %src = ... : vector<64x32xf32>  | 
 | 1208 | +///   gpu.yield %src : vector<64x32xf32>  | 
 | 1209 | +/// }  | 
 | 1210 | +/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],  | 
 | 1211 | +///   strides = [1] : vector<64x1xf32> to vector<16x1xf32>  | 
 | 1212 | +/// ```  | 
 | 1213 | +/// NOTE: Current support assumes that the extraction happens only on non  | 
 | 1214 | +/// distributed dimensions (does not require cross lane communication).  | 
 | 1215 | +struct WarpOpExtractStridedSlice : public WarpDistributionPattern {  | 
 | 1216 | +  using Base::Base;  | 
 | 1217 | +  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,  | 
 | 1218 | +                                PatternRewriter &rewriter) const override {  | 
 | 1219 | +    OpOperand *operand =  | 
 | 1220 | +        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);  | 
 | 1221 | +    if (!operand)  | 
 | 1222 | +      return failure();  | 
 | 1223 | +    unsigned int operandNumber = operand->getOperandNumber();  | 
 | 1224 | +    auto extractOp =  | 
 | 1225 | +        operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();  | 
 | 1226 | +    auto distributedType =  | 
 | 1227 | +        cast<VectorType>(warpOp.getResult(operandNumber).getType());  | 
 | 1228 | +    // Distributed type must be 2D or higher.  | 
 | 1229 | +    // TODO: Support 1D distributed types.  | 
 | 1230 | +    if (distributedType.getRank() < 2)  | 
 | 1231 | +      return rewriter.notifyMatchFailure(  | 
 | 1232 | +          extractOp, "result vector type must be 2D or higher");  | 
 | 1233 | + | 
 | 1234 | +    // Find the distributed dimension. There should be exactly one.  | 
 | 1235 | +    auto yieldedType = cast<VectorType>(operand->get().getType());  | 
 | 1236 | +    int64_t distributedDim = getDistributedDim(yieldedType, distributedType);  | 
 | 1237 | +    assert(distributedDim != -1 && "could not find distributed dimension");  | 
 | 1238 | + | 
 | 1239 | +    int64_t numOfExtractedDims =  | 
 | 1240 | +        static_cast<int64_t>(extractOp.getSizes().size());  | 
 | 1241 | +    // If the distributed dim is included in the extracted dims,  then we make  | 
 | 1242 | +    // sure distributed dim is fully extracted. If distributed dim is not  | 
 | 1243 | +    // included in extracted dims, it is guaranteed to be fully extracted (i.e.  | 
 | 1244 | +    // distributed dim comes after all the extracted dims)  | 
 | 1245 | +    // TODO: Partial extraction from distributed dimension require cross lane  | 
 | 1246 | +    // communication.  | 
 | 1247 | +    if (distributedDim < numOfExtractedDims) {  | 
 | 1248 | +      int64_t distributedDimOffset =  | 
 | 1249 | +          llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])  | 
 | 1250 | +              .getInt();  | 
 | 1251 | +      int64_t distributedDimSize =  | 
 | 1252 | +          llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])  | 
 | 1253 | +              .getInt();  | 
 | 1254 | +      if (distributedDimOffset != 0 ||  | 
 | 1255 | +          distributedDimSize != yieldedType.getDimSize(distributedDim))  | 
 | 1256 | +        return rewriter.notifyMatchFailure(  | 
 | 1257 | +            extractOp, "distributed dimension must be fully extracted");  | 
 | 1258 | +    }  | 
 | 1259 | +    SmallVector<int64_t> newDistributedShape(  | 
 | 1260 | +        extractOp.getSourceVectorType().getShape());  | 
 | 1261 | +    newDistributedShape[distributedDim] =  | 
 | 1262 | +        distributedType.getDimSize(distributedDim);  | 
 | 1263 | +    auto newDistributedType =  | 
 | 1264 | +        VectorType::get(newDistributedShape, distributedType.getElementType());  | 
 | 1265 | +    SmallVector<size_t> newRetIndices;  | 
 | 1266 | +    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(  | 
 | 1267 | +        rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},  | 
 | 1268 | +        newRetIndices);  | 
 | 1269 | +    rewriter.setInsertionPointAfter(newWarpOp);  | 
 | 1270 | +    SmallVector<Attribute> distributedSizes = llvm::map_to_vector(  | 
 | 1271 | +        extractOp.getSizes(), [](Attribute attr) { return attr; });  | 
 | 1272 | +    // Update the distributed sizes to match the distributed type.  | 
 | 1273 | +    if (distributedDim < static_cast<int64_t>(distributedSizes.size()))  | 
 | 1274 | +      distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(  | 
 | 1275 | +          distributedType.getDimSize(distributedDim));  | 
 | 1276 | + | 
 | 1277 | +    // Create a new extract strided slice op that extracts from the  | 
 | 1278 | +    // distributed vector.  | 
 | 1279 | +    Value distributedVec = newWarpOp->getResult(newRetIndices[0]);  | 
 | 1280 | +    Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(  | 
 | 1281 | +        extractOp.getLoc(), distributedType, distributedVec,  | 
 | 1282 | +        extractOp.getOffsets(),  | 
 | 1283 | +        ArrayAttr::get(rewriter.getContext(), distributedSizes),  | 
 | 1284 | +        extractOp.getStrides());  | 
 | 1285 | +    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),  | 
 | 1286 | +                                newExtract);  | 
 | 1287 | +    return success();  | 
 | 1288 | +  }  | 
 | 1289 | +};  | 
 | 1290 | + | 
1079 | 1291 | /// Pattern to move out vector.extract of single element vector. Those don't  | 
1080 | 1292 | /// need to be distributed and can just be propagated outside of the region.  | 
1081 | 1293 | struct WarpOpExtract : public WarpDistributionPattern {  | 
@@ -1122,15 +1334,7 @@ struct WarpOpExtract : public WarpDistributionPattern {  | 
1122 | 1334 |     auto distributedType =  | 
1123 | 1335 |         cast<VectorType>(warpOp.getResult(operandNumber).getType());  | 
1124 | 1336 |     auto yieldedType = cast<VectorType>(operand->get().getType());  | 
1125 |  | -    int64_t distributedDim = -1;  | 
1126 |  | -    for (int64_t i = 0; i < yieldedType.getRank(); ++i) {  | 
1127 |  | -      if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {  | 
1128 |  | -        // Keep this assert here in case WarpExecuteOnLane0Op gets extended to  | 
1129 |  | -        // support distributing multiple dimensions in the future.  | 
1130 |  | -        assert(distributedDim == -1 && "found multiple distributed dims");  | 
1131 |  | -        distributedDim = i;  | 
1132 |  | -      }  | 
1133 |  | -    }  | 
 | 1337 | +    int64_t distributedDim = getDistributedDim(yieldedType, distributedType);  | 
1134 | 1338 |     assert(distributedDim != -1 && "could not find distributed dimension");  | 
1135 | 1339 |     (void)distributedDim;  | 
1136 | 1340 | 
 
  | 
@@ -1764,7 +1968,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(  | 
1764 | 1968 |   patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,  | 
1765 | 1969 |                WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,  | 
1766 | 1970 |                WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,  | 
1767 |  | -               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(  | 
 | 1971 | +               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,  | 
 | 1972 | +               WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(  | 
1768 | 1973 |       patterns.getContext(), benefit);  | 
1769 | 1974 |   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,  | 
1770 | 1975 |                                     benefit);  | 
 | 
0 commit comments