Skip to content

Commit d18ba67

Browse files
committed
like-117731-but-for-insert
Signed-off-by: Benoit Jacob <[email protected]>
1 parent c53eb93 commit d18ba67

File tree

2 files changed

+78
-55
lines changed

2 files changed

+78
-55
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,6 @@
3535
using namespace mlir;
3636
using namespace mlir::vector;
3737

38-
// Helper to reduce vector type by *all* but one rank at back.
39-
static VectorType reducedVectorTypeBack(VectorType tp) {
40-
assert((tp.getRank() > 1) && "unlowerable vector type");
41-
return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
42-
tp.getScalableDims().take_back());
43-
}
44-
4538
// Helper that picks the proper sequence for inserting.
4639
static Value insertOne(ConversionPatternRewriter &rewriter,
4740
const LLVMTypeConverter &typeConverter, Location loc,
@@ -1223,7 +1216,6 @@ class VectorInsertOpConversion
12231216
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
12241217
ConversionPatternRewriter &rewriter) const override {
12251218
auto loc = insertOp->getLoc();
1226-
auto sourceType = insertOp.getSourceType();
12271219
auto destVectorType = insertOp.getDestVectorType();
12281220
auto llvmResultType = typeConverter->convertType(destVectorType);
12291221
// Bail if result type cannot be lowered.
@@ -1233,53 +1225,74 @@ class VectorInsertOpConversion
12331225
SmallVector<OpFoldResult> positionVec = getMixedValues(
12341226
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
12351227

1236-
// Overwrite entire vector with value. Should be handled by folder, but
1237-
// just to be safe.
1238-
ArrayRef<OpFoldResult> position(positionVec);
1239-
if (position.empty()) {
1240-
rewriter.replaceOp(insertOp, adaptor.getSource());
1241-
return success();
1242-
}
1243-
1244-
// One-shot insertion of a vector into an array (only requires insertvalue).
1245-
if (isa<VectorType>(sourceType)) {
1246-
if (insertOp.hasDynamicPosition())
1247-
return failure();
1248-
1249-
Value inserted = rewriter.create<LLVM::InsertValueOp>(
1250-
loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
1251-
rewriter.replaceOp(insertOp, inserted);
1252-
return success();
1228+
// The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1229+
// its explanatory comment about how N-D vectors are converted as nested
1230+
// aggregates (llvm.array's) of 1D vectors.
1231+
//
1232+
// There are 3 steps here, vs 2 in VectorExtractOpConversion:
1233+
// - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue.
1234+
// - Insertion into the 1D vector: llvm.insertelement.
1235+
// - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue.
1236+
1237+
// Determine if we need to extract/insert a 1D vector out of the aggregate.
1238+
bool is1DVectorWithinAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1239+
// Determine if we need to insert a scalar into the 1D vector.
1240+
bool isScalarWithin1DVector =
1241+
static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
1242+
1243+
ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1244+
positionVec.begin(),
1245+
isScalarWithin1DVector ? positionVec.size() - 1 : positionVec.size());
1246+
OpFoldResult positionOfScalarWithin1DVector;
1247+
if (destVectorType.getRank() == 0) {
1248+
// Since the LLVM type converter converts 0D vectors to 1D vectors, we
1249+
// need to create a 0 here as the position into the 1D vector.
1250+
Type idxType = typeConverter->convertType(rewriter.getIndexType());
1251+
positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1252+
} else if (isScalarWithin1DVector) {
1253+
positionOfScalarWithin1DVector = positionVec.back();
12531254
}
12541255

1255-
// Potential extraction of 1-D vector from array.
1256-
Value extracted = adaptor.getDest();
1257-
auto oneDVectorType = destVectorType;
1258-
if (position.size() > 1) {
1259-
if (insertOp.hasDynamicPosition())
1260-
return failure();
1261-
1262-
oneDVectorType = reducedVectorTypeBack(destVectorType);
1263-
extracted = rewriter.create<LLVM::ExtractValueOp>(
1264-
loc, extracted, getAsIntegers(position.drop_back()));
1256+
// We are going to mutate this 1D vector until it is either the final
1257+
// result (in the non-aggregate case) or the value that needs to be
1258+
// inserted into the aggregate result.
1259+
Value vector1d;
1260+
if (isScalarWithin1DVector) {
1261+
// Scalar-into-1D-vector case, so we know we will have to create a
1262+
// InsertElementOp. The question is into what destination.
1263+
if (is1DVectorWithinAggregate) {
1264+
// Aggregate case: the destination for the InsertElementOp needs to be
1265+
// extracted from the aggregate.
1266+
if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1267+
llvm::IsaPred<Attribute>)) {
1268+
// llvm.extractvalue does not support dynamic dimensions.
1269+
return failure();
1270+
}
1271+
vector1d = rewriter.create<LLVM::ExtractValueOp>(
1272+
loc, adaptor.getDest(),
1273+
getAsIntegers(positionOf1DVectorWithinAggregate));
1274+
} else {
1275+
// No-aggregate case. The destination for the InsertElementOp is just
1276+
// the insertOp's destination.
1277+
vector1d = adaptor.getDest();
1278+
}
1279+
// Insert the scalar into the 1D vector.
1280+
vector1d = rewriter.create<LLVM::InsertElementOp>(
1281+
loc, vector1d.getType(), vector1d, adaptor.getSource(),
1282+
getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
1283+
} else {
1284+
// No scalar insertion. The 1D vector is just the source.
1285+
vector1d = adaptor.getSource();
12651286
}
12661287

1267-
// Insertion of an element into a 1-D LLVM vector.
1268-
Value inserted = rewriter.create<LLVM::InsertElementOp>(
1269-
loc, typeConverter->convertType(oneDVectorType), extracted,
1270-
adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
1271-
1272-
// Potential insertion of resulting 1-D vector into array.
1273-
if (position.size() > 1) {
1274-
if (insertOp.hasDynamicPosition())
1275-
return failure();
1276-
1277-
inserted = rewriter.create<LLVM::InsertValueOp>(
1278-
loc, adaptor.getDest(), inserted,
1279-
getAsIntegers(position.drop_back()));
1288+
Value result = vector1d;
1289+
if (is1DVectorWithinAggregate) {
1290+
result = rewriter.create<LLVM::InsertValueOp>(
1291+
loc, adaptor.getDest(), vector1d,
1292+
getAsIntegers(positionOf1DVectorWithinAggregate));
12801293
}
12811294

1282-
rewriter.replaceOp(insertOp, inserted);
1295+
rewriter.replaceOp(insertOp, result);
12831296
return success();
12841297
}
12851298
};

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,16 @@ func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f
628628
// vector.insert
629629
//===----------------------------------------------------------------------===//
630630

631+
func.func @insert_scalar_into_vec_0d(%src: f32, %dst: vector<f32>) -> vector<f32> {
632+
%0 = vector.insert %src, %dst[] : f32 into vector<f32>
633+
return %0 : vector<f32>
634+
}
635+
636+
// CHECK-LABEL: @insert_scalar_into_vec_0d
637+
// CHECK: llvm.insertelement {{.*}} : vector<1xf32>
638+
639+
// -----
640+
631641
func.func @insert_scalar_into_vec_1d_f32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
632642
%0 = vector.insert %arg0, %arg1[3] : f32 into vector<4xf32>
633643
return %0 : vector<4xf32>
@@ -780,10 +790,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a
780790
return %0 : vector<1x16xf32>
781791
}
782792

783-
// Multi-dim vectors are not supported but this test shouldn't crash.
784-
785793
// CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx(
786-
// CHECK: vector.insert
794+
// CHECK: llvm.extractvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
795+
// CHECK: llvm.insertelement {{.*}} : vector<16xf32>
796+
// CHECK: llvm.insertvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
787797

788798
// -----
789799

@@ -793,10 +803,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
793803
return %0 : vector<1x[16]xf32>
794804
}
795805

796-
// Multi-dim vectors are not supported but this test shouldn't crash.
797-
798806
// CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(
799-
// CHECK: vector.insert
807+
// CHECK: llvm.extractvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
808+
// CHECK: llvm.insertelement {{.*}} : vector<[16]xf32>
809+
// CHECK: llvm.insertvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
800810

801811
// -----
802812

0 commit comments

Comments
 (0)