Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
/// the ops to get converted properly.
void populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth);
ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is indexBitWidth and targetBitWidth different here? Aren't they representing the same thing?

Copy link
Author

@ayzhuang ayzhuang Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcaballe Currently index type is not supported because we can't use getElementTypeBitWidth to get the bit width of index type. I add indexBitWidth argument to supply the bit width of index type. When it has non zero value and targetBitWidth is big enough, we can linearize vector of indices. Example: %0 = arith.addi %arg0, %arg1 : vector<2x2xindex> to %0 = arith.addi %arg0, %arg1 : vector<4xindex>.


/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
/// vector shuffle operations.
void populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth);
ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);

} // namespace vector
} // namespace mlir
Expand Down
85 changes: 59 additions & 26 deletions mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,44 @@

using namespace mlir;

static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
static bool isLessThanTargetBitWidth(Operation *op, unsigned indexBitWidth,
unsigned targetBitWidth) {
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
VectorType vecType = dyn_cast<VectorType>(resType);
// Reject index since getElementTypeBitWidth will abort for Index types.
if (!vecType || vecType.getElementType().isIndex())
if (!vecType)
return false;
bool isIndexTy = vecType.getElementType().isIndex();
// Reject index if `indexBitWidth` is not supplied.
if (isIndexTy && indexBitWidth == 0)
return false;
// There are no dimension to fold if it is a 0-D vector.
if (vecType.getRank() == 0)
return false;
unsigned trailingVecDimBitWidth =
vecType.getShape().back() * vecType.getElementTypeBitWidth();
vecType.getShape().back() *
(isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
if (trailingVecDimBitWidth >= targetBitWidth)
return false;
}
return true;
}

static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned indexBitWidth,
unsigned targetBitWidth) {
VectorType vecType = dyn_cast<VectorType>(t);
// Reject index since getElementTypeBitWidth will abort for Index types.
if (!vecType || vecType.getElementType().isIndex())
if (!vecType)
return false;
bool isIndexTy = vecType.getElementType().isIndex();
// Reject index if `indexBitWidth` is not supplied.
if (isIndexTy && indexBitWidth == 0)
return false;
// There are no dimension to fold if it is a 0-D vector.
if (vecType.getRank() == 0)
return false;
unsigned trailingVecDimBitWidth =
vecType.getShape().back() * vecType.getElementTypeBitWidth();
vecType.getShape().back() *
(isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
return trailingVecDimBitWidth <= targetBitWidth;
}

Expand All @@ -61,10 +71,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeConstant(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
}
LogicalResult
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -79,7 +91,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {

if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
if (!isLessThanTargetBitWidth(constOp, indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
loc, "Can't flatten since targetBitWidth <= OpSize");
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
Expand All @@ -93,6 +105,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
}

private:
unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};

Expand All @@ -103,14 +116,16 @@ struct LinearizeVectorizable final
public:
LinearizeVectorizable(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
if (!isLessThanTargetBitWidth(op, indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
FailureOr<Operation *> newOp =
Expand All @@ -123,6 +138,7 @@ struct LinearizeVectorizable final
}

private:
unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};

Expand All @@ -142,10 +158,12 @@ struct LinearizeVectorExtractStridedSlice final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtractStridedSlice(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
}

LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
Expand All @@ -156,7 +174,8 @@ struct LinearizeVectorExtractStridedSlice final
if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");

Expand Down Expand Up @@ -237,6 +256,7 @@ struct LinearizeVectorExtractStridedSlice final
}

private:
unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};

Expand All @@ -256,10 +276,12 @@ struct LinearizeVectorShuffle final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorShuffle(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
}

LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
Expand All @@ -273,7 +295,8 @@ struct LinearizeVectorShuffle final
shuffleOp.getV2VectorType().isScalable() ||
dstType.isScalable()) &&
"scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
if (!isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
targetVectorBitWidth))
return rewriter.notifyMatchFailure(
shuffleOp, "Can't flatten since targetBitWidth <= OpSize");

