@@ -1096,53 +1096,55 @@ class VectorExtractOpConversion
10961096 SmallVector<OpFoldResult> positionVec = getMixedValues (
10971097 adaptor.getStaticPosition (), adaptor.getDynamicPosition (), rewriter);
10981098
1099- // The LLVM lowering models multi dimension vectors as stacked 1-d vectors.
1100- // The stacking is modeled using arrays. We do this conversion from a
1101- // N-d vector extract to stacked 1-d vector extract in two steps:
1102- // - Extract a 1-d vector or a stack of 1-d vectors (llvm.extractvalue)
1103- // - Extract a scalar out of the 1-d vector if needed (llvm.extractelement)
1104-
1105- // Determine if we need to extract a slice out of the original vector. We
1106- // always need to extract a slice if the input rank >= 2.
1107- bool isSlicingExtract = extractOp.getSourceVectorType ().getRank () >= 2 ;
1099+ // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1100+ // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1101+ // from a N-d vector extract to a nested aggregate vector extract in two
1102+ // steps:
1103+ // - Extract a member from the nested aggregate. The result can be
1104+ // a lower rank nested aggregate or a vector (1-D). This is done using
1105+ // `llvm.extractvalue`.
1106+ // - Extract a scalar out of the vector if needed. This is done using
1107+ // `llvm.extractelement`.
1108+
1109+ // Determine if we need to extract a member out of the aggregate. We
1110+ // always need to extract a member if the input rank >= 2.
1111+ bool extractsAggregate = extractOp.getSourceVectorType ().getRank () >= 2 ;
11081112 // Determine if we need to extract a scalar as the result. We extract
1109- // a scalar if the extract is full rank i.e. the number of indices is equal
1110- // to source vector rank.
1111- bool isScalarExtract = static_cast <int64_t >(positionVec.size ()) ==
1112- extractOp.getSourceVectorType ().getRank ();
1113+ // a scalar if the extract is full rank, i.e., the number of indices is
1114+ // equal to source vector rank.
1115+ bool extractsScalar = static_cast <int64_t >(positionVec.size ()) ==
1116+ extractOp.getSourceVectorType ().getRank ();
1117+
1118+ // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1119+ // need to add a position for this change.
1120+ if (extractOp.getSourceVectorType ().getRank () == 0 ) {
1121+ auto idxType = rewriter.getIndexType ();
1122+ Value position = rewriter.create <LLVM::ConstantOp>(
1123+ loc, typeConverter->convertType (idxType),
1124+ rewriter.getIntegerAttr (idxType, 0 ));
1125+ positionVec.push_back (position);
1126+ }
11131127
11141128 Value extracted = adaptor.getVector ();
1115- if (isSlicingExtract ) {
1129+ if (extractsScalar ) {
11161130 ArrayRef<OpFoldResult> position (positionVec);
1117- if (isScalarExtract) {
1118- // If we are extracting a scalar from the returned slice, we need to
1119- // extract a N-1 D slice.
1131+ if (extractsAggregate) {
1132+ // If we are extracting a scalar from the extracted member, we drop
1133+ // the last index, which will be used to extract the scalar out of the
1134+ // vector.
11201135 position = position.drop_back ();
11211136 }
11221137 // llvm.extractvalue does not support dynamic dimensions.
1123- if (!llvm::all_of (position,
1124- [](OpFoldResult x) { return isa<Attribute>(x); })) {
1138+ if (!llvm::all_of (position, llvm::IsaPred<Attribute>)) {
11251139 return failure ();
11261140 }
11271141 extracted = rewriter.create <LLVM::ExtractValueOp>(
11281142 loc, extracted, getAsIntegers (position));
11291143 }
11301144
1131- if (isScalarExtract) {
1132- Value position;
1133- if (positionVec.empty ()) {
1134- // A scalar extract with no position is a 0-D vector extract. The LLVM
1135- // type converter converts 0-D vectors to 1-D vectors, so we need to add
1136- // a constant position.
1137- auto idxType = rewriter.getIndexType ();
1138- position = rewriter.create <LLVM::ConstantOp>(
1139- loc, typeConverter->convertType (idxType),
1140- rewriter.getIntegerAttr (idxType, 0 ));
1141- } else {
1142- position = getAsLLVMValue (rewriter, loc, positionVec.back ());
1143- }
1144- extracted =
1145- rewriter.create <LLVM::ExtractElementOp>(loc, extracted, position);
1145+ if (extractsScalar) {
1146+ extracted = rewriter.create <LLVM::ExtractElementOp>(
1147+ loc, extracted, getAsLLVMValue (rewriter, loc, positionVec.back ()));
11461148 }
11471149
11481150 rewriter.replaceOp (extractOp, extracted);
0 commit comments