Skip to content

Commit a4b0153

Browse files
authored
[mlir][vector] Support for extracting 1-element vectors in VectorExtractOpConversion (llvm#107549)
This patch adds support for converting `vector.extract` that extract 1-element vectors into LLVM, fixing a crash in such cases. E.g., `vector.extract %1[0]: vector<1xf32> from vector<2xf32>`. Fix llvm#61372.
1 parent a8f3d30 commit a4b0153

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,10 @@ class VectorExtractOpConversion
11041104
}
11051105

11061106
// One-shot extraction of vector from array (only requires extractvalue).
1107-
if (isa<VectorType>(resultType)) {
1107+
// Except for extracting 1-element vectors.
1108+
if (isa<VectorType>(resultType) &&
1109+
position.size() !=
1110+
static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
11081111
if (extractOp.hasDynamicPosition())
11091112
return failure();
11101113

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,30 @@ func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f
11301130

11311131
// -----
11321132

1133+
func.func @extract_vec_1e_from_vec_1d_f32(%arg0: vector<16xf32>) -> vector<1xf32> {
1134+
%0 = vector.extract %arg0[15]: vector<1xf32> from vector<16xf32>
1135+
return %0 : vector<1xf32>
1136+
}
1137+
// CHECK-LABEL: @extract_vec_1e_from_vec_1d_f32(
1138+
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
1139+
// CHECK: %[[T0:.*]] = llvm.mlir.constant(15 : i64) : i64
1140+
// CHECK: %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : i64] : vector<16xf32>
1141+
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : f32 to vector<1xf32>
1142+
// CHECK: return %[[T2]] : vector<1xf32>
1143+
1144+
func.func @extract_vec_1e_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> vector<1xf32> {
1145+
%0 = vector.extract %arg0[15]: vector<1xf32> from vector<[16]xf32>
1146+
return %0 : vector<1xf32>
1147+
}
1148+
// CHECK-LABEL: @extract_vec_1e_from_vec_1d_f32_scalable(
1149+
// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
1150+
// CHECK: %[[T0:.*]] = llvm.mlir.constant(15 : i64) : i64
1151+
// CHECK: %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : i64] : vector<[16]xf32>
1152+
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : f32 to vector<1xf32>
1153+
// CHECK: return %[[T2]] : vector<1xf32>
1154+
1155+
// -----
1156+
11331157
func.func @extract_scalar_from_vec_1d_index(%arg0: vector<16xindex>) -> index {
11341158
%0 = vector.extract %arg0[15]: index from vector<16xindex>
11351159
return %0 : index

0 commit comments

Comments
 (0)