Skip to content

Commit eede476

Browse files
authored
[mlir][tosa] Robustify Tosa_while_loop op against null dereference and wrong assertion (#159910)
Follow up to #159756
1 parent 6438d01 commit eede476

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4073,16 +4073,26 @@ LogicalResult WhileOp::verify() {
40734073
.failed())
40744074
return failure();
40754075

4076-
auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4077-
if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
4078-
"'body_graph' results", getInputList(),
4079-
"'input_list'")
4080-
.failed())
4081-
return failure();
4076+
if (getBodyGraph().front().mightHaveTerminator()) {
4077+
auto bodyYield =
4078+
dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4079+
if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
4080+
"'body_graph' results",
4081+
getInputList(), "'input_list'")
4082+
.failed())
4083+
return failure();
4084+
}
40824085

40834086
// Condition block output must be a single element tensor with a single bool
40844087
// value.
4085-
auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4088+
if (!getCondGraph().front().mightHaveTerminator())
4089+
return success();
4090+
4091+
auto condYield =
4092+
dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4093+
if (!condYield)
4094+
return success();
4095+
40864096
if (condYield.getInputs().size() != 1)
40874097
return emitOpError() << "require 'cond_graph' only have one result";
40884098

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,48 @@ func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<
686686
return %0 : tensor<f32>
687687
}
688688

689+
// -----
690+
func.func @test_while_loop_wrong_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
691+
%0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
692+
// expected-error@+2 {{'func.return' op expects parent op 'func.func'}}
693+
%1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
694+
"func.return"(%arg2) : (tensor<i32>) -> ()
695+
} do {
696+
^bb0(%arg2: tensor<i32>):
697+
%1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
698+
%2 = tosa.add %arg2, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
699+
tosa.yield %2 : tensor<i32>
700+
}
701+
return %0 : tensor<i32>
702+
}
703+
704+
// -----
705+
func.func @test_while_loop_missing_cond_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
706+
%0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
707+
// expected-error@+1 {{block with no terminator}}
708+
%1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
709+
} do {
710+
^bb0(%arg2: tensor<i32>):
711+
%1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
712+
%2 = tosa.add %arg2, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
713+
tosa.yield %2 : tensor<i32>
714+
}
715+
return %0 : tensor<i32>
716+
}
717+
718+
// -----
719+
func.func @test_while_loop_missing_body_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
720+
%0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
721+
%1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
722+
tosa.yield %1 : tensor<i1>
723+
} do {
724+
^bb0(%arg2: tensor<i32>):
725+
// expected-error@+1 {{block with no terminator}}
726+
%1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
727+
}
728+
return %0 : tensor<i32>
729+
}
730+
689731
// -----
690732

691733
func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {

0 commit comments

Comments
 (0)