Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
custom<ShuffleType>(ref(type($v1)), type($res), ref($mask))
}];

let hasFolder = 1;
let hasVerifier = 1;

string llvmInstName = "ShuffleVector";
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering
return success();
}

// For 1-d vector, we additionally do a `vectorshuffle`.
auto v =
LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
poison, adaptor.getSource(), zero);

// For 1-d vector, we additionally do a `shufflevector`.
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);

// Shuffle the value across the desired number of elements.
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
zeroValues);
auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
broadcast.getLoc(), v, poison, zeroValues);
rewriter.replaceOp(broadcast, shuffle);
return success();
}
};
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2824,6 +2824,20 @@ LogicalResult ShuffleVectorOp::verify() {
return success();
}

// Folding for shufflevector op when v1 is single element 1D vector
// and the mask is a single zero. OpFoldResult will be v1 in this case.
OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) {
// Check if operand 0 is a single element vector.
auto vecType = llvm::dyn_cast<VectorType>(getV1().getType());
if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1)
return {};
// Check if the mask is a single zero.
// Note: The mask is guaranteed to be non-empty.
if (getMask().size() != 1 || getMask()[0] != 0)
return {};
return getV1();
}

//===----------------------------------------------------------------------===//
// Implementations for LLVM::LLVMFuncOp.
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ func.func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {

// -----

func.func @broadcast_single_elem_vec1d_from_f32(%arg0: f32) -> vector<1xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
return %0 : vector<1xf32>
}
// CHECK-LABEL: @broadcast_single_elem_vec1d_from_f32
// CHECK-SAME: %[[A:.*]]: f32)
// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]]
// CHECK-NOT: llvm.shufflevector
// CHECK: return %[[T0]] : vector<1xf32>

// -----

func.func @broadcast_vec1d_from_f32_scalable(%arg0: f32) -> vector<[2]xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<[2]xf32>
return %0 : vector<[2]xf32>
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Dialect/LLVMIR/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,17 @@ llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr {

// -----

// CHECK-LABEL: fold_shufflevector
// CHECK-SAME: %[[ARG1:[[:alnum:]]+]]: vector<1xf32>, %[[ARG2:[[:alnum:]]+]]: vector<1xf32>
llvm.func @fold_shufflevector(%v1 : vector<1xf32>, %v2 : vector<1xf32>) -> vector<1xf32> {
// CHECK-NOT: llvm.shufflevector
%c = llvm.shufflevector %v1, %v2 [0] : vector<1xf32>
// CHECK: llvm.return %[[ARG1]]
llvm.return %c : vector<1xf32>
}

// -----

// Check that LLVM constants participate in cross-dialect constant folding. The
// resulting constant is created in the arith dialect because the last folded
// operation belongs to it.
Expand Down