Skip to content

Commit c999c20

Browse files
committed
add comments and test
1 parent 311d6d7 commit c999c20

File tree

2 files changed

+279
-15
lines changed

2 files changed

+279
-15
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 199 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1616
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
1717
#include "mlir/IR/AffineExpr.h"
18+
#include "mlir/IR/Attributes.h"
19+
#include "mlir/IR/BuiltinTypes.h"
1820
#include "mlir/Interfaces/SideEffectInterfaces.h"
1921
#include "mlir/Transforms/RegionUtils.h"
2022
#include "llvm/ADT/SetVector.h"
23+
#include "llvm/ADT/SmallVectorExtras.h"
2124
#include "llvm/Support/FormatVariadic.h"
2225
#include <utility>
2326

@@ -52,6 +55,21 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
5255
return map;
5356
}
5457

58+
static int getDistributedDim(VectorType origType, VectorType distributedType) {
59+
assert(origType.getRank() == distributedType.getRank() &&
60+
"sequential and distributed vector types must have the same rank");
61+
int64_t distributedDim = -1;
62+
for (int64_t i = 0; i < origType.getRank(); ++i) {
63+
if (distributedType.getDimSize(i) != origType.getDimSize(i)) {
64+
// Keep this assert here in case WarpExecuteOnLane0Op gets extended to
65+
// support distributing multiple dimensions in the future.
66+
assert(distributedDim == -1 && "found multiple distributed dims");
67+
distributedDim = i;
68+
}
69+
}
70+
return distributedDim;
71+
}
72+
5573
namespace {
5674

5775
/// Helper struct to create the load / store operations that permit transit
@@ -1076,6 +1094,123 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
10761094
}
10771095
};
10781096

