3535using namespace mlir ;
3636using namespace mlir ::vector;
3737
38- // Helper to reduce vector type by *all* but one rank at back.
39- static VectorType reducedVectorTypeBack (VectorType tp) {
40- assert ((tp.getRank () > 1 ) && " unlowerable vector type" );
41- return VectorType::get (tp.getShape ().take_back (), tp.getElementType (),
42- tp.getScalableDims ().take_back ());
43- }
44-
4538// Helper that picks the proper sequence for inserting.
4639static Value insertOne (ConversionPatternRewriter &rewriter,
4740 const LLVMTypeConverter &typeConverter, Location loc,
@@ -1223,7 +1216,6 @@ class VectorInsertOpConversion
12231216 matchAndRewrite (vector::InsertOp insertOp, OpAdaptor adaptor,
12241217 ConversionPatternRewriter &rewriter) const override {
12251218 auto loc = insertOp->getLoc ();
1226- auto sourceType = insertOp.getSourceType ();
12271219 auto destVectorType = insertOp.getDestVectorType ();
12281220 auto llvmResultType = typeConverter->convertType (destVectorType);
12291221 // Bail if result type cannot be lowered.
@@ -1233,53 +1225,74 @@ class VectorInsertOpConversion
12331225 SmallVector<OpFoldResult> positionVec = getMixedValues (
12341226 adaptor.getStaticPosition (), adaptor.getDynamicPosition (), rewriter);
12351227
1236- // Overwrite entire vector with value. Should be handled by folder, but
1237- // just to be safe.
1238- ArrayRef<OpFoldResult> position (positionVec);
1239- if (position.empty ()) {
1240- rewriter.replaceOp (insertOp, adaptor.getSource ());
1241- return success ();
1242- }
1243-
1244- // One-shot insertion of a vector into an array (only requires insertvalue).
1245- if (isa<VectorType>(sourceType)) {
1246- if (insertOp.hasDynamicPosition ())
1247- return failure ();
1248-
1249- Value inserted = rewriter.create <LLVM::InsertValueOp>(
1250- loc, adaptor.getDest (), adaptor.getSource (), getAsIntegers (position));
1251- rewriter.replaceOp (insertOp, inserted);
1252- return success ();
1228+ // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1229+ // its explanatory comment about how N-D vectors are converted as nested
1230+ // aggregates (llvm.array's) of 1D vectors.
1231+ //
1232+ // There are 3 steps here, vs 2 in VectorExtractOpConversion:
1233+ // - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue.
1234+ // - Insertion into the 1D vector: llvm.insertelement.
1235+ // - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue.
1236+
1237+ // Determine if we need to extract/insert a 1D vector out of the aggregate.
1238+ bool is1DVectorWithinAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1239+ // Determine if we need to insert a scalar into the 1D vector.
1240+ bool isScalarWithin1DVector =
1241+ static_cast <int64_t >(positionVec.size ()) == destVectorType.getRank ();
1242+
1243+ ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate (
1244+ positionVec.begin (),
1245+ isScalarWithin1DVector ? positionVec.size () - 1 : positionVec.size ());
1246+ OpFoldResult positionOfScalarWithin1DVector;
1247+ if (destVectorType.getRank () == 0 ) {
1248+ // Since the LLVM type converter converts 0D vectors to 1D vectors, we
1249+ // need to create a 0 here as the position into the 1D vector.
1250+ Type idxType = typeConverter->convertType (rewriter.getIndexType ());
1251+ positionOfScalarWithin1DVector = rewriter.getZeroAttr (idxType);
1252+ } else if (isScalarWithin1DVector) {
1253+ positionOfScalarWithin1DVector = positionVec.back ();
12531254 }
12541255
1255- // Potential extraction of 1-D vector from array.
1256- Value extracted = adaptor.getDest ();
1257- auto oneDVectorType = destVectorType;
1258- if (position.size () > 1 ) {
1259- if (insertOp.hasDynamicPosition ())
1260- return failure ();
1261-
1262- oneDVectorType = reducedVectorTypeBack (destVectorType);
1263- extracted = rewriter.create <LLVM::ExtractValueOp>(
1264- loc, extracted, getAsIntegers (position.drop_back ()));
1256+ // We are going to mutate this 1D vector until it is either the final
1257+ // result (in the non-aggregate case) or the value that needs to be
1258+ // inserted into the aggregate result.
1259+ Value vector1d;
1260+ if (isScalarWithin1DVector) {
1261+ // Scalar-into-1D-vector case, so we know we will have to create a
1262+ // InsertElementOp. The question is into what destination.
1263+ if (is1DVectorWithinAggregate) {
1264+ // Aggregate case: the destination for the InsertElementOp needs to be
1265+ // extracted from the aggregate.
1266+ if (!llvm::all_of (positionOf1DVectorWithinAggregate,
1267+ llvm::IsaPred<Attribute>)) {
1268+ // llvm.extractvalue does not support dynamic dimensions.
1269+ return failure ();
1270+ }
1271+ vector1d = rewriter.create <LLVM::ExtractValueOp>(
1272+ loc, adaptor.getDest (),
1273+ getAsIntegers (positionOf1DVectorWithinAggregate));
1274+ } else {
1275+ // No-aggregate case. The destination for the InsertElementOp is just
1276+ // the insertOp's destination.
1277+ vector1d = adaptor.getDest ();
1278+ }
1279+ // Insert the scalar into the 1D vector.
1280+ vector1d = rewriter.create <LLVM::InsertElementOp>(
1281+ loc, vector1d.getType (), vector1d, adaptor.getSource (),
1282+ getAsLLVMValue (rewriter, loc, positionOfScalarWithin1DVector));
1283+ } else {
1284+ // No scalar insertion. The 1D vector is just the source.
1285+ vector1d = adaptor.getSource ();
12651286 }
12661287
1267- // Insertion of an element into a 1-D LLVM vector.
1268- Value inserted = rewriter.create <LLVM::InsertElementOp>(
1269- loc, typeConverter->convertType (oneDVectorType), extracted,
1270- adaptor.getSource (), getAsLLVMValue (rewriter, loc, position.back ()));
1271-
1272- // Potential insertion of resulting 1-D vector into array.
1273- if (position.size () > 1 ) {
1274- if (insertOp.hasDynamicPosition ())
1275- return failure ();
1276-
1277- inserted = rewriter.create <LLVM::InsertValueOp>(
1278- loc, adaptor.getDest (), inserted,
1279- getAsIntegers (position.drop_back ()));
1288+ Value result = vector1d;
1289+ if (is1DVectorWithinAggregate) {
1290+ result = rewriter.create <LLVM::InsertValueOp>(
1291+ loc, adaptor.getDest (), vector1d,
1292+ getAsIntegers (positionOf1DVectorWithinAggregate));
12801293 }
12811294
1282- rewriter.replaceOp (insertOp, inserted );
1295+ rewriter.replaceOp (insertOp, result );
12831296 return success ();
12841297 }
12851298};
0 commit comments