Skip to content

Commit 856de05

Browse files
authored
[MLIR][Conversion] Vector to LLVM: Remove unneeded vector shuffle (#162946)
if vector.broadcast source is a scalar and target is a single element 1D vector.
1 parent 2b135b9 commit 856de05

File tree

5 files changed

+42
-3
lines changed

5 files changed

+42
-3
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
973973
custom<ShuffleType>(ref(type($v1)), type($res), ref($mask))
974974
}];
975975

976+
let hasFolder = 1;
976977
let hasVerifier = 1;
977978

978979
string llvmInstName = "ShuffleVector";

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering
17231723
return success();
17241724
}
17251725

1726-
// For 1-d vector, we additionally do a `vectorshuffle`.
17271726
auto v =
17281727
LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
17291728
poison, adaptor.getSource(), zero);
17301729

1730+
// For 1-d vector, we additionally do a `shufflevector`.
17311731
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
17321732
SmallVector<int32_t> zeroValues(width, 0);
17331733

17341734
// Shuffle the value across the desired number of elements.
1735-
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
1736-
zeroValues);
1735+
auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
1736+
broadcast.getLoc(), v, poison, zeroValues);
1737+
rewriter.replaceOp(broadcast, shuffle);
17371738
return success();
17381739
}
17391740
};

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2826,6 +2826,20 @@ LogicalResult ShuffleVectorOp::verify() {
28262826
return success();
28272827
}
28282828

2829+
// Folding for shufflevector op when v1 is single element 1D vector
2830+
// and the mask is a single zero. OpFoldResult will be v1 in this case.
2831+
OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) {
2832+
// Check if operand 0 is a single element vector.
2833+
auto vecType = llvm::dyn_cast<VectorType>(getV1().getType());
2834+
if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1)
2835+
return {};
2836+
// Check if the mask is a single zero.
2837+
// Note: The mask is guaranteed to be non-empty.
2838+
if (getMask().size() != 1 || getMask()[0] != 0)
2839+
return {};
2840+
return getV1();
2841+
}
2842+
28292843
//===----------------------------------------------------------------------===//
28302844
// Implementations for LLVM::LLVMFuncOp.
28312845
//===----------------------------------------------------------------------===//

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@ func.func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {
7676

7777
// -----
7878

79+
func.func @broadcast_single_elem_vec1d_from_f32(%arg0: f32) -> vector<1xf32> {
80+
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
81+
return %0 : vector<1xf32>
82+
}
83+
// CHECK-LABEL: @broadcast_single_elem_vec1d_from_f32
84+
// CHECK-SAME: %[[A:.*]]: f32)
85+
// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]]
86+
// CHECK-NOT: llvm.shufflevector
87+
// CHECK: return %[[T0]] : vector<1xf32>
88+
89+
// -----
90+
7991
func.func @broadcast_vec1d_from_f32_scalable(%arg0: f32) -> vector<[2]xf32> {
8092
%0 = vector.broadcast %arg0 : f32 to vector<[2]xf32>
8193
return %0 : vector<[2]xf32>

mlir/test/Dialect/LLVMIR/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,17 @@ llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr {
235235

236236
// -----
237237

238+
// CHECK-LABEL: fold_shufflevector
239+
// CHECK-SAME: %[[ARG1:[[:alnum:]]+]]: vector<1xf32>, %[[ARG2:[[:alnum:]]+]]: vector<1xf32>
240+
llvm.func @fold_shufflevector(%v1 : vector<1xf32>, %v2 : vector<1xf32>) -> vector<1xf32> {
241+
// CHECK-NOT: llvm.shufflevector
242+
%c = llvm.shufflevector %v1, %v2 [0] : vector<1xf32>
243+
// CHECK: llvm.return %[[ARG1]]
244+
llvm.return %c : vector<1xf32>
245+
}
246+
247+
// -----
248+
238249
// Check that LLVM constants participate in cross-dialect constant folding. The
239250
// resulting constant is created in the arith dialect because the last folded
240251
// operation belongs to it.

0 commit comments

Comments
 (0)