@@ -1096,43 +1096,50 @@ class VectorExtractOpConversion
10961096 SmallVector<OpFoldResult> positionVec = getMixedValues (
10971097 adaptor.getStaticPosition (), adaptor.getDynamicPosition (), rewriter);
10981098
1099- // Extract entire vector. Should be handled by folder, but just to be safe.
1100- ArrayRef<OpFoldResult> position (positionVec);
1101- if (position.empty ()) {
1102- rewriter.replaceOp (extractOp, adaptor.getVector ());
1103- return success ();
1104- }
1105-
1106- // One-shot extraction of vector from array (only requires extractvalue).
1107- // Except for extracting 1-element vectors.
1108- if (isa<VectorType>(resultType) &&
1109- position.size () !=
1110- static_cast <size_t >(extractOp.getSourceVectorType ().getRank ())) {
1111- if (extractOp.hasDynamicPosition ())
1112- return failure ();
1113-
1114- Value extracted = rewriter.create <LLVM::ExtractValueOp>(
1115- loc, adaptor.getVector (), getAsIntegers (position));
1116- rewriter.replaceOp (extractOp, extracted);
1117- return success ();
1118- }
1099+ // Determine if we need to extract a scalar as the result. We extract
1100+ // a scalar if the extract is full rank i.e. the number of indices is equal
1101+ // to source vector rank.
1102+ bool isScalarExtract =
1103+ positionVec.size () == extractOp.getSourceVectorType ().getRank ();
1104+ // Determine if we need to extract a slice out of the original vector. We
1105+ // always need to extract a slice if the input rank >= 2.
1106+ bool isSlicingExtract = extractOp.getSourceVectorType ().getRank () >= 2 ;
11191107
1120- // Potential extraction of 1-D vector from array.
11211108 Value extracted = adaptor.getVector ();
1122- if (position.size () > 1 ) {
1123- if (extractOp.hasDynamicPosition ())
1109+ if (isSlicingExtract) {
1110+ ArrayRef<OpFoldResult> position (positionVec);
1111+ if (isScalarExtract) {
1112+ // If we are extracting a scalar from the returned slice, we need to
1113+ // extract a N-1 D slice.
1114+ position = position.drop_back ();
1115+ }
1116+ // llvm.extractvalue does not support dynamic dimensions.
1117+ if (!llvm::all_of (position,
1118+ [](OpFoldResult x) { return isa<Attribute>(x); })) {
11241119 return failure ();
1120+ }
1121+ extracted = rewriter.create <LLVM::ExtractValueOp>(
1122+ loc, extracted, getAsIntegers (position));
1123+ }
11251124
1126- SmallVector<int64_t > nMinusOnePosition =
1127- getAsIntegers (position.drop_back ());
1128- extracted = rewriter.create <LLVM::ExtractValueOp>(loc, extracted,
1129- nMinusOnePosition);
1125+ if (isScalarExtract) {
1126+ Value position;
1127+ if (positionVec.empty ()) {
1128+ // A scalar extract with no position is a 0-D vector extract. The LLVM
1129+ // type converter converts 0-D vectors to 1-D vectors, so we need to add
1130+ // a constant position.
1131+ auto idxType = rewriter.getIndexType ();
1132+ position = rewriter.create <LLVM::ConstantOp>(
1133+ loc, typeConverter->convertType (idxType),
1134+ rewriter.getIntegerAttr (idxType, 0 ));
1135+ } else {
1136+ position = getAsLLVMValue (rewriter, loc, positionVec.back ());
1137+ }
1138+ extracted =
1139+ rewriter.create <LLVM::ExtractElementOp>(loc, extracted, position);
11301140 }
11311141
1132- Value lastPosition = getAsLLVMValue (rewriter, loc, position.back ());
1133- // Remaining extraction of element from 1-D LLVM vector.
1134- rewriter.replaceOpWithNewOp <LLVM::ExtractElementOp>(extractOp, extracted,
1135- lastPosition);
1142+ rewriter.replaceOp (extractOp, extracted);
11361143 return success ();
11371144 }
11381145};
0 commit comments