Skip to content

Commit 8d6a016

Browse files
nbpatelaadeshps-mcw
authored andcommitted
[MLIR][Vector] Add unroll pattern for vector.shape_cast (llvm#167738)
This PR adds pattern for unrolling shape_cast given a targetShape. This PR is a follow up of llvm#164010 which was very general and was using inserts and extracts on each element (which is also LowerVectorShapeCast.cpp is doing). After doing some more research on use cases, we (me and @Jianhui-Li ) realized that the previous version in llvm#164010 is unnecessarily generic and doesn't fit our performance needs. Our use case requires that targetShape is contiguous in both source and result vector. This pattern only applies when contiguous slices can be extracted from the source vector and inserted into the result vector such that each slice remains in vector form with targetShape (and not decompose to scalars). In these cases, the unrolling proceeds as: vector.extract_strided_slice -> vector.shape_cast (on the slice unrolled) -> vector.insert_strided_slice
1 parent b75a5fa commit 8d6a016

File tree

5 files changed

+297
-2
lines changed

5 files changed

+297
-2
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2424,6 +2424,7 @@ def Vector_CompressStoreOp :
24242424

24252425
def Vector_ShapeCastOp :
24262426
Vector_Op<"shape_cast", [Pure,
2427+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
24272428
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
24282429
]>,
24292430
Arguments<(ins AnyVectorOfAnyRank:$source)>,

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6243,6 +6243,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
62436243
setResultRanges(getResult(), argRanges.front());
62446244
}
62456245

6246+
std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6247+
return llvm::to_vector<4>(getResultVectorType().getShape());
6248+
}
6249+
62466250
LogicalResult ShapeCastOp::verify() {
62476251

62486252
VectorType sourceType = getSourceVectorType();

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 191 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,195 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
10031003
vector::UnrollVectorOptions options;
10041004
};
10051005

