Skip to content

Commit 2e027d8

Browse files
committed
move canonicalizers to folders
1 parent 061f87f commit 2e027d8

File tree

2 files changed

+66
-102
lines changed

2 files changed

+66
-102
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 59 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -3717,6 +3717,59 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
37173717
return getVector();
37183718
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
37193719
return getResult();
3720+
3721+
Attribute foldInput = adaptor.getVector();
3722+
if (!foldInput) {
3723+
return {};
3724+
}
3725+
3726+
// rewrite : ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3727+
if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput))
3728+
DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
3729+
3730+
// rewrite ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
3731+
if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
3732+
// TODO: Handle non-unit strides when they become available.
3733+
if (hasNonUnitStrides())
3734+
return {};
3735+
3736+
Value sourceVector = getVector();
3737+
auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3738+
ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3739+
SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3740+
3741+
VectorType sliceVecTy = getType();
3742+
ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3743+
int64_t sliceRank = sliceVecTy.getRank();
3744+
3745+
// Expand offsets and sizes to match the vector rank.
3746+
SmallVector<int64_t, 4> offsets(sliceRank, 0);
3747+
copy(getI64SubArray(getOffsets()), offsets.begin());
3748+
3749+
SmallVector<int64_t, 4> sizes(sourceShape);
3750+
copy(getI64SubArray(getSizes()), sizes.begin());
3751+
3752+
// Calculate the slice elements by enumerating all slice positions and
3753+
// linearizing them. The enumeration order is lexicographic which yields a
3754+
// sequence of monotonically increasing linearized position indices.
3755+
auto denseValuesBegin = dense.value_begin<Attribute>();
3756+
SmallVector<Attribute> sliceValues;
3757+
sliceValues.reserve(sliceVecTy.getNumElements());
3758+
SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3759+
do {
3760+
int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3761+
assert(linearizedPosition < sourceVecTy.getNumElements() &&
3762+
"Invalid index");
3763+
sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3764+
} while (
3765+
succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3766+
3767+
assert(static_cast<int64_t>(sliceValues.size()) ==
3768+
sliceVecTy.getNumElements() &&
3769+
"Invalid number of slice elements");
3770+
return DenseElementsAttr::get(sliceVecTy, sliceValues);
3771+
}
3772+
37203773
return {};
37213774
}
37223775

@@ -3781,98 +3834,6 @@ class StridedSliceConstantMaskFolder final
37813834
}
37823835
};
37833836

