Skip to content

Commit 370ea51

Browse files
authored
[mlir][tosa] Robustify Tosa_IfOp against null dereference and wrong assertion (#159756)
Fixes #159650. The current implementation ICE out if we access an IfOp's terminator when it doesn't have it. Instead the PR defers the job of verifying that a block would have at least a terminator. The current implementation also crashes with cast<YieldOp> if the terminator is not a YieldOp, the PR also defers the job of verification to the op itself.
1 parent add9079 commit 370ea51

File tree

2 files changed

+57
-12
lines changed

2 files changed

+57
-12
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4025,19 +4025,27 @@ LogicalResult IfOp::verify() {
40254025
.failed())
40264026
return failure();
40274027

4028-
auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4029-
if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
4030-
"'then_graph' results", getOutputList(),
4031-
"'output_list'")
4032-
.failed())
4033-
return failure();
4028+
// MLIR will verify the absence of the terminator for us if otherwise.
4029+
if (getThenGraph().front().mightHaveTerminator()) {
4030+
auto thenYield =
4031+
dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4032+
if (thenYield && errorIfTypeOrShapeMismatch(
4033+
*this, thenYield.getInputs(), "'then_graph' results",
4034+
getOutputList(), "'output_list'")
4035+
.failed())
4036+
return failure();
4037+
}
40344038

4035-
auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4036-
if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
4037-
"'else_graph' results", getOutputList(),
4038-
"'output_list'")
4039-
.failed())
4040-
return failure();
4039+
// MLIR will verify the absence of the terminator for us if otherwise.
4040+
if (getElseGraph().front().mightHaveTerminator()) {
4041+
auto elseYield =
4042+
dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4043+
if (elseYield && errorIfTypeOrShapeMismatch(
4044+
*this, elseYield.getInputs(), "'else_graph' results",
4045+
getOutputList(), "'output_list'")
4046+
.failed())
4047+
return failure();
4048+
}
40414049

40424050
auto condType = getCondition().getType();
40434051
if (errorIfShapeNotSizeOne(*this, condType).failed())

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,43 @@ func.func @test_pad_invalid_padding_value(%arg0: tensor<10xi8>, %arg1: tensor<1x
438438
return %1 : tensor<10xi8>
439439
}
440440

441+
// -----
442+
func.func @test_cond_if_wrong_terminator_op(%arg0: tensor<i1>) -> tensor<i32> {
443+
%0 = "tosa.cond_if"(%arg0) ({
444+
%1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
445+
"tosa.yield"(%1) : (tensor<i32>) -> ()
446+
}, {
447+
// expected-error@+2 {{'func.return' op expects parent op 'func.func'}}
448+
%2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
449+
"func.return"(%2) : (tensor<i32>) -> ()
450+
}) : (tensor<i1>) -> tensor<i32>
451+
return %0 : tensor<i32>
452+
}
453+
454+
// -----
455+
func.func @test_cond_if_missing_then_terminator(%arg0: tensor<i1>) -> tensor<i32> {
456+
%0 = "tosa.cond_if"(%arg0) ({
457+
// expected-error@+1 {{block with no terminator}}
458+
%1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
459+
}, {
460+
%2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
461+
"tosa.yield"(%2) : (tensor<i32>) -> ()
462+
}) : (tensor<i1>) -> tensor<i32>
463+
return %0 : tensor<i32>
464+
}
465+
466+
// -----
467+
func.func @test_cond_if_missing_else_terminator(%arg0: tensor<i1>) -> tensor<i32> {
468+
%0 = "tosa.cond_if"(%arg0) ({
469+
%1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
470+
"tosa.yield"(%1) : (tensor<i32>) -> ()
471+
}, {
472+
// expected-error@+1 {{block with no terminator}}
473+
%2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
474+
}) : (tensor<i1>) -> tensor<i32>
475+
return %0 : tensor<i32>
476+
}
477+
441478
// -----
442479

443480
func.func @test_cond_if_input_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {

0 commit comments

Comments
 (0)