1006+
/// Checks whether extractShape is a contiguous slice of shape.
1007+
/// For extractShape to be contiguous in shape:
1008+
/// 1) All but the leading dimension of extractShape and shape must match
1009+
/// exactly. 2) The total number of elements in shape must be evenly divisible
1010+
/// by
1011+
/// the total number of elements in extractShape.
1012+
/// Examples:
1013+
/// isContiguous([4, 4], [8, 4]) == true
1014+
/// isContiguous([2, 4], [8, 4]) == true
1015+
/// isContiguous([2, 2], [8, 4]) == false
1016+
/// Removes leading unit dimensions to handle cases like:
1017+
/// isContiguous([1, 16], [1, 32]) == true
1018+
static bool isContiguous(ArrayRef<int64_t> extractShape,
1019+
ArrayRef<int64_t> shape) {
1020+
1021+
if (extractShape.size() > shape.size())
1022+
return false;
1023+
1024+
while (!extractShape.empty() && extractShape.front() == 1) {
1025+
extractShape = extractShape.drop_front();
1026+
}
1027+
1028+
while (!shape.empty() && shape.front() == 1) {
1029+
shape = shape.drop_front();
1030+
}
1031+
1032+
size_t rankDiff = shape.size() - extractShape.size();
1033+
if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
1034+
return false;
1035+
1036+
int64_t extractElements = ShapedType::getNumElements(extractShape);
1037+
int64_t shapeElements = ShapedType::getNumElements(shape);
1038+
return shapeElements % extractElements == 0;
1039+
}
1040+
1041+
/// Determines what shape to use with `vector.extract_strided_slice` to extract
1042+
/// a contiguous memory region from a source vector. The extraction must be
1043+
/// contiguous and contain exactly the specified number of elements. If such an
1044+
/// extraction shape cannot be determined, returns std::nullopt.
1045+
/// EXAMPLE 1:
1046+
/// sourceShape = [16], targetElements = 8
1047+
/// Working right-to-left:
1048+
/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
1049+
/// remaining = 8/8 = 1
1050+
/// Result: [8]
1051+
///
1052+
/// EXAMPLE 2:
1053+
/// sourceShape = [4, 4], targetElements = 8
1054+
/// Working right-to-left:
1055+
/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
1056+
/// remaining = 8/4 = 2
1057+
/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
1058+
/// remaining = 2/2 = 1
1059+
/// Result: [2, 4]
1060+
static std::optional<SmallVector<int64_t>>
1061+
calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
1062+
int64_t targetElements) {
1063+
SmallVector<int64_t> extractShape;
1064+
int64_t remainingElements = targetElements;
1065+
1066+
// Build extract shape from innermost dimension outward to ensure contiguity.
1067+
for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1068+
int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1069+
extractShape.insert(extractShape.begin(), takeFromDim);
1070+
1071+
if (remainingElements % takeFromDim != 0)
1072+
return std::nullopt; // Not evenly divisible.
1073+
remainingElements /= takeFromDim;
1074+
}
1075+
1076+
// Fill remaining dimensions with 1.
1077+
while (extractShape.size() < sourceShape.size())
1078+
extractShape.insert(extractShape.begin(), 1);
1079+
1080+
if (ShapedType::getNumElements(extractShape) != targetElements)
1081+
return std::nullopt;
1082+
1083+
return extractShape;
1084+
}
1085+
1086+
// Convert result offsets to source offsets via linear position.
1087+
static SmallVector<int64_t>
1088+
calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
1089+
ArrayRef<int64_t> sourceShape,
1090+
ArrayRef<int64_t> resultShape) {
1091+
// Convert result offsets to linear position.
1092+
int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
1093+
// Convert linear position to source offsets.
1094+
return delinearize(linearIndex, computeStrides(sourceShape));
1095+
}
1096+
1097+
/// This pattern unrolls `vector.shape_cast` operations according to the
1098+
/// provided target unroll shape. It unrolls a large shape cast into smaller
1099+
/// shape casts by extracting contiguous slices from the source vector, casting
1100+
/// each slice to the target shape, and assembling the result by inserting each
1101+
/// computed segment into the appropriate offset of the result vector.
1102+
///
1103+
/// This pattern only applies when contiguous slices can be extracted from the
1104+
/// source vector and inserted into the result vector such that each slice
1105+
/// remains a valid vector (and not decompose to scalars). In these cases, the
1106+
/// unrolling proceeds as:
1107+
/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
1108+
/// vector.insert_strided_slice.
1109+
///
1110+
/// Example:
1111+
/// Given a shape cast operation:
1112+
/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
1113+
///
1114+
/// and a target unroll shape of <2x4>, the pattern produces:
1115+
///
1116+
/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
1117+
/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
1118+
/// : vector<8x2xf32> to vector<4x2xf32>
1119+
/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
1120+
/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
1121+
/// : vector<2x4xf32> into vector<4x4xf32>
1122+
/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
1123+
/// : vector<8x2xf32> to vector<4x2xf32>
1124+
/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
1125+
/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
1126+
/// : vector<2x4xf32> into vector<4x4xf32>
1127+
///
1128+
struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
1129+
UnrollShapeCastPattern(MLIRContext *context,
1130+
const vector::UnrollVectorOptions &options,
1131+
PatternBenefit benefit = 1)
1132+
: OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1133+
options(options) {}
1134+
1135+
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1136+
PatternRewriter &rewriter) const override {
1137+
std::optional<SmallVector<int64_t>> targetShape =
1138+
getTargetShape(options, shapeCastOp);
1139+
if (!targetShape)
1140+
return failure();
1141+
1142+
VectorType sourceType = shapeCastOp.getSourceVectorType();
1143+
VectorType resultType = shapeCastOp.getResultVectorType();
1144+
ArrayRef<int64_t> sourceShape = sourceType.getShape();
1145+
ArrayRef<int64_t> resultShape = resultType.getShape();
1146+
1147+
if (!isContiguous(*targetShape, resultShape))
1148+
return rewriter.notifyMatchFailure(
1149+
shapeCastOp, "Only supports cases where target shape is "
1150+
"contiguous in result vector shape");
1151+
1152+
int64_t targetElements = ShapedType::getNumElements(*targetShape);
1153+
1154+
// Calculate the shape to extract from source.
1155+
std::optional<SmallVector<int64_t>> extractShape =
1156+
calculateSourceExtractShape(sourceShape, targetElements);
1157+
if (!extractShape)
1158+
return rewriter.notifyMatchFailure(
1159+
shapeCastOp,
1160+
"cannot extract target number of elements contiguously from source");
1161+
1162+
Location loc = shapeCastOp.getLoc();
1163+
1164+
// Create result vector initialized to zero.
1165+
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1166+
rewriter.getZeroAttr(resultType));
1167+
1168+
VectorType targetType =
1169+
VectorType::get(*targetShape, sourceType.getElementType());
1170+
1171+
SmallVector<int64_t> extractStrides(extractShape->size(), 1);
1172+
SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1173+
1174+
for (SmallVector<int64_t> resultOffsets :
1175+
StaticTileOffsetRange(resultShape, *targetShape)) {
1176+
SmallVector<int64_t> sourceOffsets =
1177+
calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
1178+
Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1179+
loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1180+
extractStrides);
1181+
Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
1182+
loc, targetType, sourceChunk);
1183+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1184+
loc, targetChunk, result, resultOffsets, insertStrides);
1185+
}
1186+
1187+
rewriter.replaceOp(shapeCastOp, result);
1188+
return success();
1189+
}
1190+
1191+
private:
1192+
vector::UnrollVectorOptions options;
1193+
};
1194+
10061195
} // namespace
10071196

10081197
void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1202,8 @@ void mlir::vector::populateVectorUnrollPatterns(
10131202
UnrollReductionPattern, UnrollMultiReductionPattern,
10141203
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
10151204
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1016-
UnrollToElements, UnrollStepPattern>(patterns.getContext(),
1017-
options, benefit);
1205+
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
1206+
patterns.getContext(), options, benefit);
10181207
}
10191208