1097+
/// Sink out insert_strided_slice op feeding into a warp op yield.
1098+
/// ```
1099+
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
1100+
/// ...
1101+
/// %src = ... : vector<4x16xf32>
1102+
/// %dest = ... : vector<8x16xf32>
1103+
/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1104+
/// strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
1105+
/// gpu.yield %insert : vector<8x16xf32>
1106+
/// }
1107+
/// ```
1108+
/// To
1109+
/// ```
1110+
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
1111+
/// vector<8x1xf32>) {
1112+
/// ...
1113+
/// %src = ... : vector<4x16xf32>
1114+
/// %dest = ... : vector<8x16xf32>
1115+
/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
1116+
/// }
1117+
/// %insert = vector.insert_strided_slice %0#0, %0#1,
1118+
/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
1119+
/// ```
1120+
/// NOTE: Current support assume that both src and dest vectors are distributed
1121+
/// to lanes and sinking the insert op does not require any cross lane
1122+
/// communication.
1123+
struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
1124+
using Base::Base;
1125+
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1126+
PatternRewriter &rewriter) const override {
1127+
OpOperand *operand =
1128+
getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1129+
if (!operand)
1130+
return failure();
1131+
unsigned int operandNumber = operand->getOperandNumber();
1132+
auto insertOp =
1133+
operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1134+
auto distributedType =
1135+
cast<VectorType>(warpOp.getResult(operandNumber).getType());
1136+
// Distributed type must be 2D or higher.
1137+
// TODO: Support 1D distributed types.
1138+
if (distributedType.getRank() < 2)
1139+
return rewriter.notifyMatchFailure(
1140+
insertOp, "result vector type must be 2D or higher");
1141+
// Find the distributed dimension of the dest vector. There should be
1142+
// exactly one.
1143+
auto yieldedType = cast<VectorType>(operand->get().getType());
1144+
int64_t destDistributedDim =
1145+
getDistributedDim(yieldedType, distributedType);
1146+
assert(destDistributedDim != -1 && "could not find distributed dimension");
1147+
(void)destDistributedDim;
1148+
VectorType srcType = insertOp.getSourceVectorType();
1149+
VectorType destType = insertOp.getDestVectorType();
1150+
// Currently we require that both source (kD) and dest (nD) vectors are
1151+
// distributed. This requires that distributedDim (d) is contained in the
1152+
// last k dims of the dest vector (d >= n - k).
1153+
// TODO: Add support for case where source vector is not distributed.
1154+
int64_t sourceDistributedDim =
1155+
destDistributedDim - (destType.getRank() - srcType.getRank());
1156+
if (sourceDistributedDim < 0)
1157+
return rewriter.notifyMatchFailure(
1158+
insertOp, "distributed dimension must be in the last k dims");
1159+
// Distributed dimension must be fully inserted.
1160+
if (srcType.getDimSize(sourceDistributedDim) !=
1161+
destType.getDimSize(destDistributedDim))
1162+
return rewriter.notifyMatchFailure(
1163+
insertOp, "distributed dimension must be fully inserted");
1164+
SmallVector<int64_t> newSourceDistShape(
1165+
insertOp.getSourceVectorType().getShape()),
1166+
newDestDistShape(insertOp.getDestVectorType().getShape());
1167+
newSourceDistShape[sourceDistributedDim] =
1168+
distributedType.getDimSize(destDistributedDim);
1169+
newDestDistShape[destDistributedDim] =
1170+
distributedType.getDimSize(destDistributedDim);
1171+
auto newSourceTy =
1172+
VectorType::get(newSourceDistShape, distributedType.getElementType());
1173+
auto newDestTy =
1174+
VectorType::get(newDestDistShape, distributedType.getElementType());
1175+
SmallVector<size_t> newRetIndices;
1176+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1177+
rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1178+
{newSourceTy, newDestTy}, newRetIndices);
1179+
rewriter.setInsertionPointAfter(newWarpOp);
1180+
auto distributedSource = newWarpOp->getResult(newRetIndices[0]);
1181+
auto distributedDest = newWarpOp->getResult(newRetIndices[1]);
1182+
// Create a new insert strided slice op that inserts distributed source into
1183+
// distributed dest.
1184+
Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(
1185+
insertOp.getLoc(), distributedDest.getType(), distributedSource,
1186+
distributedDest, insertOp.getOffsets(), insertOp.getStrides());
1187+
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
1188+
return success();
1189+
}
1190+
};
1191+
1192+
/// Sink out extract_strided_slice op feeding into a warp op yield.
1193+
/// ```
1194+
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
1195+
/// ...
1196+
/// %src = ... : vector<32x16xf32>
1197+
/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1198+
/// strides = [1] : vector<32x16xf32> to vector<16x16xf32>
1199+
/// gpu.yield %extract : vector<16x16xf32>
1200+
/// }
1201+
/// ```
1202+
/// To
1203+
/// ````
1204+
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32>) {
1205+
/// ...
1206+
/// %src = ... : vector<32x16xf32>
1207+
/// gpu.yield %src : vector<32x16xf32>
1208+
/// }
1209+
/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1210+
/// strides = [1] : vector<32x1xf32> to vector<16x1xf32>
1211+
/// ```
1212+
/// NOTE: Current support assumes that the extraction happens only on non
1213+
/// distributed dimensions (does not require cross lane communication).
10791214
struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
10801215
using Base::Base;
10811216
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
@@ -1087,6 +1222,63 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
10871222
unsigned int operandNumber = operand->getOperandNumber();
10881223
auto extractOp =
10891224
operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
1225+
auto distributedType =
1226+
cast<VectorType>(warpOp.getResult(operandNumber).getType());
1227+
// Distributed type must be 2D or higher.
1228+
// TODO: Support 1D distributed types.
1229+
if (distributedType.getRank() < 2)
1230+
return rewriter.notifyMatchFailure(
1231+
extractOp, "result vector type must be 2D or higher");
1232+
1233+
// Find the distributed dimension. There should be exactly one.
1234+
auto yieldedType = cast<VectorType>(operand->get().getType());
1235+
int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1236+
assert(distributedDim != -1 && "could not find distributed dimension");
1237+
(void)distributedDim;
1238+
1239+
// Distributed dimension must be fully extracted.
1240+
// TODO: Partial extraction from distributed dimension require cross lane
1241+
// communication.
1242+
if (distributedDim < static_cast<int64_t>(extractOp.getSizes().size())) {
1243+
int64_t distributedDimOffset =
1244+
llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1245+
.getInt();
1246+
int64_t distributedDimSize =
1247+
llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1248+
.getInt();
1249+
if (distributedDimOffset != 0 ||
1250+
distributedDimSize != yieldedType.getDimSize(distributedDim))
1251+
return rewriter.notifyMatchFailure(
1252+
extractOp, "distributed dimension must be fully extracted");
1253+
}
1254+
SmallVector<int64_t> newDistributedShape(
1255+
extractOp.getSourceVectorType().getShape());
1256+
newDistributedShape[distributedDim] =
1257+
distributedType.getDimSize(distributedDim);
1258+
auto newDistributedType =
1259+
VectorType::get(newDistributedShape, distributedType.getElementType());
1260+
SmallVector<size_t> newRetIndices;
1261+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1262+
rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1263+
newRetIndices);
1264+
rewriter.setInsertionPointAfter(newWarpOp);
1265+
SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1266+
extractOp.getSizes(), [](Attribute attr) { return attr; });
1267+
// Update the distributed sizes to match the distributed type.
1268+
if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
1269+
distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1270+
distributedType.getDimSize(distributedDim));
1271+
1272+
// Create a new extract strided slice op that extracts from the
1273+
// distributed vector.
1274+
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1275+
Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
1276+
extractOp.getLoc(), distributedType, distributedVec,
1277+
extractOp.getOffsets(),
1278+
ArrayAttr::get(rewriter.getContext(), distributedSizes),
1279+
extractOp.getStrides());
1280+
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1281+
newExtract);
10901282
return success();
10911283
}
10921284
};
@@ -1137,15 +1329,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
11371329
auto distributedType =
11381330
cast<VectorType>(warpOp.getResult(operandNumber).getType());
11391331
auto yieldedType = cast<VectorType>(operand->get().getType());
1140-
int64_t distributedDim = -1;
1141-
for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1142-
if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1143-
// Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1144-
// support distributing multiple dimensions in the future.
1145-
assert(distributedDim == -1 && "found multiple distributed dims");
1146-
distributedDim = i;
1147-
}
1148-
}
1332+
int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
11491333
assert(distributedDim != -1 && "could not find distributed dimension");
11501334
(void)distributedDim;
11511335

