Skip to content

Commit d8e2291

Browse files
committed
review comments
Signed-off-by: Benoit Jacob <[email protected]>
1 parent d18ba67 commit d8e2291

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,66 +1229,73 @@ class VectorInsertOpConversion
12291229
// its explanatory comment about how N-D vectors are converted as nested
12301230
// aggregates (llvm.array's) of 1D vectors.
12311231
//
1232-
// There are 3 steps here, vs 2 in VectorExtractOpConversion:
1233-
// - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue.
1234-
// - Insertion into the 1D vector: llvm.insertelement.
1235-
// - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue.
1232+
// The innermost dimension of the destination vector, when converted to a
1233+
// nested aggregate form, will always be a 1D vector.
1234+
//
1235+
// * If the insertion is happening into the innermost dimension of the
1236+
// destination vector:
1237+
// - If the destination is a nested aggregate, extract a 1D vector out of
1238+
// the aggregate. This can be done using llvm.extractvalue. The
1239+
// destination is now guaranteed to be a 1D vector, to which we are
1240+
// inserting.
1241+
// - Do the insertion into the 1D destination vector, and make the result
1242+
// the new source nested aggregate. This can be done using
1243+
// llvm.insertelement.
1244+
// * Insert the source nested aggregate into the destination nested
1245+
// aggregate.
12361246

12371247
// Determine if we need to extract/insert a 1D vector out of the aggregate.
1238-
bool is1DVectorWithinAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1248+
bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
12391249
// Determine if we need to insert a scalar into the 1D vector.
1240-
bool isScalarWithin1DVector =
1250+
bool insertIntoInnermostDim =
12411251
static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
12421252

12431253
ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
12441254
positionVec.begin(),
1245-
isScalarWithin1DVector ? positionVec.size() - 1 : positionVec.size());
1255+
insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
12461256
OpFoldResult positionOfScalarWithin1DVector;
12471257
if (destVectorType.getRank() == 0) {
12481258
// Since the LLVM type converter converts 0D vectors to 1D vectors, we
12491259
// need to create a 0 here as the position into the 1D vector.
12501260
Type idxType = typeConverter->convertType(rewriter.getIndexType());
12511261
positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1252-
} else if (isScalarWithin1DVector) {
1262+
} else if (insertIntoInnermostDim) {
12531263
positionOfScalarWithin1DVector = positionVec.back();
12541264
}
12551265

12561266
// We are going to mutate this 1D vector until it is either the final
12571267
// result (in the non-aggregate case) or the value that needs to be
12581268
// inserted into the aggregate result.
1259-
Value vector1d;
1260-
if (isScalarWithin1DVector) {
1269+
Value sourceAggregate = adaptor.getSource();
1270+
if (insertIntoInnermostDim) {
12611271
// Scalar-into-1D-vector case, so we know we will have to create a
12621272
// InsertElementOp. The question is into what destination.
1263-
if (is1DVectorWithinAggregate) {
1273+
if (isNestedAggregate) {
12641274
// Aggregate case: the destination for the InsertElementOp needs to be
12651275
// extracted from the aggregate.
12661276
if (!llvm::all_of(positionOf1DVectorWithinAggregate,
12671277
llvm::IsaPred<Attribute>)) {
12681278
// llvm.extractvalue does not support dynamic dimensions.
12691279
return failure();
12701280
}
1271-
vector1d = rewriter.create<LLVM::ExtractValueOp>(
1281+
sourceAggregate = rewriter.create<LLVM::ExtractValueOp>(
12721282
loc, adaptor.getDest(),
12731283
getAsIntegers(positionOf1DVectorWithinAggregate));
12741284
} else {
12751285
// No-aggregate case. The destination for the InsertElementOp is just
12761286
// the insertOp's destination.
1277-
vector1d = adaptor.getDest();
1287+
sourceAggregate = adaptor.getDest();
12781288
}
12791289
// Insert the scalar into the 1D vector.
1280-
vector1d = rewriter.create<LLVM::InsertElementOp>(
1281-
loc, vector1d.getType(), vector1d, adaptor.getSource(),
1290+
sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
1291+
loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(),
12821292
getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
1283-
} else {
1284-
// No scalar insertion. The 1D vector is just the source.
1285-
vector1d = adaptor.getSource();
12861293
}
12871294

1288-
Value result = vector1d;
1289-
if (is1DVectorWithinAggregate) {
1295+
Value result = sourceAggregate;
1296+
if (isNestedAggregate) {
12901297
result = rewriter.create<LLVM::InsertValueOp>(
1291-
loc, adaptor.getDest(), vector1d,
1298+
loc, adaptor.getDest(), sourceAggregate,
12921299
getAsIntegers(positionOf1DVectorWithinAggregate));
12931300
}
12941301

0 commit comments

Comments
 (0)