From d82a741ae91b73dff02dbd19c65e0fd7c6198c9e Mon Sep 17 00:00:00 2001 From: Amy Zhuang Date: Mon, 2 Dec 2024 23:54:21 +0200 Subject: [PATCH] [mlir][vector] Support index type in ND to 1D vector linearization Currently index type is not supported because getElementTypeBitWidth aborts for index type. This patch adds indexBitWidth input to the vector linearization patterns. --- .../Vector/Transforms/VectorRewritePatterns.h | 4 +- .../Vector/Transforms/VectorLinearize.cpp | 85 +++++++++++++------ mlir/test/Dialect/Vector/linearize.mlir | 15 +++- .../Dialect/Vector/TestVectorTransforms.cpp | 8 +- 4 files changed, 79 insertions(+), 33 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index a59f06f3c1ef1..e3c19a078c18b 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns( /// the ops to get converted properly. void populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned targetBitWidth); + ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth); /// Populates patterns for linearizing ND (N >= 2) vector operations to 1D /// vector shuffle operations. void populateVectorLinearizeShuffleLikeOpsPatterns( const TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned targetBitWidth); + ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth); } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 757631944f224..f0bf6276f0e65 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -25,34 +25,44 @@ using namespace mlir; -static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { +static bool isLessThanTargetBitWidth(Operation *op, unsigned indexBitWidth, + unsigned targetBitWidth) { auto resultTypes = op->getResultTypes(); for (auto resType : resultTypes) { VectorType vecType = dyn_cast(resType); - // Reject index since getElementTypeBitWidth will abort for Index types. - if (!vecType || vecType.getElementType().isIndex()) + if (!vecType) + return false; + bool isIndexTy = vecType.getElementType().isIndex(); + // Reject index if `indexBitWidth` is not supplied. + if (isIndexTy && indexBitWidth == 0) return false; // There are no dimension to fold if it is a 0-D vector. if (vecType.getRank() == 0) return false; unsigned trailingVecDimBitWidth = - vecType.getShape().back() * vecType.getElementTypeBitWidth(); + vecType.getShape().back() * + (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth()); if (trailingVecDimBitWidth >= targetBitWidth) return false; } return true; } -static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) { +static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned indexBitWidth, + unsigned targetBitWidth) { VectorType vecType = dyn_cast(t); - // Reject index since getElementTypeBitWidth will abort for Index types. - if (!vecType || vecType.getElementType().isIndex()) + if (!vecType) + return false; + bool isIndexTy = vecType.getElementType().isIndex(); + // Reject index if `indexBitWidth` is not supplied. + if (isIndexTy && indexBitWidth == 0) return false; // There are no dimension to fold if it is a 0-D vector. if (vecType.getRank() == 0) return false; unsigned trailingVecDimBitWidth = - vecType.getShape().back() * vecType.getElementTypeBitWidth(); + vecType.getShape().back() * + (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth()); return trailingVecDimBitWidth <= targetBitWidth; } @@ -61,10 +71,12 @@ struct LinearizeConstant final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeConstant( const TypeConverter &typeConverter, MLIRContext *context, + unsigned indexBitWidth = 0, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) { + } LogicalResult matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -79,7 +91,7 @@ struct LinearizeConstant final : OpConversionPattern { if (!resType) return rewriter.notifyMatchFailure(loc, "can't convert return type"); - if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) + if (!isLessThanTargetBitWidth(constOp, indexBitWidth, targetVectorBitWidth)) return rewriter.notifyMatchFailure( loc, "Can't flatten since targetBitWidth <= OpSize"); auto dstElementsAttr = dyn_cast(constOp.getValue()); @@ -93,6 +105,7 @@ struct LinearizeConstant final : OpConversionPattern { } private: + unsigned indexBitWidth; unsigned targetVectorBitWidth; }; @@ -103,14 +116,16 @@ struct LinearizeVectorizable final public: LinearizeVectorizable( const TypeConverter &typeConverter, MLIRContext *context, + unsigned indexBitWidth = 0, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) { + } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) + if (!isLessThanTargetBitWidth(op, indexBitWidth, targetVectorBitWidth)) return rewriter.notifyMatchFailure( op->getLoc(), "Can't flatten since targetBitWidth <= OpSize"); FailureOr newOp = @@ -123,6 +138,7 @@ struct LinearizeVectorizable final } private: + unsigned indexBitWidth; unsigned targetVectorBitWidth; }; @@ -142,10 +158,12 @@ struct LinearizeVectorExtractStridedSlice final using OpConversionPattern::OpConversionPattern; LinearizeVectorExtractStridedSlice( const TypeConverter &typeConverter, MLIRContext *context, + unsigned indexBitWidth = 0, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) { + } LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, @@ -156,7 +174,8 @@ struct LinearizeVectorExtractStridedSlice final if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) return rewriter.notifyMatchFailure(extractOp, "scalable vectors are not supported."); - if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) + if (!isLessThanTargetBitWidth(extractOp, indexBitWidth, + targetVectorBitWidth)) return rewriter.notifyMatchFailure( extractOp, "Can't flatten since targetBitWidth <= OpSize"); @@ -237,6 +256,7 @@ struct LinearizeVectorExtractStridedSlice final } private: + unsigned indexBitWidth; unsigned targetVectorBitWidth; }; @@ -256,10 +276,12 @@ struct LinearizeVectorShuffle final using OpConversionPattern::OpConversionPattern; LinearizeVectorShuffle( const TypeConverter &typeConverter, MLIRContext *context, + unsigned indexBitWidth = 0, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) { + } LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, @@ -273,7 +295,8 @@ struct LinearizeVectorShuffle final shuffleOp.getV2VectorType().isScalable() || dstType.isScalable()) && "scalable vectors are not supported."); - if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) + if (!isLessThanTargetBitWidth(shuffleOp, indexBitWidth, + targetVectorBitWidth)) return rewriter.notifyMatchFailure( shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); @@ -312,6 +335,7 @@ struct LinearizeVectorShuffle final } private: + unsigned indexBitWidth; unsigned targetVectorBitWidth; }; @@ -329,10 +353,12 @@ struct LinearizeVectorExtract final using OpConversionPattern::OpConversionPattern; LinearizeVectorExtract( const TypeConverter &typeConverter, MLIRContext *context, + unsigned indexBitWidth = 0, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) { + } LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -345,7 +371,8 @@ struct LinearizeVectorExtract final cast(dstTy).isScalable()) return rewriter.notifyMatchFailure(extractOp, "scalable vectors are not supported."); - if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) + if (!isLessThanTargetBitWidth(extractOp, indexBitWidth, + targetVectorBitWidth)) return rewriter.notifyMatchFailure( extractOp, "Can't flatten since targetBitWidth <= OpSize"); @@ -374,6 +401,7 @@ struct LinearizeVectorExtract final } private: + unsigned indexBitWidth; unsigned targetVectorBitWidth; }; @@ -392,10 +420,12 @@ struct LinearizeVectorInsert final using OpConversionPattern::OpConversionPattern; LinearizeVectorInsert( const TypeConverter &typeConverter, MLIRContext *context, + unsigned indexBitWidth = 0, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) { + } LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -407,7 +437,7 @@ struct LinearizeVectorInsert final "scalable vectors are not supported."); if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), - targetVectorBitWidth)) + indexBitWidth, targetVectorBitWidth)) return rewriter.notifyMatchFailure( insertOp, "Can't flatten since targetBitWidth < OpSize"); @@ -457,13 +487,14 @@ struct LinearizeVectorInsert final } private: + unsigned indexBitWidth; unsigned targetVectorBitWidth; }; } // namespace void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned targetBitWidth) { + ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth) { typeConverter.addConversion([](VectorType type) -> std::optional { if (!isLinearizableVector(type)) @@ -488,7 +519,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( [=](Operation *op) -> std::optional { if ((isa(op) || op->hasTrait())) { - return (isLessThanTargetBitWidth(op, targetBitWidth) + return (isLessThanTargetBitWidth(op, indexBitWidth, targetBitWidth) ? typeConverter.isLegal(op) : true); } @@ -496,15 +527,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( }); patterns.add( - typeConverter, patterns.getContext(), targetBitWidth); + typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( const TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned int targetBitWidth) { + ConversionTarget &target, unsigned indexBitWidth, + unsigned int targetBitWidth) { target.addDynamicallyLegalOp( [=](vector::ShuffleOp shuffleOp) -> bool { - return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) + return isLessThanTargetBitWidth(shuffleOp, indexBitWidth, + targetBitWidth) ? (typeConverter.isLegal(shuffleOp) && cast(shuffleOp.getResult().getType()) .getRank() == 1) @@ -512,5 +545,5 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( }); patterns.add( - typeConverter, patterns.getContext(), targetBitWidth); + typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 543e76b5b26e0..fe169d3e16d68 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT // RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128 // RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0 +// RUN: mlir-opt %s -split-input-file -test-vector-linearize=index-bitwidth=64 | FileCheck %s --check-prefixes=ALL,INDEX-BW-64 // ALL-LABEL: test_linearize // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>) @@ -14,6 +15,8 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> { // BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32> // BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32> + + // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32> %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> // DEFAULT: %{{.*}} = math.sin %[[ARG]] : vector<4xf32> @@ -45,6 +48,8 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32> // BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32> // BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32> + + // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32> %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> // DEFAULT: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<16xf32> @@ -79,9 +84,12 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32> // ----- -// ALL-LABEL: test_index_no_linearize -func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> { - // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> +// ALL-LABEL: test_index_linearize +func.func @test_index_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> { + // DEFAULT: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + // BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + // INDEX-BW-64: %[[ADD:.*]] = arith.addi {{.*}} : vector<4xindex> %0 = arith.addi %arg0, %arg1 : vector<2x2xindex> return %0 : vector<2x2xindex> } @@ -122,6 +130,7 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32 // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32> // BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32> + // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<[4]xf32> to vector<2x[2]xf32> // ALL: return %[[RES]] : vector<2x[2]xf32> return %2 : vector<2x[2]xf32> } diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f67a24755ac09..2589782aee144 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -853,6 +853,10 @@ struct TestVectorLinearize final registry.insert(); } + Option indexBitwidth{*this, "index-bitwidth", + llvm::cl::desc("Bitwidth of the index type"), + llvm::cl::init(0)}; + Option targetVectorBitwidth{ *this, "target-vector-bitwidth", llvm::cl::desc( @@ -866,9 +870,9 @@ struct TestVectorLinearize final ConversionTarget target(*context); vector::populateVectorLinearizeTypeConversionsAndLegality( - typeConverter, patterns, target, targetVectorBitwidth); + typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth); vector::populateVectorLinearizeShuffleLikeOpsPatterns( - typeConverter, patterns, target, targetVectorBitwidth); + typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure();