Skip to content

Commit cd8b818

Browse files
committed
Add unroll pattern for vector.shape_cast
1 parent 0364baf commit cd8b818

File tree

5 files changed

+213
-2
lines changed

5 files changed

+213
-2
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2427,6 +2427,7 @@ def Vector_CompressStoreOp :
24272427

24282428
def Vector_ShapeCastOp :
24292429
Vector_Op<"shape_cast", [Pure,
2430+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
24302431
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
24312432
]>,
24322433
Arguments<(ins AnyVectorOfAnyRank:$source)>,

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6241,6 +6241,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
62416241
setResultRanges(getResult(), argRanges.front());
62426242
}
62436243

6244+
std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6245+
return llvm::to_vector<4>(getResultVectorType().getShape());
6246+
}
6247+
62446248
LogicalResult ShapeCastOp::verify() {
62456249

62466250
VectorType sourceType = getSourceVectorType();

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 168 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,172 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
10031003
vector::UnrollVectorOptions options;
10041004
};
10051005

1006+
static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
1007+
ArrayRef<int64_t> resultShape) {
1008+
if (targetShape.size() > resultShape.size()) {
1009+
return false;
1010+
}
1011+
1012+
size_t rankDiff = resultShape.size() - targetShape.size();
1013+
// Inner dimensions must match exactly & total resultElements should be
1014+
// evenly divisible by targetElements.
1015+
for (size_t i = 1; i < targetShape.size(); ++i) {
1016+
if (targetShape[i] != resultShape[rankDiff + i]) {
1017+
return false;
1018+
}
1019+
}
1020+
1021+
int64_t targetElements = ShapedType::getNumElements(targetShape);
1022+
int64_t resultElements = ShapedType::getNumElements(resultShape);
1023+
if (resultElements % targetElements != 0) {
1024+
return false;
1025+
}
1026+
return true;
1027+
}
1028+
1029+
// Calculate the shape to extract from source
1030+
static std::optional<SmallVector<int64_t>>
1031+
calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
1032+
int64_t targetElements) {
1033+
SmallVector<int64_t> extractShape;
1034+
int64_t remainingElements = targetElements;
1035+
1036+
// Build extract shape from innermost dimension outward to ensure contiguity
1037+
for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1038+
int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1039+
extractShape.insert(extractShape.begin(), takeFromDim);
1040+
1041+
if (remainingElements % takeFromDim != 0) {
1042+
return std::nullopt; // Not evenly divisible
1043+
}
1044+
remainingElements /= takeFromDim;
1045+
}
1046+
1047+
// Fill remaining dimensions with 1
1048+
while (extractShape.size() < sourceShape.size()) {
1049+
extractShape.insert(extractShape.begin(), 1);
1050+
}
1051+
1052+
if (ShapedType::getNumElements(extractShape) != targetElements) {
1053+
return std::nullopt;
1054+
}
1055+
1056+
return extractShape;
1057+
}
1058+
1059+
// Convert result offsets to source offsets via linear position
1060+
static SmallVector<int64_t>
1061+
calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
1062+
ArrayRef<int64_t> sourceStrides,
1063+
ArrayRef<int64_t> resultStrides) {
1064+
// Convert result offsets to linear position
1065+
int64_t linearIndex = linearize(resultOffsets, resultStrides);
1066+
// Convert linear position to source offsets
1067+
SmallVector<int64_t> sourceOffsets = delinearize(linearIndex, sourceStrides);
1068+
return sourceOffsets;
1069+
}
1070+
1071+
/// This pattern unrolls `vector.shape_cast` operations according to the
1072+
/// provided target unroll shape. It unrolls a large shape cast into smaller
1073+
/// shape casts by extracting contiguous slices from the source vector, casting
1074+
/// each slice to the target shape, and assembling the result by inserting each
1075+
/// computed segment into the appropriate offset of the result vector.
1076+
///
1077+
/// This pattern only applies when contiguous slices can be extracted from the
1078+
/// source vector and inserted into the result vector such that each slice
1079+
/// remains a valid vector (and not decompose to scalars). In these cases, the
1080+
/// unrolling proceeds as:
1081+
/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
1082+
/// vector.insert_strided_slice
1083+
///
1084+
/// Example:
1085+
/// Given a shape cast operation:
1086+
/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
1087+
///
1088+
/// and a target unroll shape of <2x4>, the pattern produces:
1089+
///
1090+
/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
1091+
/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
1092+
/// : vector<8x2xf32> to vector<4x2xf32>
1093+
/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
1094+
/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
1095+
/// : vector<2x4xf32> into vector<4x4xf32>
1096+
/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
1097+
/// : vector<8x2xf32> to vector<4x2xf32>
1098+
/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
1099+
/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
1100+
/// : vector<2x4xf32> into vector<4x4xf32>
1101+
///
1102+
struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
1103+
UnrollShapeCastPattern(MLIRContext *context,
1104+
const vector::UnrollVectorOptions &options,
1105+
PatternBenefit benefit = 1)
1106+
: OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1107+
options(options) {}
1108+
1109+
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1110+
PatternRewriter &rewriter) const override {
1111+
auto targetShape = getTargetShape(options, shapeCastOp);
1112+
if (!targetShape)
1113+
return failure();
1114+
1115+
VectorType sourceType = shapeCastOp.getSourceVectorType();
1116+
VectorType resultType = shapeCastOp.getResultVectorType();
1117+
ArrayRef<int64_t> sourceShape = sourceType.getShape();
1118+
ArrayRef<int64_t> resultShape = resultType.getShape();
1119+
1120+
if (!isContiguousExtract(*targetShape, resultShape)) {
1121+
return rewriter.notifyMatchFailure(shapeCastOp,
1122+
"Only supports cases where contiguous "
1123+
"extraction is possible");
1124+
}
1125+
1126+
int64_t targetElements = ShapedType::getNumElements(*targetShape);
1127+
1128+
// Calculate the shape to extract from source
1129+
auto extractShape =
1130+
calculateSourceExtractShape(sourceShape, targetElements);
1131+
if (!extractShape) {
1132+
return rewriter.notifyMatchFailure(
1133+
shapeCastOp,
1134+
"cannot extract target number of elements contiguously from source");
1135+
}
1136+
1137+
Location loc = shapeCastOp.getLoc();
1138+
1139+
// Create result vector initialized to zero
1140+
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1141+
rewriter.getZeroAttr(resultType));
1142+
1143+
VectorType targetType =
1144+
VectorType::get(*targetShape, sourceType.getElementType());
1145+
1146+
SmallVector<int64_t> extractStrides(extractShape->size(), 1);
1147+
SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1148+
SmallVector<int64_t> sourceStrides = computeStrides(sourceShape);
1149+
SmallVector<int64_t> resultStrides = computeStrides(resultShape);
1150+
1151+
for (SmallVector<int64_t> resultOffsets :
1152+
StaticTileOffsetRange(resultShape, *targetShape)) {
1153+
SmallVector<int64_t> sourceOffsets =
1154+
calculateSourceOffsets(resultOffsets, sourceStrides, resultStrides);
1155+
Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1156+
loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1157+
extractStrides);
1158+
Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
1159+
loc, targetType, sourceChunk);
1160+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1161+
loc, targetChunk, result, resultOffsets, insertStrides);
1162+
}
1163+
1164+
rewriter.replaceOp(shapeCastOp, result);
1165+
return success();
1166+
}
1167+
1168+
private:
1169+
vector::UnrollVectorOptions options;
1170+
};
1171+
10061172
} // namespace
10071173

