Skip to content

Commit 567412b

Browse files
support insert.
1 parent fa488d5 commit 567412b

File tree

2 files changed

+76
-7
lines changed

2 files changed

+76
-7
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,10 +1099,7 @@ class VectorExtractOpConversion
10991099
for (unsigned idx = 0; idx < positionVec.size(); ++idx) {
11001100
if (auto position = llvm::dyn_cast<Value>(positionVec[idx])) {
11011101
auto defOp = position.getDefiningOp();
1102-
while (true) {
1103-
if (!defOp) {
1104-
break;
1105-
}
1102+
while (defOp) {
11061103
if (llvm::isa<arith::ConstantOp, LLVM::ConstantOp>(defOp)) {
11071104
Attribute value =
11081105
defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]);
@@ -1254,6 +1251,25 @@ class VectorInsertOpConversion
12541251

12551252
SmallVector<OpFoldResult> positionVec = getMixedValues(
12561253
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1254+
for (unsigned idx = 0; idx < positionVec.size(); ++idx) {
1255+
if (auto position = llvm::dyn_cast<Value>(positionVec[idx])) {
1256+
auto defOp = position.getDefiningOp();
1257+
while (defOp) {
1258+
if (llvm::isa<arith::ConstantOp, LLVM::ConstantOp>(defOp)) {
1259+
Attribute value =
1260+
defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]);
1261+
positionVec[idx] = OpFoldResult{
1262+
rewriter.getI64IntegerAttr(cast<IntegerAttr>(value).getInt())};
1263+
break;
1264+
} else if (auto unrealizedCastOp =
1265+
llvm::dyn_cast<UnrealizedConversionCastOp>(defOp)) {
1266+
defOp = unrealizedCastOp.getOperand(0).getDefiningOp();
1267+
} else {
1268+
break;
1269+
}
1270+
}
1271+
}
1272+
}
12571273

12581274
// Overwrite entire vector with value. Should be handled by folder, but
12591275
// just to be safe.
@@ -1265,8 +1281,9 @@ class VectorInsertOpConversion
12651281

12661282
// One-shot insertion of a vector into an array (only requires insertvalue).
12671283
if (isa<VectorType>(sourceType)) {
1268-
if (insertOp.hasDynamicPosition())
1284+
if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
12691285
return failure();
1286+
}
12701287

12711288
Value inserted = rewriter.create<LLVM::InsertValueOp>(
12721289
loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
@@ -1278,8 +1295,9 @@ class VectorInsertOpConversion
12781295
Value extracted = adaptor.getDest();
12791296
auto oneDVectorType = destVectorType;
12801297
if (position.size() > 1) {
1281-
if (insertOp.hasDynamicPosition())
1298+
if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
12821299
return failure();
1300+
}
12831301

12841302
oneDVectorType = reducedVectorTypeBack(destVectorType);
12851303
extracted = rewriter.create<LLVM::ExtractValueOp>(
@@ -1293,8 +1311,9 @@ class VectorInsertOpConversion
12931311

12941312
// Potential insertion of resulting 1-D vector into array.
12951313
if (position.size() > 1) {
1296-
if (insertOp.hasDynamicPosition())
1314+
if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
12971315
return failure();
1316+
}
12981317

12991318
inserted = rewriter.create<LLVM::InsertValueOp>(
13001319
loc, adaptor.getDest(), inserted,

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4134,3 +4134,53 @@ module {
41344134
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
41354135
// CHECK: %[[VAL_4:.*]] = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
41364136
// CHECK: return %[[VAL_4]] : i32
4137+
4138+
// -----
4139+
4140+
// CHECK-LABEL: @insert_arith_constnt()
4141+
4142+
func.func @insert_arith_constnt() -> vector<32x1xi32> {
4143+
%v = arith.constant dense<0> : vector<32x1xi32>
4144+
%c_0 = arith.constant 0 : index
4145+
%c_1 = arith.constant 1 : i32
4146+
%v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
4147+
return %v_1 : vector<32x1xi32>
4148+
}
4149+
4150+
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
4151+
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
4152+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
4153+
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
4154+
// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
4155+
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i64) : i64
4156+
// CHECK: %[[VAL_6:.*]] = llvm.insertelement %[[VAL_3]], %[[VAL_4]]{{\[}}%[[VAL_5]] : i64] : vector<1xi32>
4157+
// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
4158+
// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
4159+
// CHECK: return %[[VAL_8]] : vector<32x1xi32>
4160+
4161+
// -----
4162+
4163+
// CHECK-LABEL: @insert_llvm_constnt()
4164+
4165+
module {
4166+
func.func @insert_llvm_constnt() -> vector<32x1xi32> {
4167+
%0 = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
4168+
%1 = builtin.unrealized_conversion_cast %0 : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
4169+
%2 = llvm.mlir.constant(0 : index) : i64
4170+
%3 = builtin.unrealized_conversion_cast %2 : i64 to index
4171+
%4 = llvm.mlir.constant(1 : i32) : i32
4172+
%5 = vector.insert %4, %1 [%3, %3] : i32 into vector<32x1xi32>
4173+
return %5 : vector<32x1xi32>
4174+
}
4175+
}
4176+
4177+
// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i32) : i32
4178+
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64
4179+
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
4180+
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.array<32 x vector<1xi32>>
4181+
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
4182+
// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_0]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
4183+
// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_2]][0] : !llvm.array<32 x vector<1xi32>>
4184+
// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
4185+
// CHECK: return %[[VAL_7]] : vector<32x1xi32>
4186+
// CHECK: }

0 commit comments

Comments
 (0)