Skip to content

Commit 7e937c9

Browse files
committed
Fix vector.insert
Signed-off-by: Benoit Jacob <[email protected]>
1 parent c53eb93 commit 7e937c9

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,11 +1233,15 @@ class VectorInsertOpConversion
12331233
SmallVector<OpFoldResult> positionVec = getMixedValues(
12341234
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
12351235

1236-
// Overwrite entire vector with value. Should be handled by folder, but
1237-
// just to be safe.
12381236
ArrayRef<OpFoldResult> position(positionVec);
1237+
// Case of empty position, used with 0-D destination vector. In that case,
1238+
// the converted destination type is a LLVM vector of size 1, and we need
1239+
// a 0 as the position.
12391240
if (position.empty()) {
1240-
rewriter.replaceOp(insertOp, adaptor.getSource());
1241+
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1242+
insertOp, llvmResultType, adaptor.getDest(), adaptor.getSource(),
1243+
rewriter.create<LLVM::ConstantOp>(loc,
1244+
rewriter.getI64IntegerAttr(0)));
12411245
return success();
12421246
}
12431247

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,3 +1787,32 @@ func.func @step() -> vector<4xindex> {
17871787
%0 = vector.step : vector<4xindex>
17881788
return %0 : vector<4xindex>
17891789
}
1790+
1791+
// -----
1792+
1793+
//===----------------------------------------------------------------------===//
1794+
// vector.insert
1795+
//===----------------------------------------------------------------------===//
1796+
1797+
// CHECK-LABEL: @insert_0d
1798+
// CHECK: llvm.insertelement {{.*}} : vector<1xf32>
1799+
func.func @insert_0d(%src: f32, %dst: vector<f32>) -> vector<f32> {
1800+
%0 = vector.insert %src, %dst[] : f32 into vector<f32>
1801+
return %0 : vector<f32>
1802+
}
1803+
1804+
// CHECK-LABEL: @insert_1d
1805+
// CHECK: llvm.insertelement {{.*}} : vector<2xf32>
1806+
func.func @insert_1d(%src: f32, %dst: vector<2xf32>) -> vector<2xf32> {
1807+
%0 = vector.insert %src, %dst[1] : f32 into vector<2xf32>
1808+
return %0 : vector<2xf32>
1809+
}
1810+
1811+
// CHECK-LABEL: @insert_2d
1812+
// CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf32>>
1813+
// CHECK: llvm.insertelement {{.*}} : vector<2xf32>
1814+
// CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf32>>
1815+
func.func @insert_2d(%src: f32, %dst: vector<2x2xf32>) -> vector<2x2xf32> {
1816+
%0 = vector.insert %src, %dst[1, 0] : f32 into vector<2x2xf32>
1817+
return %0 : vector<2x2xf32>
1818+
}

0 commit comments

Comments
 (0)