From d067a5ab18604996b9290e98701e3d8ac04efe7b Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 15 Jan 2025 19:10:38 +0000 Subject: [PATCH 1/6] add linearize pattern for bitcast --- .../Vector/Transforms/VectorLinearize.cpp | 37 +++++++++++++++++-- mlir/test/Dialect/Vector/linearize.mlir | 21 ++++++++++- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 68535ae5a7a5c..b450ea91fef65 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -72,13 +72,14 @@ struct LinearizeConstant final : OpConversionPattern { auto resType = getTypeConverter()->convertType(constOp.getType()); + if (!resType) + return rewriter.notifyMatchFailure(loc, "can't convert return type"); + if (resType.isScalable() && !isa(constOp.getValue())) return rewriter.notifyMatchFailure( loc, "Cannot linearize a constant scalable vector that's not a splat"); - if (!resType) - return rewriter.notifyMatchFailure(loc, "can't convert return type"); if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) return rewriter.notifyMatchFailure( loc, "Can't flatten since targetBitWidth <= OpSize"); @@ -459,6 +460,35 @@ struct LinearizeVectorInsert final private: unsigned targetVectorBitWidth; }; + +struct LinearizeVectorBitCast final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorBitCast( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + LogicalResult + matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = castOp.getLoc(); + auto resType = getTypeConverter()->convertType(castOp.getType()); + if (!resType) + return rewriter.notifyMatchFailure(loc, "can't convert return type."); + + if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth)) + return rewriter.notifyMatchFailure( + loc, "Can't flatten since targetBitWidth <= OpSize"); + + rewriter.replaceOpWithNewOp(castOp, resType, adaptor.getSource()); + return mlir::success(); + } +private: + unsigned targetVectorBitWidth; +}; + } // namespace void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( @@ -486,6 +516,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional { if ((isa(op) || + isa(op) || op->hasTrait())) { return (isLessThanTargetBitWidth(op, targetBitWidth) ? typeConverter.isLegal(op) @@ -494,7 +525,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( return std::nullopt; }); - patterns.add( + patterns.add( typeConverter, patterns.getContext(), targetBitWidth); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 543e76b5b26e0..0358c2637f72b 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -179,7 +179,7 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf // ALL-LABEL: func.func @test_extract_strided_slice_1_scalable( // ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> { -func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> { +func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> { // ALL-NOT: vector.shuffle // ALL-NOT: vector.shape_cast // ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32> @@ -318,3 +318,22 @@ func.func @test_vector_extract_scalar() { %0 = vector.extract %cst[0] : i32 from vector<4xi32> return } + +// ----- + +// ALL-LABEL: test_vector_bitcast +// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x1xf32>) +func.func @test_vector_bitcast(%arg0: vector<4x1xf32>) -> vector<4x2xf16> { + + // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32> + // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16> + // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16> + + // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32> + // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16> + // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16> + + // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4x2xf16> + %1 = vector.bitcast %arg0 : vector<4x1xf32> to vector<4x2xf16> + return %1 : vector<4x2xf16> +} From 5f358831c7f25bc385a3d00f8a340766adbf1170 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 15 Jan 2025 19:36:22 +0000 Subject: [PATCH 2/6] code format --- .../Dialect/Vector/Transforms/VectorLinearize.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index b450ea91fef65..a89d2872e2434 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -482,9 +482,11 @@ struct LinearizeVectorBitCast final return rewriter.notifyMatchFailure( loc, "Can't flatten since targetBitWidth <= OpSize"); - rewriter.replaceOpWithNewOp(castOp, resType, adaptor.getSource()); + rewriter.replaceOpWithNewOp(castOp, resType, + adaptor.getSource()); return mlir::success(); } + private: unsigned targetVectorBitWidth; }; @@ -515,8 +517,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( typeConverter.addTargetMaterialization(materializeCast); target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional { - if ((isa(op) || - isa(op) || + if ((isa(op) || isa(op) || op->hasTrait())) { return (isLessThanTargetBitWidth(op, targetBitWidth) ? typeConverter.isLegal(op) @@ -525,8 +526,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( return std::nullopt; }); - patterns.add( - typeConverter, patterns.getContext(), targetBitWidth); + patterns + .add( + typeConverter, patterns.getContext(), targetBitWidth); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( From f0b3ae0d45ebaea3659fed486a8cc973673814f6 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 16 Jan 2025 15:18:09 +0000 Subject: [PATCH 3/6] update tests --- mlir/test/Dialect/Vector/linearize.mlir | 37 +++++++++++++++++-------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 0358c2637f72b..bab5c6c15ed8f 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -322,18 +322,31 @@ func.func @test_vector_extract_scalar() { // ----- // ALL-LABEL: test_vector_bitcast -// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x1xf32>) -func.func @test_vector_bitcast(%arg0: vector<4x1xf32>) -> vector<4x2xf16> { - - // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32> - // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16> - // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16> +// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x4xf32> +func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> { + // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x4xf32> to vector<16xf32> + // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<16xf32> to vector<32xf16> + // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<32xf16> to vector<4x8xf16> + + // BW-128: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16> + // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16> + %1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16> + return %1 : vector<4x8xf16> +} - // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32> - // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16> - // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16> +// ----- - // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4x2xf16> - %1 = vector.bitcast %arg0 : vector<4x1xf32> to vector<4x2xf16> - return %1 : vector<4x2xf16> +// ALL-LABEL: test_vector_bitcast +// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x2xf32> +func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> { + // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32> + // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<8xf32> to vector<16xf16> + // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<16xf16> to vector<4x4xf16> + // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32> + // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<8xf32> to vector<16xf16> + // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<16xf16> to vector<4x4xf16> + + // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x2xf32> to vector<4x4xf16> + %1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16> + return %1 : vector<4x4xf16> } From 25a8f39e7841d28c459a390175ea50cebf737b74 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 16 Jan 2025 19:34:32 +0000 Subject: [PATCH 4/6] add test for scalable vector --- .../Vector/Transforms/VectorLinearize.cpp | 8 +++++ mlir/test/Dialect/Vector/linearize.mlir | 33 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index a89d2872e2434..3ecd585c5a26d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -461,6 +461,14 @@ struct LinearizeVectorInsert final unsigned targetVectorBitWidth; }; +/// This pattern converts the BitCastOp that works on nD (n > 1) +/// vectors to a BitCastOp that works on linearized vectors. +/// Following, +/// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16> +/// is converted to : +/// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32> +/// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16> +/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> struct LinearizeVectorBitCast final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index bab5c6c15ed8f..de757fb9e4c1a 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -350,3 +350,36 @@ func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> { %1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16> return %1 : vector<4x4xf16> } + +// ----- + +// ALL-LABEL: test_vector_bitcast +// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x[2]xf32> +func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> { + // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32> + // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16> + // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<4x[4]xf16> + // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32> + // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16> + // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<4x[4]xf16> + + // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<4x[4]xf16> + %1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16> + return %1 : vector<4x[4]xf16> +} + +// ----- +// ALL-LABEL: test_vector_bitcast +// ALL-SAME: %[[ORIG_ARG:.*]]: vector<[4]x2xf32> +func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> { + // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32> + // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16> + // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<[4]x4xf16> + // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32> + // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16> + // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<[4]x4xf16> + + // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[4]x4xf16> + %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16> + return %1 : vector<[4]x4xf16> +} From ea2e518149180e77c2652fff69b2c140b9cec95b Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Mon, 27 Jan 2025 15:14:26 +0000 Subject: [PATCH 5/6] fix naming in tests --- mlir/test/Dialect/Vector/linearize.mlir | 58 ++++++++++++------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index de757fb9e4c1a..8279aac07245d 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -324,12 +324,12 @@ func.func @test_vector_extract_scalar() { // ALL-LABEL: test_vector_bitcast // ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x4xf32> func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> { - // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x4xf32> to vector<16xf32> - // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<16xf32> to vector<32xf16> - // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<32xf16> to vector<4x8xf16> + // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x4xf32> to vector<16xf32> + // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<16xf32> to vector<32xf16> + // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<32xf16> to vector<4x8xf16> - // BW-128: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16> - // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16> + // BW-128: %[[UPCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16> + // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16> %1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16> return %1 : vector<4x8xf16> } @@ -339,14 +339,14 @@ func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> { // ALL-LABEL: test_vector_bitcast // ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x2xf32> func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> { - // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32> - // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<8xf32> to vector<16xf16> - // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<16xf16> to vector<4x4xf16> - // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32> - // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<8xf32> to vector<16xf16> - // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<16xf16> to vector<4x4xf16> - - // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x2xf32> to vector<4x4xf16> + // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32> + // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16> + // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16> + // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32> + // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16> + // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16> + + // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x2xf32> to vector<4x4xf16> %1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16> return %1 : vector<4x4xf16> } @@ -356,14 +356,14 @@ func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> { // ALL-LABEL: test_vector_bitcast // ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x[2]xf32> func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> { - // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32> - // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16> - // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<4x[4]xf16> - // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32> - // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16> - // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<4x[4]xf16> - - // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<4x[4]xf16> + // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32> + // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> + // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16> + // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32> + // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> + // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16> + + // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<4x[4]xf16> %1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16> return %1 : vector<4x[4]xf16> } @@ -372,14 +372,14 @@ func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> { // ALL-LABEL: test_vector_bitcast // ALL-SAME: %[[ORIG_ARG:.*]]: vector<[4]x2xf32> func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> { - // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32> - // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16> - // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<[4]x4xf16> - // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32> - // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16> - // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<[4]x4xf16> - - // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[4]x4xf16> + // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32> + // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> + // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16> + // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32> + // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> + // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16> + + // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[4]x4xf16> %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16> return %1 : vector<[4]x4xf16> } From 054b70061c77cc012fb8742d9b2b4c0c6e9fde7a Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Mon, 27 Jan 2025 17:17:28 +0000 Subject: [PATCH 6/6] fix naming in tests --- mlir/test/Dialect/Vector/linearize.mlir | 32 ++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 8279aac07245d..99b1bbab1eede 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -322,14 +322,14 @@ func.func @test_vector_extract_scalar() { // ----- // ALL-LABEL: test_vector_bitcast -// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x4xf32> +// ALL-SAME: %[[ARG_0:.*]]: vector<4x4xf32> func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> { - // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x4xf32> to vector<16xf32> + // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x4xf32> to vector<16xf32> // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<16xf32> to vector<32xf16> // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<32xf16> to vector<4x8xf16> - // BW-128: %[[UPCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16> - // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16> + // BW-128: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16> + // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16> %1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16> return %1 : vector<4x8xf16> } @@ -337,16 +337,16 @@ func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> { // ----- // ALL-LABEL: test_vector_bitcast -// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x2xf32> +// ALL-SAME: %[[ARG_0:.*]]: vector<4x2xf32> func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> { - // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32> + // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32> // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16> // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16> - // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32> + // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32> // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16> // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16> - // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x2xf32> to vector<4x4xf16> + // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x2xf32> to vector<4x4xf16> %1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16> return %1 : vector<4x4xf16> } @@ -354,32 +354,32 @@ func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> { // ----- // ALL-LABEL: test_vector_bitcast -// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x[2]xf32> +// ALL-SAME: %[[ARG_0:.*]]: vector<4x[2]xf32> func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> { - // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32> + // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32> // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16> - // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32> + // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32> // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16> - // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<4x[4]xf16> + // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x[2]xf32> to vector<4x[4]xf16> %1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16> return %1 : vector<4x[4]xf16> } // ----- // ALL-LABEL: test_vector_bitcast -// ALL-SAME: %[[ORIG_ARG:.*]]: vector<[4]x2xf32> +// ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32> func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> { - // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32> + // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32> // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16> - // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32> + // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32> // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16> - // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[4]x4xf16> + // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<[4]x2xf32> to vector<[4]x4xf16> %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16> return %1 : vector<[4]x4xf16> }