Expand Down Expand Up @@ -312,6 +335,7 @@ struct LinearizeVectorShuffle final
}

private:
unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};

Expand All @@ -329,10 +353,12 @@ struct LinearizeVectorExtract final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtract(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
}
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -345,7 +371,8 @@ struct LinearizeVectorExtract final
cast<VectorType>(dstTy).isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");

Expand Down Expand Up @@ -374,6 +401,7 @@ struct LinearizeVectorExtract final
}

private:
unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};

Expand All @@ -392,10 +420,12 @@ struct LinearizeVectorInsert final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorInsert(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
}
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -407,7 +437,7 @@ struct LinearizeVectorInsert final
"scalable vectors are not supported.");

if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
targetVectorBitWidth))
indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
insertOp, "Can't flatten since targetBitWidth < OpSize");

Expand Down Expand Up @@ -457,13 +487,14 @@ struct LinearizeVectorInsert final
}

private:
unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
} // namespace

void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth) {
ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth) {

typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
Expand All @@ -488,29 +519,31 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
[=](Operation *op) -> std::optional<bool> {
if ((isa<arith::ConstantOp>(op) ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
return (isLessThanTargetBitWidth(op, indexBitWidth, targetBitWidth)
? typeConverter.isLegal(op)
: true);
}
return std::nullopt;
});

patterns.add<LinearizeConstant, LinearizeVectorizable>(
typeConverter, patterns.getContext(), targetBitWidth);
typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
}

void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned int targetBitWidth) {
ConversionTarget &target, unsigned indexBitWidth,
unsigned int targetBitWidth) {
target.addDynamicallyLegalOp<vector::ShuffleOp>(
[=](vector::ShuffleOp shuffleOp) -> bool {
return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
return isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
targetBitWidth)
? (typeConverter.isLegal(shuffleOp) &&
cast<mlir::VectorType>(shuffleOp.getResult().getType())
.getRank() == 1)
: true;
});
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
typeConverter, patterns.getContext(), targetBitWidth);
typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
}
15 changes: 12 additions & 3 deletions mlir/test/Dialect/Vector/linearize.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=index-bitwidth=64 | FileCheck %s --check-prefixes=ALL,INDEX-BW-64

// ALL-LABEL: test_linearize
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
Expand All @@ -14,6 +15,8 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>

// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>

// INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>

// DEFAULT: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
Expand Down Expand Up @@ -45,6 +48,8 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>

// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>

// INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>

// DEFAULT: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<16xf32>
Expand Down Expand Up @@ -79,9 +84,12 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>

// -----

// ALL-LABEL: test_index_no_linearize
func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
// ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
// ALL-LABEL: test_index_linearize
func.func @test_index_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
// DEFAULT: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
// BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
// BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
// INDEX-BW-64: %[[ADD:.*]] = arith.addi {{.*}} : vector<4xindex>
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
return %0 : vector<2x2xindex>
}
Expand Down Expand Up @@ -122,6 +130,7 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32

// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
// INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<[4]xf32> to vector<2x[2]xf32>
// ALL: return %[[RES]] : vector<2x[2]xf32>
return %2 : vector<2x[2]xf32>
}
Expand Down
8 changes: 6 additions & 2 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,10 @@ struct TestVectorLinearize final
registry.insert<vector::VectorDialect>();
}

Option<unsigned> indexBitwidth{*this, "index-bitwidth",
llvm::cl::desc("Bitwidth of the index type"),
llvm::cl::init(0)};

Option<unsigned> targetVectorBitwidth{
*this, "target-vector-bitwidth",
llvm::cl::desc(
Expand All @@ -866,9 +870,9 @@ struct TestVectorLinearize final
ConversionTarget target(*context);

vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter, patterns, target, targetVectorBitwidth);
typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
vector::populateVectorLinearizeShuffleLikeOpsPatterns(
typeConverter, patterns, target, targetVectorBitwidth);
typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
Expand Down
Loading