Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 65 additions & 102 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2590,6 +2590,8 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
}
if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
return ub::PoisonAttr::get(getContext());
return {};
}

Expand Down Expand Up @@ -3717,6 +3719,59 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
return getVector();
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
return getResult();

Attribute foldInput = adaptor.getVector();
if (!foldInput) {
return {};
}

// rewrite : ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput))
DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());

// rewrite ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
// TODO: Handle non-unit strides when they become available.
if (hasNonUnitStrides())
return {};

Value sourceVector = getVector();
auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);

VectorType sliceVecTy = getType();
ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
int64_t sliceRank = sliceVecTy.getRank();

// Expand offsets and sizes to match the vector rank.
SmallVector<int64_t, 4> offsets(sliceRank, 0);
copy(getI64SubArray(getOffsets()), offsets.begin());

SmallVector<int64_t, 4> sizes(sourceShape);
copy(getI64SubArray(getSizes()), sizes.begin());

// Calculate the slice elements by enumerating all slice positions and
// linearizing them. The enumeration order is lexicographic which yields a
// sequence of monotonically increasing linearized position indices.
auto denseValuesBegin = dense.value_begin<Attribute>();
SmallVector<Attribute> sliceValues;
sliceValues.reserve(sliceVecTy.getNumElements());
SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
do {
int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
assert(linearizedPosition < sourceVecTy.getNumElements() &&
"Invalid index");
sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
} while (
succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));

assert(static_cast<int64_t>(sliceValues.size()) ==
sliceVecTy.getNumElements() &&
"Invalid number of slice elements");
return DenseElementsAttr::get(sliceVecTy, sliceValues);
}

return {};
}

Expand Down Expand Up @@ -3781,98 +3836,6 @@ class StridedSliceConstantMaskFolder final
}
};

// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
class StridedSliceSplatConstantFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
PatternRewriter &rewriter) const override {
// Return if 'ExtractStridedSliceOp' operand is not defined by a splat
// ConstantOp.
Value sourceVector = extractStridedSliceOp.getVector();
Attribute vectorCst;
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
return failure();

auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
if (!splat)
return failure();

auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
splat.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
newAttr);
return success();
}
};

// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
// ConstantOp.
class StridedSliceNonSplatConstantFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
PatternRewriter &rewriter) const override {
// Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
// ConstantOp.
Value sourceVector = extractStridedSliceOp.getVector();
Attribute vectorCst;
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
return failure();

// The splat case is handled by `StridedSliceSplatConstantFolder`.
auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
if (!dense || dense.isSplat())
return failure();

// TODO: Handle non-unit strides when they become available.
if (extractStridedSliceOp.hasNonUnitStrides())
return failure();

auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);

VectorType sliceVecTy = extractStridedSliceOp.getType();
ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
int64_t sliceRank = sliceVecTy.getRank();

// Expand offsets and sizes to match the vector rank.
SmallVector<int64_t, 4> offsets(sliceRank, 0);
copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());

SmallVector<int64_t, 4> sizes(sourceShape);
copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());

// Calculate the slice elements by enumerating all slice positions and
// linearizing them. The enumeration order is lexicographic which yields a
// sequence of monotonically increasing linearized position indices.
auto denseValuesBegin = dense.value_begin<Attribute>();
SmallVector<Attribute> sliceValues;
sliceValues.reserve(sliceVecTy.getNumElements());
SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
do {
int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
assert(linearizedPosition < sourceVecTy.getNumElements() &&
"Invalid index");
sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
} while (
succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));

assert(static_cast<int64_t>(sliceValues.size()) ==
sliceVecTy.getNumElements() &&
"Invalid number of slice elements");
auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
newAttr);
return success();
}
};

// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
// BroadcastOp(ExtractStrideSliceOp).
class StridedSliceBroadcast final
Expand Down Expand Up @@ -4016,8 +3979,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
context);
}
Expand Down Expand Up @@ -5654,10 +5616,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {

// shape_cast(constant) -> constant
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
return DenseElementsAttr::get(resultType,
splatAttr.getSplatValue<Attribute>());
}
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
return splatAttr.reshape(getType());

// shape_cast(poison) -> poison
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
Expand Down Expand Up @@ -6001,10 +5961,13 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,

OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
if (auto attr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
if (attr.isSplat())
return attr.reshape(getResultVectorType());
if (auto splat =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
return splat.reshape(getResultVectorType());

// Eliminate poison transpose ops.
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
return ub::PoisonAttr::get(getContext());

// Eliminate identity transpose ops. This happens when the dimensions of the
// input vector remain in their original order after the transpose operation.
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,8 @@ func.func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<
return %0, %2 : vector<4x8xf32>, vector<2xi32>
}

// -----

// CHECK-LABEL: func @bitcast_f16_to_f32
// bit pattern: 0x40004000
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<2.00390625> : vector<4xf32>
Expand All @@ -1135,6 +1137,8 @@ func.func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
return %cast0, %cast1: vector<4xf32>, vector<4xf32>
}

// -----

// CHECK-LABEL: func @bitcast_i8_to_i32
// bit pattern: 0xA0A0A0A0
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32>
Expand Down Expand Up @@ -1176,6 +1180,28 @@ func.func @broadcast_folding2() -> vector<4x16xi32> {

// -----

// CHECK-LABEL: broadcast_poison
// CHECK: %[[POISON:.*]] = ub.poison : vector<4x6xi8>
// CHECK: return %[[POISON]] : vector<4x6xi8>
func.func @broadcast_poison() -> vector<4x6xi8> {
%poison = ub.poison : vector<6xi8>
%broadcast = vector.broadcast %poison : vector<6xi8> to vector<4x6xi8>
return %broadcast : vector<4x6xi8>
}

// -----

// CHECK-LABEL: broadcast_splat_constant
// CHECK: %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
// CHECK: return %[[CONST]] : vector<4x6xi8>
func.func @broadcast_splat_constant() -> vector<4x6xi8> {
%cst = arith.constant dense<1> : vector<6xi8>
%broadcast = vector.broadcast %cst : vector<6xi8> to vector<4x6xi8>
return %broadcast : vector<4x6xi8>
}

// -----

// CHECK-LABEL: @fold_consecutive_broadcasts(
// CHECK-SAME: %[[ARG0:.*]]: i32
// CHECK: %[[RESULT:.*]] = vector.broadcast %[[ARG0]] : i32 to vector<4x16xi32>
Expand Down Expand Up @@ -1710,6 +1736,7 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
}

// -----

// CHECK-LABEL: func.func @vector_multi_reduction_scalable(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[4]x1xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x[4]xf32>,
Expand Down Expand Up @@ -2251,6 +2278,8 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
return %0 : vector<8x4xf32>
}

// -----

// CHECK-LABEL: func @transpose_splat2(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>
Expand All @@ -2264,6 +2293,17 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {

// -----

// CHECK-LABEL: transpose_poison
// CHECK: %[[POISON:.*]] = ub.poison : vector<4x6xi8>
// CHECK: return %[[POISON]] : vector<4x6xi8>
func.func @transpose_poison() -> vector<4x6xi8> {
%poison = ub.poison : vector<6x4xi8>
%transpose = vector.transpose %poison, [1, 0] : vector<6x4xi8> to vector<4x6xi8>
return %transpose : vector<4x6xi8>
}

// -----

// CHECK-LABEL: func.func @insert_1d_constant
// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>
Expand Down