Skip to content

Commit fa0f2d2

Browse files
authored
[mlir][tosa] Fix crash in tosa.concat verifier (llvm#165966)
The `tosa.concat` verifier crashed when the output rank did not match the input rank. This PR adds a proper check and error emission to prevent the crash. Fixes llvm#159742.
1 parent 5ba0c7c commit fa0f2d2

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1761,6 +1761,11 @@ LogicalResult tosa::ConcatOp::verify() {
17611761
}
17621762
}
17631763

1764+
const ShapeAdaptor outputShape(outType);
1765+
if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1766+
return emitOpError("expect output rank to match inputs rank, got ")
1767+
<< outputShape.getRank() << " vs " << firstInputRank;
1768+
17641769
// ERROR_IF(axis_sum != shape[axis]);
17651770
int64_t axisSum = 0;
17661771
for (const auto &input : inputList) {
@@ -1772,7 +1777,7 @@ LogicalResult tosa::ConcatOp::verify() {
17721777
}
17731778
axisSum += inputShape.getDimSize(axis);
17741779
}
1775-
const ShapeAdaptor outputShape(outType);
1780+
17761781
if (axisSum >= 0 && outputShape.hasRank() &&
17771782
!outputShape.isDynamicDim(axis) &&
17781783
axisSum != outputShape.getDimSize(axis))

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,14 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens
306306

307307
// -----
308308

309+
func.func @test_concat_input_output_rank_mismatch(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2xf32> {
310+
// expected-error@+1 {{'tosa.concat' op expect output rank to match inputs rank, got 1 vs 2}}
311+
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2xf32>
312+
return %0 : tensor<2xf32>
313+
}
314+
315+
// -----
316+
309317
func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>) {
310318
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
311319
%1 = "tosa.const"() {values = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>

0 commit comments

Comments
 (0)