Skip to content

Commit a33e198

Browse files
committed
Add op fold pattern for llvm.shufflevector and use it for optimizing single
element vector.
1 parent b0d476e commit a33e198

File tree

4 files changed

+30
-8
lines changed

4 files changed

+30
-8
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
973973
custom<ShuffleType>(ref(type($v1)), type($res), ref($mask))
974974
}];
975975

976+
let hasFolder = 1;
976977
let hasVerifier = 1;
977978

978979
string llvmInstName = "ShuffleVector";

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,18 +1727,14 @@ struct VectorBroadcastScalarToLowRankLowering
17271727
LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
17281728
poison, adaptor.getSource(), zero);
17291729

1730-
// For 1-d vector, if vector width > 1, we additionally do a
1731-
// `vector shuffle`
1730+
// For 1-d vector, we additionally do a `shufflevector`.
17321731
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
1733-
if (width == 1) {
1734-
rewriter.replaceOp(broadcast, v);
1735-
return success();
1736-
}
17371732
SmallVector<int32_t> zeroValues(width, 0);
17381733

17391734
// Shuffle the value across the desired number of elements.
1740-
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
1741-
zeroValues);
1735+
auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
1736+
broadcast.getLoc(), v, poison, zeroValues);
1737+
rewriter.replaceOp(broadcast, shuffle);
17421738
return success();
17431739
}
17441740
};

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2824,6 +2824,20 @@ LogicalResult ShuffleVectorOp::verify() {
28242824
return success();
28252825
}
28262826

2827+
// Folding for shufflevector op when v1 is single element 1D vector
2828+
// and the mask is a single zero. OpFoldResult will be v1 in this case.
2829+
OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) {
2830+
// Check if operand 0 is a single element vector.
2831+
auto vecType = llvm::dyn_cast<VectorType>(getV1().getType());
2832+
if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1)
2833+
return {};
2834+
// Check if the mask is a single zero.
2835+
// Note: The mask is guaranteed to be non-empty.
2836+
if (getMask().size() != 1 || getMask()[0] != 0)
2837+
return {};
2838+
return getV1();
2839+
}
2840+
28272841
//===----------------------------------------------------------------------===//
28282842
// Implementations for LLVM::LLVMFuncOp.
28292843
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,17 @@ llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr {
235235

236236
// -----
237237

238+
// CHECK-LABEL: fold_shufflevector
239+
// CHECK-SAME: %[[ARG1:[[:alnum:]]+]]: vector<1xf32>, %[[ARG2:[[:alnum:]]+]]: vector<1xf32>
240+
llvm.func @fold_shufflevector(%v1 : vector<1xf32>, %v2 : vector<1xf32>) -> vector<1xf32> {
241+
// CHECK-NOT: llvm.shufflevector
242+
%c = llvm.shufflevector %v1, %v2 [0] : vector<1xf32>
243+
// CHECK: llvm.return %[[ARG1]]
244+
llvm.return %c : vector<1xf32>
245+
}
246+
247+
// -----
248+
238249
// Check that LLVM constants participate in cross-dialect constant folding. The
239250
// resulting constant is created in the arith dialect because the last folded
240251
// operation belongs to it.

0 commit comments

Comments
 (0)