@@ -1776,12 +1960,12 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
17761960
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
17771961
PatternBenefit readBenefit) {
17781962
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1779-
patterns
1780-
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1781-
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
1782-
WarpOpExtractElement, WarpOpInsertElement, WarpOpInsertScalar,
1783-
WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice>(
1784-
patterns.getContext(), benefit);
1963+
patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1964+
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1965+
WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1966+
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
1967+
WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
1968+
patterns.getContext(), benefit);
17851969
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
17861970
benefit);
17871971
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,86 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
12961296
return %r : vector<4x96xf32>
12971297
}
12981298

1299+
// -----
1300+
// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_outer(
1301+
// CHECK-RPOP-SAME: %[[LANEID:.*]]: index
1302+
// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<64x1xf32>) {
1303+
// CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<64x32xf32>
1304+
// CHECK-PROP: gpu.yield %[[VEC]] : vector<64x32xf32>
1305+
// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
1306+
// CHECK-PROP-SAME: {offsets = [8], sizes = [24], strides = [1]} : vector<64x1xf32> to vector<24x1xf32>
1307+
// CHECK-PROP: return %[[EXTRACT]] : vector<24x1xf32>
1308+
func.func @vector_extract_strided_slice_2d_distr_outer(%laneid: index) -> (vector<24x1xf32>) {
1309+
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<24x1xf32>) {
1310+
%0 = "some_def"() : () -> (vector<64x32xf32>)
1311+
%1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [24], strides = [1]}
1312+
: vector<64x32xf32> to vector<24x32xf32>
1313+
gpu.yield %1 : vector<24x32xf32>
1314+
}
1315+
return %r : vector<24x1xf32>
1316+
}
1317+
1318+
// -----
1319+
// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_inner(
1320+
// CHECK-PROP-SAME: %[[LANEID:.*]]: index
1321+
// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x64xf32>) {
1322+
// CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<32x64xf32>
1323+
// CHECK-PROP: gpu.yield %[[VEC]] : vector<32x64xf32>
1324+
// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
1325+
// CHECK-PROP-SAME: {offsets = [0, 12], sizes = [1, 8], strides = [1, 1]} : vector<1x64xf32> to vector<1x8xf32>
1326+
// CHECK-PROP: return %[[EXTRACT]] : vector<1x8xf32>
1327+
func.func @vector_extract_strided_slice_2d_distr_inner(%laneid: index) -> (vector<1x8xf32>) {
1328+
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x8xf32>) {
1329+
%0 = "some_def"() : () -> (vector<32x64xf32>)
1330+
%1 = vector.extract_strided_slice %0 { offsets = [0, 12], sizes = [32, 8], strides = [1, 1]}
1331+
: vector<32x64xf32> to vector<32x8xf32>
1332+
gpu.yield %1 : vector<32x8xf32>
1333+
}
1334+
return %r : vector<1x8xf32>
1335+
}
1336+
1337+
// -----
1338+
// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_1d_to_2d(
1339+
// CHECK-PROP-SAME: %[[LANEID:.*]]: index)
1340+
// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}} -> (vector<1xf32>, vector<64x1xf32>) {
1341+
// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<32xf32>
1342+
// CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
1343+
// CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<32xf32>, vector<64x32xf32>
1344+
// CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1
1345+
// CHECK-PROP-SAME: {offsets = [18, 0], strides = [1]} : vector<1xf32> into vector<64x1xf32>
1346+
// CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
1347+
func.func @vector_insert_strided_slice_1d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
1348+
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
1349+
%0 = "some_def"() : () -> (vector<32xf32>)
1350+
%1 = "some_def"() : () -> (vector<64x32xf32>)
1351+
%2 = vector.insert_strided_slice %0, %1 { offsets = [18, 0], strides = [1]}
1352+
: vector<32xf32> into vector<64x32xf32>
1353+
gpu.yield %2 : vector<64x32xf32>
1354+
}
1355+
return %r : vector<64x1xf32>
1356+
}
1357+
1358+
// -----
1359+
// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_2d_to_2d(
1360+
// CHECK-PROP-SAME: %[[LANEID:.*]]: index)
1361+
// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<16x1xf32>, vector<64x1xf32>) {
1362+
// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<16x32xf32>
1363+
// CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
1364+
// CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<16x32xf32>, vector<64x32xf32>
1365+
// CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1 {offsets = [36, 0], strides = [1, 1]} :
1366+
// CHECK-PROP-SAME: vector<16x1xf32> into vector<64x1xf32>
1367+
// CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
1368+
func.func @vector_insert_strided_slice_2d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
1369+
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
1370+
%0 = "some_def"() : () -> (vector<16x32xf32>)
1371+
%1 = "some_def"() : () -> (vector<64x32xf32>)
1372+
%2 = vector.insert_strided_slice %0, %1 { offsets = [36, 0], strides = [1, 1]}
1373+
: vector<16x32xf32> into vector<64x32xf32>
1374+
gpu.yield %2 : vector<64x32xf32>
1375+
}
1376+
return %r : vector<64x1xf32>
1377+
}
1378+
12991379
// -----
13001380

13011381
// Make sure that all operands of the transfer_read op are properly propagated.

0 commit comments

Comments
 (0)