@@ -1229,66 +1229,73 @@ class VectorInsertOpConversion
12291229 // its explanatory comment about how N-D vectors are converted as nested
12301230 // aggregates (llvm.array's) of 1D vectors.
12311231 //
1232- // There are 3 steps here, vs 2 in VectorExtractOpConversion:
1233- // - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue.
1234- // - Insertion into the 1D vector: llvm.insertelement.
1235- // - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue.
1232+ // The innermost dimension of the destination vector, when converted to a
1233+ // nested aggregate form, will always be a 1D vector.
1234+ //
1235+ // * If the insertion is happening into the innermost dimension of the
1236+ // destination vector:
1237+ // - If the destination is a nested aggregate, extract a 1D vector out of
1238+ // the aggregate. This can be done using llvm.extractvalue. The
1239+ // destination is now guaranteed to be a 1D vector, to which we are
1240+ // inserting.
1241+ // - Do the insertion into the 1D destination vector, and make the result
1242+ // the new source nested aggregate. This can be done using
1243+ // llvm.insertelement.
1244+ // * Insert the source nested aggregate into the destination nested
1245+ // aggregate.
12361246
12371247 // Determine if we need to extract/insert a 1D vector out of the aggregate.
1238- bool is1DVectorWithinAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1248+ bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
12391249 // Determine if we need to insert a scalar into the 1D vector.
1240- bool isScalarWithin1DVector =
1250+ bool insertIntoInnermostDim =
12411251 static_cast <int64_t >(positionVec.size ()) == destVectorType.getRank ();
12421252
12431253 ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate (
12441254 positionVec.begin (),
1245- isScalarWithin1DVector ? positionVec.size () - 1 : positionVec.size ());
1255+ insertIntoInnermostDim ? positionVec.size () - 1 : positionVec.size ());
12461256 OpFoldResult positionOfScalarWithin1DVector;
12471257 if (destVectorType.getRank () == 0 ) {
12481258 // Since the LLVM type converter converts 0D vectors to 1D vectors, we
12491259 // need to create a 0 here as the position into the 1D vector.
12501260 Type idxType = typeConverter->convertType (rewriter.getIndexType ());
12511261 positionOfScalarWithin1DVector = rewriter.getZeroAttr (idxType);
1252- } else if (isScalarWithin1DVector ) {
1262+ } else if (insertIntoInnermostDim ) {
12531263 positionOfScalarWithin1DVector = positionVec.back ();
12541264 }
12551265
12561266 // We are going to mutate this 1D vector until it is either the final
12571267 // result (in the non-aggregate case) or the value that needs to be
12581268 // inserted into the aggregate result.
1259- Value vector1d ;
1260- if (isScalarWithin1DVector ) {
1269+ Value sourceAggregate = adaptor. getSource () ;
1270+ if (insertIntoInnermostDim ) {
12611271 // Scalar-into-1D-vector case, so we know we will have to create a
12621272 // InsertElementOp. The question is into what destination.
1263- if (is1DVectorWithinAggregate ) {
1273+ if (isNestedAggregate ) {
12641274 // Aggregate case: the destination for the InsertElementOp needs to be
12651275 // extracted from the aggregate.
12661276 if (!llvm::all_of (positionOf1DVectorWithinAggregate,
12671277 llvm::IsaPred<Attribute>)) {
12681278 // llvm.extractvalue does not support dynamic dimensions.
12691279 return failure ();
12701280 }
1271- vector1d = rewriter.create <LLVM::ExtractValueOp>(
1281+ sourceAggregate = rewriter.create <LLVM::ExtractValueOp>(
12721282 loc, adaptor.getDest (),
12731283 getAsIntegers (positionOf1DVectorWithinAggregate));
12741284 } else {
12751285 // No-aggregate case. The destination for the InsertElementOp is just
12761286 // the insertOp's destination.
1277- vector1d = adaptor.getDest ();
1287+ sourceAggregate = adaptor.getDest ();
12781288 }
12791289 // Insert the scalar into the 1D vector.
1280- vector1d = rewriter.create <LLVM::InsertElementOp>(
1281- loc, vector1d .getType (), vector1d , adaptor.getSource (),
1290+ sourceAggregate = rewriter.create <LLVM::InsertElementOp>(
1291+ loc, sourceAggregate .getType (), sourceAggregate , adaptor.getSource (),
12821292 getAsLLVMValue (rewriter, loc, positionOfScalarWithin1DVector));
1283- } else {
1284- // No scalar insertion. The 1D vector is just the source.
1285- vector1d = adaptor.getSource ();
12861293 }
12871294
1288- Value result = vector1d ;
1289- if (is1DVectorWithinAggregate ) {
1295+ Value result = sourceAggregate ;
1296+ if (isNestedAggregate ) {
12901297 result = rewriter.create <LLVM::InsertValueOp>(
1291- loc, adaptor.getDest (), vector1d ,
1298+ loc, adaptor.getDest (), sourceAggregate ,
12921299 getAsIntegers (positionOf1DVectorWithinAggregate));
12931300 }
12941301
0 commit comments