Skip to content

Commit 14f5504

Browse files
authored
[mlir][tosa] Fix validation check on controlflow operators (#159754)
Previoulsy the error_if check for controlflow operators would silently fail on valid controflow operators. This was due to incorrect return logic in the validation function. This commit fixes that logic.
1 parent cae73be commit 14f5504

File tree

3 files changed

+38
-37
lines changed

3 files changed

+38
-37
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,17 +1257,17 @@ bool checkErrorIfCondIf(Operation *op) {
12571257
// tosa.yield %arg4
12581258
// }
12591259

1260-
return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
1261-
failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
1260+
return succeeded(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) &&
1261+
succeeded(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
12621262
}
12631263

12641264
bool checkErrorIfWhileLoop(Operation *op) {
12651265
auto whileOp = dyn_cast<tosa::WhileOp>(op);
12661266
if (!whileOp)
12671267
return true;
12681268

1269-
return failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
1270-
failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body"));
1269+
return succeeded(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) &&
1270+
succeeded(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body"));
12711271
}
12721272

12731273
bool checkErrorIfScatter(Operation *op) {

mlir/test/Dialect/Tosa/error_if_check.mlir

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -269,20 +269,6 @@ func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f3
269269

270270
// -----
271271

272-
// Check isolated cond_if's are valid
273-
func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
274-
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
275-
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
276-
tosa.yield %arg3 : tensor<f32>
277-
}, {
278-
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
279-
tosa.yield %arg4 : tensor<f32>
280-
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
281-
return %0 : tensor<f32>
282-
}
283-
284-
// -----
285-
286272
func.func @test_while_loop_cond_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
287273
%0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
288274
// expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'cond' region is isolated from above.}}
@@ -318,22 +304,3 @@ func.func @test_while_loop_body_not_isolated_from_above(%arg0: tensor<i32>, %arg
318304
}) : (tensor<i32>) -> (tensor<i32>)
319305
return
320306
}
321-
322-
// -----
323-
324-
// Check isolated while_loops are valid
325-
func.func @test_while_loop_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<i32>) {
326-
%0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
327-
%1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
328-
^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
329-
%2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor<i32>, tensor<i32>) -> tensor<i1>
330-
%3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
331-
"tosa.yield"(%3) : (tensor<i1>) -> ()
332-
}, {
333-
^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
334-
%2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
335-
%3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
336-
"tosa.yield"(%3, %arg4, %arg5) : (tensor<i32>, tensor<f32>, tensor<i32>) -> ()
337-
}) : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
338-
return
339-
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" | FileCheck %s
2+
3+
// -----
4+
5+
// CHECK-LABEL: test_cond_if_isolated_from_above
6+
func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
7+
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
8+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
9+
tosa.yield %arg3 : tensor<f32>
10+
}, {
11+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
12+
tosa.yield %arg4 : tensor<f32>
13+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
14+
return %0 : tensor<f32>
15+
}
16+
17+
// -----
18+
19+
// CHECK-LABEL: test_while_loop_isolated_from_above
20+
func.func @test_while_loop_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<i32>) {
21+
%0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
22+
%1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
23+
^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
24+
%2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor<i32>, tensor<i32>) -> tensor<i1>
25+
%3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
26+
"tosa.yield"(%3) : (tensor<i1>) -> ()
27+
}, {
28+
^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
29+
%2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
30+
%3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
31+
"tosa.yield"(%3, %arg4, %arg5) : (tensor<i32>, tensor<f32>, tensor<i32>) -> ()
32+
}) : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
33+
return
34+
}

0 commit comments

Comments
 (0)