10201209
void mlir::vector::populateVectorToElementsUnrollPatterns(

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,82 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
496496
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
497497
// CHECK-NOT: arith.addf
498498
// CHECK: return
499+
500+
501+
func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
502+
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
503+
return %0 : vector<2x2x4xf32>
504+
}
505+
506+
// CHECK-LABEL: func @shape_cast_1D
507+
// CHECK-SAME: (%[[V:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
508+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32>
509+
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
510+
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32>
511+
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
512+
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
513+
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32>
514+
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
515+
// CHECK: return %[[I1]] : vector<2x2x4xf32>
516+
517+
518+
func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> {
519+
%0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32>
520+
return %0 : vector<4x4xf32>
521+
}
522+
523+
// CHECK-LABEL: func @shape_cast_2D
524+
// CHECK-SAME: (%[[V:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
525+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
526+
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
527+
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32>
528+
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
529+
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
530+
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32>
531+
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
532+
// CHECK: return %[[I1]] : vector<4x4xf32>
533+
534+
535+
// This is a negative test case to ensure that such shape casts are not unrolled
536+
// because the targetShape (2x4) is not contiguous in result vector
537+
func.func @negative_shape_cast_target_shape_not_contiguous(%v: vector<64xf32>) -> vector<8x8xf32> {
538+
%0 = vector.shape_cast %v : vector<64xf32> to vector<8x8xf32>
539+
return %0 : vector<8x8xf32>
540+
}
541+
542+
// CHECK-LABEL: func @negative_shape_cast_target_shape_not_contiguous
543+
// CHECK-SAME: (%[[V:.*]]: vector<64xf32>) -> vector<8x8xf32> {
544+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<64xf32> to vector<8x8xf32>
545+
// CHECK: return %[[SC]] : vector<8x8xf32>
546+
547+
548+
// This is negative test case to ensure that such shape casts are not unrolled
549+
// because it cannot determine the extractShape from source vector (8x3)
550+
// to extract conitguous targetShape (2x4)
551+
func.func @negative_shape_cast_source_shape_not_determinable(%v: vector<8x3xf32>) -> vector<6x4xf32> {
552+
%0 = vector.shape_cast %v : vector<8x3xf32> to vector<6x4xf32>
553+
return %0 : vector<6x4xf32>
554+
}
555+
556+
// CHECK-LABEL: func @negative_shape_cast_source_shape_not_determinable
557+
// CHECK-SAME: (%[[V:.*]]: vector<8x3xf32>) -> vector<6x4xf32> {
558+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<8x3xf32> to vector<6x4xf32>
559+
// CHECK: return %[[SC]] : vector<6x4xf32>
560+
561+
562+
// TargetShape is [1x16]
563+
func.func @shape_cast_leading_unit_dim(%v: vector<32xf32>) -> vector<1x32xf32> {
564+
%0 = vector.shape_cast %v : vector<32xf32> to vector<1x32xf32>
565+
return %0 : vector<1x32xf32>
566+
}
567+
568+
// CHECK-LABEL: func @shape_cast_leading_unit_dim
569+
// CHECK-SAME: (%[[V:.*]]: vector<32xf32>) -> vector<1x32xf32> {
570+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
571+
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
572+
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<16xf32> to vector<1x16xf32>
573+
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
574+
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
575+
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<16xf32> to vector<1x16xf32>
576+
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
577+
// CHECK: return %[[I1]] : vector<1x32xf32>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,28 @@ struct TestVectorUnrollingPatterns
178178
.setFilterConstraint([](Operation *op) {
179179
return success(isa<vector::StepOp>(op));
180180
}));
181+
populateVectorUnrollPatterns(
182+
patterns,
183+
UnrollVectorOptions()
184+
.setNativeShapeFn(
185+
[](Operation *op) -> std::optional<SmallVector<int64_t>> {
186+
auto shapeCast = dyn_cast<vector::ShapeCastOp>(op);
187+
if (!shapeCast)
188+
return std::nullopt;
189+
190+
auto resultShape = shapeCast.getResultVectorType().getShape();
191+
// Special case with leading unit dims and different inner dim
192+
// for result and target shape.
193+
if (resultShape.size() == 2 && resultShape[0] == 1 &&
194+
resultShape[1] == 32) {
195+
return SmallVector<int64_t>{1, 16};
196+
}
197+
// Default case: [2,4] for all tests.
198+
return SmallVector<int64_t>{2, 4};
199+
})
200+
.setFilterConstraint([](Operation *op) {
201+
return success(isa<vector::ShapeCastOp>(op));
202+
}));
181203
populateVectorUnrollPatterns(
182204
patterns, UnrollVectorOptions()
183205
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})

0 commit comments

Comments
 (0)