10081174
void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1179,8 @@ void mlir::vector::populateVectorUnrollPatterns(
10131179
UnrollReductionPattern, UnrollMultiReductionPattern,
10141180
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
10151181
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1016-
UnrollToElements, UnrollStepPattern>(patterns.getContext(),
1017-
options, benefit);
1182+
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
1183+
patterns.getContext(), options, benefit);
10181184
}
10191185

10201186
void mlir::vector::populateVectorToElementsUnrollPatterns(

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,37 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
496496
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
497497
// CHECK-NOT: arith.addf
498498
// CHECK: return
499+
500+
501+
func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
502+
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
503+
return %0 : vector<2x2x4xf32>
504+
}
505+
506+
// CHECK-LABEL: func @shape_cast_1D
507+
// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
508+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32>
509+
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
510+
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32>
511+
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
512+
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
513+
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32>
514+
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
515+
// CHECK: return %[[I1]] : vector<2x2x4xf32>
516+
517+
518+
func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> {
519+
%0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32>
520+
return %0 : vector<4x4xf32>
521+
}
522+
523+
// CHECK-LABEL: func @shape_cast_2D
524+
// CHECK-SAME: (%[[ARG0:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
525+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
526+
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
527+
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32>
528+
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
529+
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
530+
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32>
531+
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
532+
// CHECK: return %[[I1]] : vector<4x4xf32>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns
178178
.setFilterConstraint([](Operation *op) {
179179
return success(isa<vector::StepOp>(op));
180180
}));
181+
populateVectorUnrollPatterns(
182+
patterns, UnrollVectorOptions()
183+
.setNativeShape(ArrayRef<int64_t>{2, 4})
184+
.setFilterConstraint([](Operation *op) {
185+
return success(isa<vector::ShapeCastOp>(op));
186+
}));
181187
populateVectorUnrollPatterns(
182188
patterns, UnrollVectorOptions()
183189
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})

0 commit comments

Comments
 (0)