From c44404b91ff2e99b56fecc613ce21268f6f43109 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Sat, 11 Oct 2025 00:09:38 +0000 Subject: [PATCH 1/4] [MLIR][Conversion] Vector to LLVM: Remove unneeded vectorshuffle if vector.broadcast source is a scalar and target is a single element 1D vector. --- .../Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 6 +++++- mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 11 +++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5355909b62a7f..2bab94f82723e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1723,12 +1723,16 @@ struct VectorBroadcastScalarToLowRankLowering return success(); } - // For 1-d vector, we additionally do a `vectorshuffle`. + // For 1-d vector, we additionally do a `vectorshuffle` if vector width > 1. auto v = LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero); int64_t width = cast(broadcast.getType()).getDimSize(0); + if (width == 1) { + rewriter.replaceOp(broadcast, v); + return success(); + } SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 2d33888854ea7..f704b8dba5eed 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -76,6 +76,17 @@ func.func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> { // ----- +func.func @broadcast_single_elem_vec1d_from_f32(%arg0: f32) -> vector<1xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<1xf32> + return %0 : vector<1xf32> +} +// CHECK-LABEL: @broadcast_single_elem_vec1d_from_f32 +// CHECK-SAME: %[[A:.*]]: f32) +// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]] +// CHECK: return %[[T0]] : vector<1xf32> + +// ----- + func.func @broadcast_vec1d_from_f32_scalable(%arg0: f32) -> vector<[2]xf32> { %0 = vector.broadcast %arg0 : f32 to vector<[2]xf32> return %0 : vector<[2]xf32> From f75fe906f7e388c2e7500ec0be4059c394dbcf34 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Mon, 13 Oct 2025 10:48:11 -0700 Subject: [PATCH 2/4] Update comment. --- mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 2bab94f82723e..46f9931ae318f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1723,11 +1723,12 @@ struct VectorBroadcastScalarToLowRankLowering return success(); } - // For 1-d vector, we additionally do a `vectorshuffle` if vector width > 1. auto v = LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero); + // For 1-d vector, if vector width > 1, we additionally do a + // `vector shuffle` int64_t width = cast(broadcast.getType()).getDimSize(0); if (width == 1) { rewriter.replaceOp(broadcast, v); From b0d476edbf5e06dbdfb313657d4670c6f6b2d17d Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Mon, 13 Oct 2025 12:28:00 -0700 Subject: [PATCH 3/4] Update test. --- mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index f704b8dba5eed..d669a3bac3336 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -83,6 +83,7 @@ func.func @broadcast_single_elem_vec1d_from_f32(%arg0: f32) -> vector<1xf32> { // CHECK-LABEL: @broadcast_single_elem_vec1d_from_f32 // CHECK-SAME: %[[A:.*]]: f32) // CHECK: %[[T0:.*]] = llvm.insertelement %[[A]] +// CHECK-NOT: llvm.shufflevector // CHECK: return %[[T0]] : vector<1xf32> // ----- From a33e1980c944c95a413b25b96398d14977d1773c Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Tue, 14 Oct 2025 20:57:44 +0000 Subject: [PATCH 4/4] Add op fold pattern for llvm.shufflevector and use it for optimizing single element vector. --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 1 + .../VectorToLLVM/ConvertVectorToLLVM.cpp | 12 ++++-------- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 14 ++++++++++++++ mlir/test/Dialect/LLVMIR/canonicalize.mlir | 11 +++++++++++ 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 9753dca67c23d..b67e4cb435e55 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -973,6 +973,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", custom(ref(type($v1)), type($res), ref($mask)) }]; + let hasFolder = 1; let hasVerifier = 1; string llvmInstName = "ShuffleVector"; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 46f9931ae318f..41d8d532757ad 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1727,18 +1727,14 @@ struct VectorBroadcastScalarToLowRankLowering LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero); - // For 1-d vector, if vector width > 1, we additionally do a - // `vector shuffle` + // For 1-d vector, we additionally do a `shufflevector`. int64_t width = cast(broadcast.getType()).getDimSize(0); - if (width == 1) { - rewriter.replaceOp(broadcast, v); - return success(); - } SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. - rewriter.replaceOpWithNewOp(broadcast, v, poison, - zeroValues); + auto shuffle = rewriter.createOrFold( + broadcast.getLoc(), v, poison, zeroValues); + rewriter.replaceOp(broadcast, shuffle); return success(); } }; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 5d08cccb4faab..da49b17eee293 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2824,6 +2824,20 @@ LogicalResult ShuffleVectorOp::verify() { return success(); } +// Folding for shufflevector op when v1 is single element 1D vector +// and the mask is a single zero. OpFoldResult will be v1 in this case. +OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) { + // Check if operand 0 is a single element vector. + auto vecType = llvm::dyn_cast(getV1().getType()); + if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1) + return {}; + // Check if the mask is a single zero. + // Note: The mask is guaranteed to be non-empty. + if (getMask().size() != 1 || getMask()[0] != 0) + return {}; + return getV1(); +} + //===----------------------------------------------------------------------===// // Implementations for LLVM::LLVMFuncOp. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir index 8accf6e263863..755e3a3a5fa09 100644 --- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -235,6 +235,17 @@ llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr { // ----- +// CHECK-LABEL: fold_shufflevector +// CHECK-SAME: %[[ARG1:[[:alnum:]]+]]: vector<1xf32>, %[[ARG2:[[:alnum:]]+]]: vector<1xf32> +llvm.func @fold_shufflevector(%v1 : vector<1xf32>, %v2 : vector<1xf32>) -> vector<1xf32> { + // CHECK-NOT: llvm.shufflevector + %c = llvm.shufflevector %v1, %v2 [0] : vector<1xf32> + // CHECK: llvm.return %[[ARG1]] + llvm.return %c : vector<1xf32> +} + +// ----- + // Check that LLVM constants participate in cross-dialect constant folding. The // resulting constant is created in the arith dialect because the last folded // operation belongs to it.