diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index a59f06f3c1ef1..e3c19a078c18b 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns( /// the ops to get converted properly. void populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned targetBitWidth); + ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth); /// Populates patterns for linearizing ND (N >= 2) vector operations to 1D /// vector shuffle operations. void populateVectorLinearizeShuffleLikeOpsPatterns( const TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned targetBitWidth); + ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth); } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 757631944f224..f0bf6276f0e65 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -25,34 +25,44 @@ using namespace mlir; -static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { +static bool isLessThanTargetBitWidth(Operation *op, unsigned indexBitWidth, + unsigned targetBitWidth) { auto resultTypes = op->getResultTypes(); for (auto resType : resultTypes) { VectorType vecType = dyn_cast(resType); - // Reject index since getElementTypeBitWidth will abort for Index types. - if (!vecType || vecType.getElementType().isIndex()) + if (!vecType) + return false; + bool isIndexTy = vecType.getElementType().isIndex(); + // Reject index if `indexBitWidth` is not supplied. + if (isIndexTy && indexBitWidth == 0) return false; // There are no dimension to fold if it is a 0-D vector. if (vecType.getRank() == 0) return false; unsigned trailingVecDimBitWidth = - vecType.getShape().back() * vecType.getElementTypeBitWidth(); + vecType.getShape().back() * + (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth()); if (trailingVecDimBitWidth >= targetBitWidth) return false; } return true; } -static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) { +static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned indexBitWidth, + unsigned targetBitWidth) { VectorType vecType = dyn_cast(t); - // Reject index since getElementTypeBitWidth will abort for Index types. - if (!vecType || vecType.getElementType().isIndex()) + if (!vecType) + return false; + bool isIndexTy = vecType.getElementType().isIndex(); + // Reject index if `indexBitWidth` is not supplied. + if (isIndexTy && indexBitWidth == 0) return false; // There are no dimension to fold if it is a 0-D vector. if (vecType.getRank() == 0) return false; unsigned trailingVecDimBitWidth = - vecType.getShape().back() * vecType.getElementTypeBitWidth(); + vecType.getShape().back() * + (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth()); return trailingVecDimBitWidth <= targetBitWidth; } @@ -61,10 +71,12 @@ struct LinearizeConstant final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeConstant( const TypeConverter &typeConverter, MLIRContext *context, + unsigned indexBitWidth = 0, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) { + } LogicalResult matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -79,7 +91,7 @@ struct LinearizeConstant final : OpConversionPattern { if (!resType) return rewriter.notifyMatchFailure(loc, "can't convert return type"); - if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) + if (!isLessThanTargetBitWidth(constOp, indexBitWidth, targetVectorBitWidth)) return rewriter.notifyMatchFailure( loc, "Can't flatten since targetBitWidth <= OpSize"); auto dstElementsAttr = dyn_cast(constOp.getValue()); @@ -93,6 +105,7 @@ struct LinearizeConstant final : OpConversionPattern { } private: + unsigned indexBitWidth; unsigned targetVectorBitWidth; }; @@ -103,14 +116,16 @@ struct LinearizeVectorizable final public: LinearizeVectorizable( const TypeConverter &typeConverter, MLIRContext *context, + unsigned indexBitWidth = 0, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) { + } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) + if (!isLessThanTargetBitWidth(op, indexBitWidth, targetVectorBitWidth)) return rewriter.notifyMatchFailure( op->getLoc(), "Can't flatten since targetBitWidth <= OpSize"); FailureOr newOp = @@ -123,6 +138,7 @@ struct LinearizeVectorizable final } private: + unsigned indexBitWidth; unsigned targetVectorBitWidth; }; @@ -142,10 +158,12 @@ struct LinearizeVectorExtractStridedSlice final using OpConversionPattern::OpConversionPattern; LinearizeVectorExtractStridedSlice( const TypeConverter &typeConverter, MLIRContext *context, + unsigned indexBitWidth = 0, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) { + } LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, @@ -156,7 +174,8 @@ struct LinearizeVectorExtractStridedSlice final if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) return rewriter.notifyMatchFailure(extractOp, "scalable vectors are not supported."); - if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) + if (!isLessThanTargetBitWidth(extractOp, indexBitWidth, + targetVectorBitWidth)) return rewriter.notifyMatchFailure( extractOp, "Can't flatten since targetBitWidth <= OpSize"); @@ -237,6 +256,7 @@ struct LinearizeVectorExtractStridedSlice final } private: + unsigned indexBitWidth; unsigned targetVectorBitWidth; }; @@ -256,10 +276,12 @@ struct LinearizeVectorShuffle final using OpConversionPattern::OpConversionPattern; LinearizeVectorShuffle( const TypeConverter &typeConverter, MLIRContext *context, + unsigned indexBitWidth = 0, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) { + } LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, @@ -273,7 +295,8 @@ struct LinearizeVectorShuffle final shuffleOp.getV2VectorType().isScalable() || dstType.isScalable()) && "scalable vectors are not supported."); - if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) + if (!isLessThanTargetBitWidth(shuffleOp, indexBitWidth, + targetVectorBitWidth)) return rewriter.notifyMatchFailure( shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); @@ -312,6 +335,7 @@ struct LinearizeVectorShuffle final } private: + unsigned indexBitWidth; unsigned targetVectorBitWidth; }; @@ -329,10 +353,12 @@ struct LinearizeVectorExtract final using OpConversionPattern::OpConversionPattern; LinearizeVectorExtract( const TypeConverter &typeConverter, MLIRContext *context, + unsigned indexBitWidth = 0, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) { + } LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -345,7 +371,8 @@ struct LinearizeVectorExtract final cast(dstTy).isScalable()) return rewriter.notifyMatchFailure(extractOp, "scalable vectors are not supported."); - if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) + if (!isLessThanTargetBitWidth(extractOp, indexBitWidth, + targetVectorBitWidth)) return rewriter.notifyMatchFailure( extractOp, "Can't flatten since targetBitWidth <= OpSize"); @@ -374,6 +401,7 @@ struct LinearizeVectorExtract final } private: + unsigned indexBitWidth; unsigned targetVectorBitWidth; }; @@ -392,10 +420,12 @@ struct LinearizeVectorInsert final using OpConversionPattern::OpConversionPattern; LinearizeVectorInsert( const TypeConverter &typeConverter, MLIRContext *context, + unsigned indexBitWidth = 0, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) { + } LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -407,7 +437,7 @@ struct LinearizeVectorInsert final "scalable vectors are not supported."); if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), - targetVectorBitWidth)) + indexBitWidth, targetVectorBitWidth)) return rewriter.notifyMatchFailure( insertOp, "Can't flatten since targetBitWidth < OpSize"); @@ -457,13 +487,14 @@ struct LinearizeVectorInsert final } private: + unsigned indexBitWidth; unsigned targetVectorBitWidth; }; } // namespace void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned targetBitWidth) { + ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth) { typeConverter.addConversion([](VectorType type) -> std::optional { if (!isLinearizableVector(type)) @@ -488,7 +519,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( [=](Operation *op) -> std::optional { if ((isa(op) || op->hasTrait())) { - return (isLessThanTargetBitWidth(op, targetBitWidth) + return (isLessThanTargetBitWidth(op, indexBitWidth, targetBitWidth) ? typeConverter.isLegal(op) : true); } @@ -496,15 +527,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( }); patterns.add( - typeConverter, patterns.getContext(), targetBitWidth); + typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( const TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned int targetBitWidth) { + ConversionTarget &target, unsigned indexBitWidth, + unsigned int targetBitWidth) { target.addDynamicallyLegalOp( [=](vector::ShuffleOp shuffleOp) -> bool { - return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) + return isLessThanTargetBitWidth(shuffleOp, indexBitWidth, + targetBitWidth) ? (typeConverter.isLegal(shuffleOp) && cast(shuffleOp.getResult().getType()) .getRank() == 1) @@ -512,5 +545,5 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( }); patterns.add( - typeConverter, patterns.getContext(), targetBitWidth); + typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 543e76b5b26e0..fe169d3e16d68 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT // RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128 // RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0 +// RUN: mlir-opt %s -split-input-file -test-vector-linearize=index-bitwidth=64 | FileCheck %s --check-prefixes=ALL,INDEX-BW-64 // ALL-LABEL: test_linearize // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>) @@ -14,6 +15,8 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> { // BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32> // BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32> + + // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32> %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> // DEFAULT: %{{.*}} = math.sin %[[ARG]] : vector<4xf32> @@ -45,6 +48,8 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32> // BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32> // BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32> + + // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32> %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> // DEFAULT: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<16xf32> @@ -79,9 +84,12 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32> // ----- -// ALL-LABEL: test_index_no_linearize -func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> { - // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> +// ALL-LABEL: test_index_linearize +func.func @test_index_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> { + // DEFAULT: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + // BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + // INDEX-BW-64: %[[ADD:.*]] = arith.addi {{.*}} : vector<4xindex> %0 = arith.addi %arg0, %arg1 : vector<2x2xindex> return %0 : vector<2x2xindex> } @@ -122,6 +130,7 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32 // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32> // BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32> + // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<[4]xf32> to vector<2x[2]xf32> // ALL: return %[[RES]] : vector<2x[2]xf32> return %2 : vector<2x[2]xf32> } diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f67a24755ac09..2589782aee144 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -853,6 +853,10 @@ struct TestVectorLinearize final registry.insert(); } + Option indexBitwidth{*this, "index-bitwidth", + llvm::cl::desc("Bitwidth of the index type"), + llvm::cl::init(0)}; + Option targetVectorBitwidth{ *this, "target-vector-bitwidth", llvm::cl::desc( @@ -866,9 +870,9 @@ struct TestVectorLinearize final ConversionTarget target(*context); vector::populateVectorLinearizeTypeConversionsAndLegality( - typeConverter, patterns, target, targetVectorBitwidth); + typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth); vector::populateVectorLinearizeShuffleLikeOpsPatterns( - typeConverter, patterns, target, targetVectorBitwidth); + typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure();