diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 68535ae5a7a5c..3ecd585c5a26d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -72,13 +72,14 @@ struct LinearizeConstant final : OpConversionPattern { auto resType = getTypeConverter()->convertType(constOp.getType()); + if (!resType) + return rewriter.notifyMatchFailure(loc, "can't convert return type"); + if (resType.isScalable() && !isa(constOp.getValue())) return rewriter.notifyMatchFailure( loc, "Cannot linearize a constant scalable vector that's not a splat"); - if (!resType) - return rewriter.notifyMatchFailure(loc, "can't convert return type"); if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) return rewriter.notifyMatchFailure( loc, "Can't flatten since targetBitWidth <= OpSize"); @@ -459,6 +460,45 @@ struct LinearizeVectorInsert final private: unsigned targetVectorBitWidth; }; + +/// This pattern converts the BitCastOp that works on nD (n > 1) +/// vectors to a BitCastOp that works on linearized vectors. +/// Following, +/// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16> +/// is converted to : +/// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32> +/// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16> +/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> +struct LinearizeVectorBitCast final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorBitCast( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + LogicalResult + matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = castOp.getLoc(); + auto resType = getTypeConverter()->convertType(castOp.getType()); + if (!resType) + return rewriter.notifyMatchFailure(loc, "can't convert return type."); + + if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth)) + return rewriter.notifyMatchFailure( + loc, "Can't flatten since targetBitWidth <= OpSize"); + + rewriter.replaceOpWithNewOp(castOp, resType, + adaptor.getSource()); + return mlir::success(); + } + +private: + unsigned targetVectorBitWidth; +}; + } // namespace void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( @@ -485,7 +525,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( typeConverter.addTargetMaterialization(materializeCast); target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional { - if ((isa(op) || + if ((isa(op) || isa(op) || op->hasTrait())) { return (isLessThanTargetBitWidth(op, targetBitWidth) ? typeConverter.isLegal(op) @@ -494,8 +534,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( return std::nullopt; }); - patterns.add( - typeConverter, patterns.getContext(), targetBitWidth); + patterns + .add( + typeConverter, patterns.getContext(), targetBitWidth); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 543e76b5b26e0..99b1bbab1eede 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -179,7 +179,7 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf // ALL-LABEL: func.func @test_extract_strided_slice_1_scalable( // ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> { -func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> { +func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> { // ALL-NOT: vector.shuffle // ALL-NOT: vector.shape_cast // ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32> @@ -318,3 +318,68 @@ func.func @test_vector_extract_scalar() { %0 = vector.extract %cst[0] : i32 from vector<4xi32> return } + +// ----- + +// ALL-LABEL: test_vector_bitcast +// ALL-SAME: %[[ARG_0:.*]]: vector<4x4xf32> +func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> { + // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x4xf32> to vector<16xf32> + // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<16xf32> to vector<32xf16> + // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<32xf16> to vector<4x8xf16> + + // BW-128: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16> + // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16> + %1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16> + return %1 : vector<4x8xf16> +} + +// ----- + +// ALL-LABEL: test_vector_bitcast +// ALL-SAME: %[[ARG_0:.*]]: vector<4x2xf32> +func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> { + // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32> + // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16> + // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16> + // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32> + // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16> + // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16> + + // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x2xf32> to vector<4x4xf16> + %1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16> + return %1 : vector<4x4xf16> +} + +// ----- + +// ALL-LABEL: test_vector_bitcast +// ALL-SAME: %[[ARG_0:.*]]: vector<4x[2]xf32> +func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> { + // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32> + // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> + // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16> + // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32> + // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> + // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16> + + // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x[2]xf32> to vector<4x[4]xf16> + %1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16> + return %1 : vector<4x[4]xf16> +} + +// ----- +// ALL-LABEL: test_vector_bitcast +// ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32> +func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> { + // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32> + // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> + // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16> + // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32> + // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> + // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16> + + // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<[4]x2xf32> to vector<[4]x4xf16> + %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16> + return %1 : vector<[4]x4xf16> +}