Skip to content

Commit dc1a9a5

Browse files
committed
[mlir][Vector] Fix vector.extract lowering to llvm for 0-d vectors
1 parent bbea1de commit dc1a9a5

File tree

2 files changed

+84
-38
lines changed

2 files changed

+84
-38
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,43 +1096,50 @@ class VectorExtractOpConversion
10961096
SmallVector<OpFoldResult> positionVec = getMixedValues(
10971097
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
10981098

1099-
// Extract entire vector. Should be handled by folder, but just to be safe.
1100-
ArrayRef<OpFoldResult> position(positionVec);
1101-
if (position.empty()) {
1102-
rewriter.replaceOp(extractOp, adaptor.getVector());
1103-
return success();
1104-
}
1105-
1106-
// One-shot extraction of vector from array (only requires extractvalue).
1107-
// Except for extracting 1-element vectors.
1108-
if (isa<VectorType>(resultType) &&
1109-
position.size() !=
1110-
static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
1111-
if (extractOp.hasDynamicPosition())
1112-
return failure();
1113-
1114-
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1115-
loc, adaptor.getVector(), getAsIntegers(position));
1116-
rewriter.replaceOp(extractOp, extracted);
1117-
return success();
1118-
}
1099+
// Determine if we need to extract a scalar as the result. We extract
1100+
// a scalar if the extract is full rank i.e. the number of indices is equal
1101+
// to source vector rank.
1102+
bool isScalarExtract =
1103+
positionVec.size() == extractOp.getSourceVectorType().getRank();
1104+
// Determine if we need to extract a slice out of the original vector. We
1105+
// always need to extract a slice if the input rank >= 2.
1106+
bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
11191107

1120-
// Potential extraction of 1-D vector from array.
11211108
Value extracted = adaptor.getVector();
1122-
if (position.size() > 1) {
1123-
if (extractOp.hasDynamicPosition())
1109+
if (isSlicingExtract) {
1110+
ArrayRef<OpFoldResult> position(positionVec);
1111+
if (isScalarExtract) {
1112+
// If we are extracting a scalar from the returned slice, we need to
1113+
// extract a N-1 D slice.
1114+
position = position.drop_back();
1115+
}
1116+
// llvm.extractvalue does not support dynamic dimensions.
1117+
if (!llvm::all_of(position,
1118+
[](OpFoldResult x) { return isa<Attribute>(x); })) {
11241119
return failure();
1120+
}
1121+
extracted = rewriter.create<LLVM::ExtractValueOp>(
1122+
loc, extracted, getAsIntegers(position));
1123+
}
11251124

1126-
SmallVector<int64_t> nMinusOnePosition =
1127-
getAsIntegers(position.drop_back());
1128-
extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1129-
nMinusOnePosition);
1125+
if (isScalarExtract) {
1126+
Value position;
1127+
if (positionVec.empty()) {
1128+
// A scalar extract with no position is a 0-D vector extract. The LLVM
1129+
// type converter converts 0-D vectors to 1-D vectors, so we need to add
1130+
// a constant position.
1131+
auto idxType = rewriter.getIndexType();
1132+
position = rewriter.create<LLVM::ConstantOp>(
1133+
loc, typeConverter->convertType(idxType),
1134+
rewriter.getIntegerAttr(idxType, 0));
1135+
} else {
1136+
position = getAsLLVMValue(rewriter, loc, positionVec.back());
1137+
}
1138+
extracted =
1139+
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, position);
11301140
}
11311141

1132-
Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
1133-
// Remaining extraction of element from 1-D LLVM vector.
1134-
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
1135-
lastPosition);
1142+
rewriter.replaceOp(extractOp, extracted);
11361143
return success();
11371144
}
11381145
};

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,26 +1290,65 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16
12901290

12911291
// -----
12921292

1293-
func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
1293+
func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
12941294
%0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32>
12951295
return %0 : f32
12961296
}
12971297

1298-
// Multi-dim vectors are not supported but this test shouldn't crash.
1298+
// Multi-dim vectors are supported if the inner most dimension is dynamic.
12991299

1300-
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx(
1301-
// CHECK: vector.extract
1300+
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(
1301+
// CHECK: llvm.extractvalue
1302+
// CHECK: llvm.extractelement
13021303

1303-
func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
1304+
func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
13041305
%0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x[16]xf32>
13051306
return %0 : f32
13061307
}
13071308

1308-
// Multi-dim vectors are not supported but this test shouldn't crash.
1309+
// Multi-dim vectors are supported if the inner most dimension is dynamic.
1310+
1311+
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(
1312+
// CHECK: llvm.extractvalue
1313+
// CHECK: llvm.extractelement
1314+
1315+
// -----
13091316

1310-
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(
1317+
func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
1318+
%0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x16xf32>
1319+
return %0 : f32
1320+
}
1321+
1322+
// Multi-dim vectors are supported if the inner most dimension is dynamic.
1323+
1324+
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(
13111325
// CHECK: vector.extract
13121326

1327+
func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
1328+
%0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x[16]xf32>
1329+
return %0 : f32
1330+
}
1331+
1332+
// Multi-dim vectors with outer dimension as dynamic are not supported, but it
1333+
// shouldn't crash.
1334+
1335+
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(
1336+
// CHECK: vector.extract
1337+
1338+
// -----
1339+
1340+
func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
1341+
%0 = vector.extract %arg0[]: index from vector<index>
1342+
return %0 : index
1343+
}
1344+
// CHECK-LABEL: @extract_scalar_from_vec_0d_index(
1345+
// CHECK-SAME: %[[A:.*]]: vector<index>)
1346+
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<index> to vector<1xi64>
1347+
// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
1348+
// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xi64>
1349+
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
1350+
// CHECK: return %[[T3]] : index
1351+
13131352
// -----
13141353

13151354
func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> {

0 commit comments

Comments
 (0)