3535using namespace mlir ;
3636using namespace mlir ::vector;
3737
38- // Helper to reduce vector type by *all* but one rank at back.
39- static VectorType reducedVectorTypeBack (VectorType tp) {
40- assert ((tp.getRank () > 1 ) && " unlowerable vector type" );
41- return VectorType::get (tp.getShape ().take_back (), tp.getElementType (),
42- tp.getScalableDims ().take_back ());
43- }
44-
4538// Helper that picks the proper sequence for inserting.
4639static Value insertOne (ConversionPatternRewriter &rewriter,
4740 const LLVMTypeConverter &typeConverter, Location loc,
@@ -1223,7 +1216,6 @@ class VectorInsertOpConversion
12231216 matchAndRewrite (vector::InsertOp insertOp, OpAdaptor adaptor,
12241217 ConversionPatternRewriter &rewriter) const override {
12251218 auto loc = insertOp->getLoc ();
1226- auto sourceType = insertOp.getSourceType ();
12271219 auto destVectorType = insertOp.getDestVectorType ();
12281220 auto llvmResultType = typeConverter->convertType (destVectorType);
12291221 // Bail if result type cannot be lowered.
@@ -1233,53 +1225,81 @@ class VectorInsertOpConversion
12331225 SmallVector<OpFoldResult> positionVec = getMixedValues (
12341226 adaptor.getStaticPosition (), adaptor.getDynamicPosition (), rewriter);
12351227
1236- // Overwrite entire vector with value. Should be handled by folder, but
1237- // just to be safe.
1238- ArrayRef<OpFoldResult> position (positionVec);
1239- if (position.empty ()) {
1240- rewriter.replaceOp (insertOp, adaptor.getSource ());
1241- return success ();
1242- }
1243-
1244- // One-shot insertion of a vector into an array (only requires insertvalue).
1245- if (isa<VectorType>(sourceType)) {
1246- if (insertOp.hasDynamicPosition ())
1247- return failure ();
1248-
1249- Value inserted = rewriter.create <LLVM::InsertValueOp>(
1250- loc, adaptor.getDest (), adaptor.getSource (), getAsIntegers (position));
1251- rewriter.replaceOp (insertOp, inserted);
1252- return success ();
1228+ // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1229+ // its explanatory comment about how N-D vectors are converted as nested
1230+ // aggregates (llvm.array's) of 1D vectors.
1231+ //
1232+ // The innermost dimension of the destination vector, when converted to a
1233+ // nested aggregate form, will always be a 1D vector.
1234+ //
1235+ // * If the insertion is happening into the innermost dimension of the
1236+ // destination vector:
1237+ // - If the destination is a nested aggregate, extract a 1D vector out of
1238+ // the aggregate. This can be done using llvm.extractvalue. The
1239+ // destination is now guaranteed to be a 1D vector, to which we are
1240+ // inserting.
1241+ // - Do the insertion into the 1D destination vector, and make the result
1242+ // the new source nested aggregate. This can be done using
1243+ // llvm.insertelement.
1244+ // * Insert the source nested aggregate into the destination nested
1245+ // aggregate.
1246+
1247+ // Determine if we need to extract/insert a 1D vector out of the aggregate.
1248+ bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1249+ // Determine if we need to insert a scalar into the 1D vector.
1250+ bool insertIntoInnermostDim =
1251+ static_cast <int64_t >(positionVec.size ()) == destVectorType.getRank ();
1252+
1253+ ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate (
1254+ positionVec.begin (),
1255+ insertIntoInnermostDim ? positionVec.size () - 1 : positionVec.size ());
1256+ OpFoldResult positionOfScalarWithin1DVector;
1257+ if (destVectorType.getRank () == 0 ) {
1258+ // Since the LLVM type converter converts 0D vectors to 1D vectors, we
1259+ // need to create a 0 here as the position into the 1D vector.
1260+ Type idxType = typeConverter->convertType (rewriter.getIndexType ());
1261+ positionOfScalarWithin1DVector = rewriter.getZeroAttr (idxType);
1262+ } else if (insertIntoInnermostDim) {
1263+ positionOfScalarWithin1DVector = positionVec.back ();
12531264 }
12541265
1255- // Potential extraction of 1-D vector from array.
1256- Value extracted = adaptor.getDest ();
1257- auto oneDVectorType = destVectorType;
1258- if (position.size () > 1 ) {
1259- if (insertOp.hasDynamicPosition ())
1260- return failure ();
1261-
1262- oneDVectorType = reducedVectorTypeBack (destVectorType);
1263- extracted = rewriter.create <LLVM::ExtractValueOp>(
1264- loc, extracted, getAsIntegers (position.drop_back ()));
1266+ // We are going to mutate this 1D vector until it is either the final
1267+ // result (in the non-aggregate case) or the value that needs to be
1268+ // inserted into the aggregate result.
1269+ Value sourceAggregate = adaptor.getSource ();
1270+ if (insertIntoInnermostDim) {
1271+ // Scalar-into-1D-vector case, so we know we will have to create a
1272+ // InsertElementOp. The question is into what destination.
1273+ if (isNestedAggregate) {
1274+ // Aggregate case: the destination for the InsertElementOp needs to be
1275+ // extracted from the aggregate.
1276+ if (!llvm::all_of (positionOf1DVectorWithinAggregate,
1277+ llvm::IsaPred<Attribute>)) {
1278+ // llvm.extractvalue does not support dynamic dimensions.
1279+ return failure ();
1280+ }
1281+ sourceAggregate = rewriter.create <LLVM::ExtractValueOp>(
1282+ loc, adaptor.getDest (),
1283+ getAsIntegers (positionOf1DVectorWithinAggregate));
1284+ } else {
1285+ // No-aggregate case. The destination for the InsertElementOp is just
1286+ // the insertOp's destination.
1287+ sourceAggregate = adaptor.getDest ();
1288+ }
1289+ // Insert the scalar into the 1D vector.
1290+ sourceAggregate = rewriter.create <LLVM::InsertElementOp>(
1291+ loc, sourceAggregate.getType (), sourceAggregate, adaptor.getSource (),
1292+ getAsLLVMValue (rewriter, loc, positionOfScalarWithin1DVector));
12651293 }
12661294
1267- // Insertion of an element into a 1-D LLVM vector.
1268- Value inserted = rewriter.create <LLVM::InsertElementOp>(
1269- loc, typeConverter->convertType (oneDVectorType), extracted,
1270- adaptor.getSource (), getAsLLVMValue (rewriter, loc, position.back ()));
1271-
1272- // Potential insertion of resulting 1-D vector into array.
1273- if (position.size () > 1 ) {
1274- if (insertOp.hasDynamicPosition ())
1275- return failure ();
1276-
1277- inserted = rewriter.create <LLVM::InsertValueOp>(
1278- loc, adaptor.getDest (), inserted,
1279- getAsIntegers (position.drop_back ()));
1295+ Value result = sourceAggregate;
1296+ if (isNestedAggregate) {
1297+ result = rewriter.create <LLVM::InsertValueOp>(
1298+ loc, adaptor.getDest (), sourceAggregate,
1299+ getAsIntegers (positionOf1DVectorWithinAggregate));
12801300 }
12811301
1282- rewriter.replaceOp (insertOp, inserted );
1302+ rewriter.replaceOp (insertOp, result );
12831303 return success ();
12841304 }
12851305};
0 commit comments