Skip to content

Commit 552f80a

Browse files
[mlir][vector] Fix crash when folding 0D extract from splat/broadcast (llvm#95918)
There was an assertion in the folder that caused a crash when extracting from a vector that is defined by an op with 0D semantics. This commit removes the assertion and adds test cases to ensure that 0D scenarios are handled correctly.
1 parent 41f6aee commit 552f80a

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,11 +1631,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
16311631
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
16321632
return Value();
16331633

1634-
// 0-D vectors not supported.
1635-
assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1636-
if (hasZeroDimVectors(defOp))
1637-
return Value();
1638-
16391634
Value source = defOp->getOperand(0);
16401635
if (extractOp.getType() == source.getType())
16411636
return source;

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2604,3 +2604,41 @@ func.func @extract_from_0d_regression(%v: vector<f32>) -> f32 {
26042604
%0 = vector.extract %v[] : f32 from vector<f32>
26052605
return %0 : f32
26062606
}
2607+
2608+
// -----
2609+
2610+
// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
2611+
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
2612+
func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
2613+
// Splat scalar to 0D and extract scalar.
2614+
%0 = vector.splat %a : vector<f32>
2615+
%1 = vector.extract %0[] : f32 from vector<f32>
2616+
2617+
// Broadcast scalar to 0D and extract scalar.
2618+
%2 = vector.broadcast %a : f32 to vector<f32>
2619+
%3 = vector.extract %2[] : f32 from vector<f32>
2620+
2621+
// Broadcast 0D to 3D and extract scalar.
2622+
// CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
2623+
%4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
2624+
%5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
2625+
2626+
// Splat scalar to 2D and extract scalar.
2627+
%6 = vector.splat %a : vector<2x3xf32>
2628+
%7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
2629+
2630+
// Broadcast scalar to 3D and extract scalar.
2631+
%8 = vector.broadcast %a : f32 to vector<5x6x7xf32>
2632+
%9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
2633+
2634+
// Extract 2D from 3D that was broadcasted from a scalar.
2635+
// CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32>
2636+
%10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
2637+
2638+
// Extract 1D from 2D that was splat'ed from a scalar.
2639+
// CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32>
2640+
%11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
2641+
2642+
// CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
2643+
return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
2644+
}

0 commit comments

Comments
 (0)