3784-
// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3785-
class StridedSliceSplatConstantFolder final
3786-
: public OpRewritePattern<ExtractStridedSliceOp> {
3787-
public:
3788-
using OpRewritePattern::OpRewritePattern;
3789-
3790-
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3791-
PatternRewriter &rewriter) const override {
3792-
// Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3793-
// ConstantOp.
3794-
Value sourceVector = extractStridedSliceOp.getVector();
3795-
Attribute vectorCst;
3796-
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3797-
return failure();
3798-
3799-
auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3800-
if (!splat)
3801-
return failure();
3802-
3803-
auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3804-
splat.getSplatValue<Attribute>());
3805-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3806-
newAttr);
3807-
return success();
3808-
}
3809-
};
3810-
3811-
// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3812-
// ConstantOp.
3813-
class StridedSliceNonSplatConstantFolder final
3814-
: public OpRewritePattern<ExtractStridedSliceOp> {
3815-
public:
3816-
using OpRewritePattern::OpRewritePattern;
3817-
3818-
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3819-
PatternRewriter &rewriter) const override {
3820-
// Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3821-
// ConstantOp.
3822-
Value sourceVector = extractStridedSliceOp.getVector();
3823-
Attribute vectorCst;
3824-
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3825-
return failure();
3826-
3827-
// The splat case is handled by `StridedSliceSplatConstantFolder`.
3828-
auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3829-
if (!dense || dense.isSplat())
3830-
return failure();
3831-
3832-
// TODO: Handle non-unit strides when they become available.
3833-
if (extractStridedSliceOp.hasNonUnitStrides())
3834-
return failure();
3835-
3836-
auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3837-
ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3838-
SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3839-
3840-
VectorType sliceVecTy = extractStridedSliceOp.getType();
3841-
ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3842-
int64_t sliceRank = sliceVecTy.getRank();
3843-
3844-
// Expand offsets and sizes to match the vector rank.
3845-
SmallVector<int64_t, 4> offsets(sliceRank, 0);
3846-
copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3847-
3848-
SmallVector<int64_t, 4> sizes(sourceShape);
3849-
copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3850-
3851-
// Calculate the slice elements by enumerating all slice positions and
3852-
// linearizing them. The enumeration order is lexicographic which yields a
3853-
// sequence of monotonically increasing linearized position indices.
3854-
auto denseValuesBegin = dense.value_begin<Attribute>();
3855-
SmallVector<Attribute> sliceValues;
3856-
sliceValues.reserve(sliceVecTy.getNumElements());
3857-
SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3858-
do {
3859-
int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3860-
assert(linearizedPosition < sourceVecTy.getNumElements() &&
3861-
"Invalid index");
3862-
sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3863-
} while (
3864-
succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3865-
3866-
assert(static_cast<int64_t>(sliceValues.size()) ==
3867-
sliceVecTy.getNumElements() &&
3868-
"Invalid number of slice elements");
3869-
auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3870-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3871-
newAttr);
3872-
return success();
3873-
}
3874-
};
3875-
38763837
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
38773838
// BroadcastOp(ExtractStrideSliceOp).
38783839
class StridedSliceBroadcast final
@@ -4016,8 +3977,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
40163977
RewritePatternSet &results, MLIRContext *context) {
40173978
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
40183979
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4019-
results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
4020-
StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3980+
results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
40213981
StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
40223982
context);
40233983
}
@@ -5657,10 +5617,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56575617

56585618
// shape_cast(constant) -> constant
56595619
if (auto splatAttr =
5660-
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
5661-
return DenseElementsAttr::get(resultType,
5662-
splatAttr.getSplatValue<Attribute>());
5663-
}
5620+
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
5621+
return splatAttr.reshape(getType());
56645622

56655623
// shape_cast(poison) -> poison
56665624
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6004,10 +5962,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
60045962

60055963
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
60065964
// Eliminate splat constant transpose ops.
6007-
if (auto attr =
6008-
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
6009-
if (attr.isSplat())
6010-
return attr.reshape(getResultVectorType());
5965+
if (auto splat =
5966+
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
5967+
return splat.reshape(getResultVectorType());
60115968

60125969
// Eliminate identity transpose ops. This happens when the dimensions of the
60135970
// input vector remain in their original order after the transpose operation.

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,8 @@ func.func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<
11211121
return %0, %2 : vector<4x8xf32>, vector<2xi32>
11221122
}
11231123

1124+
// -----
1125+
11241126
// CHECK-LABEL: func @bitcast_f16_to_f32
11251127
// bit pattern: 0x40004000
11261128
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<2.00390625> : vector<4xf32>
@@ -1135,6 +1137,8 @@ func.func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
11351137
return %cast0, %cast1: vector<4xf32>, vector<4xf32>
11361138
}
11371139

1140+
// -----
1141+
11381142
// CHECK-LABEL: func @bitcast_i8_to_i32
11391143
// bit pattern: 0xA0A0A0A0
11401144
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32>
@@ -1710,6 +1714,7 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
17101714
}
17111715

17121716
// -----
1717+
17131718
// CHECK-LABEL: func.func @vector_multi_reduction_scalable(
17141719
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[4]x1xf32>,
17151720
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x[4]xf32>,
@@ -2251,6 +2256,8 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
22512256
return %0 : vector<8x4xf32>
22522257
}
22532258

2259+
// -----
2260+
22542261
// CHECK-LABEL: func @transpose_splat2(
22552262
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
22562263
// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>

0 commit comments

Comments
 (0)