Skip to content

Commit d20d166

Browse files
committed
[mlir][tosa] Print generic cond_if when block arguments are present
The generic printer/parser captures information about block arguments for then/else regions, while the simplified version does not. Currently the simplified printer is preferred by default, which means information about block arguments can be lost during a parse/print round-trip. This commit changes that behaviour so that the generic printer is preferred when block arguments have been provided, thus avoiding loss of information. Change-Id: Ia44fde857e6cd3a26dbc40c0a9187b4ddb95666b
1 parent 5a9cc93 commit d20d166

File tree

2 files changed

+92
-5
lines changed

2 files changed

+92
-5
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3649,17 +3649,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
36493649

36503650
// parse and print of IfOp refer to the implementation of SCF dialect.
36513651
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3652+
OpAsmParser::UnresolvedOperand cond;
3653+
// Fallback to generic IfOp parser when no immediate conditional
3654+
// operand is provided.
3655+
if (!parser.parseOptionalOperand(cond).has_value()) {
3656+
return parser.parseGenericOperationAfterOpName(result);
3657+
}
3658+
36523659
// Create the regions for 'then'.
36533660
result.regions.reserve(2);
36543661
Region *thenRegion = result.addRegion();
36553662
Region *elseRegion = result.addRegion();
36563663

36573664
auto &builder = parser.getBuilder();
3658-
OpAsmParser::UnresolvedOperand cond;
36593665
// Create a i1 tensor type for the boolean condition.
36603666
Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
3661-
if (parser.parseOperand(cond) ||
3662-
parser.resolveOperand(cond, i1Type, result.operands))
3667+
if (parser.resolveOperand(cond, i1Type, result.operands))
36633668
return failure();
36643669
// Parse optional results type list.
36653670
if (parser.parseOptionalArrowTypeList(result.types))
@@ -3681,6 +3686,17 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
36813686
}
36823687

36833688
void IfOp::print(OpAsmPrinter &p) {
3689+
// The simplified syntax drops block-level arguments
3690+
// to the then/else regions. Fallback to the generic
3691+
// parser if these are found
3692+
Region &thenRegion = getThenGraph();
3693+
Region &elseRegion = getElseGraph();
3694+
if (!thenRegion.empty() && thenRegion.front().getNumArguments() > 0 &&
3695+
!elseRegion.empty() && elseRegion.front().getNumArguments() > 0) {
3696+
p.printGenericOp(*this, false);
3697+
return;
3698+
}
3699+
36843700
bool printBlockTerminators = false;
36853701

36863702
p << " " << getCondition();
@@ -3690,12 +3706,11 @@ void IfOp::print(OpAsmPrinter &p) {
36903706
printBlockTerminators = true;
36913707
}
36923708
p << ' ';
3693-
p.printRegion(getThenGraph(),
3709+
p.printRegion(thenRegion,
36943710
/*printEntryBlockArgs=*/false,
36953711
/*printBlockTerminators=*/printBlockTerminators);
36963712

36973713
// Print the 'else' regions if it exists and has a block.
3698-
auto &elseRegion = getElseGraph();
36993714
if (!elseRegion.empty()) {
37003715
p << " else ";
37013716
p.printRegion(elseRegion,
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
2+
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
3+
4+
// -----
5+
6+
// CHECK-LABEL: test_cond_if_generic_form
7+
// CHECK: %[[OUT:.*]] = tosa.cond_if(%[[COND:.*]], %[[IN0:.*]], %[[IN1:.*]]) ({
8+
// CHECK: ^bb0(%[[INA:.*]]: tensor<f32>, %[[INB:.*]]: tensor<f32>):
9+
// CHECK: %[[THEN_TERM:.*]] = tosa.add %[[INA]], %[[INB]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
10+
// CHECK: tosa.yield %[[THEN_TERM]] : tensor<f32>
11+
// CHECK: }, {
12+
// CHECK: ^bb0(%[[INC:.*]]: tensor<f32>, %[[IND:.*]]: tensor<f32>):
13+
// CHECK: %[[ELSE_TERM:.*]] = tosa.sub %[[INC]], %[[IND]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
14+
// CHECK: tosa.yield %[[ELSE_TERM]] : tensor<f32>
15+
// CHECK: }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
16+
// CHECK: return %[[OUT]] : tensor<f32>
17+
func.func @test_cond_if_generic_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
18+
%0 = tosa.cond_if(%arg2, %arg0, %arg1) ({
19+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
20+
%1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
21+
tosa.yield %1 : tensor<f32>
22+
}, {
23+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
24+
%1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
25+
tosa.yield %1 : tensor<f32>
26+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
27+
return %0 : tensor<f32>
28+
}
29+
30+
// -----
31+
32+
// CHECK-LABEL: test_cond_if_generic_form_no_block_arguments
33+
// COM: No block arguments means simplified form can be printed
34+
func.func @test_cond_if_generic_form_no_block_arguments(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
35+
// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
36+
%0 = tosa.cond_if(%arg2) ({
37+
^bb0():
38+
tosa.yield %arg0 : tensor<f32>
39+
}, {
40+
^bb0():
41+
tosa.yield %arg1 : tensor<f32>
42+
}) : (tensor<i1>) -> tensor<f32>
43+
return %0 : tensor<f32>
44+
}
45+
46+
// -----
47+
48+
// CHECK-LABEL: test_cond_if_simplified_form
49+
// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
50+
func.func @test_cond_if_simplified_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
51+
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
52+
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
53+
tosa.yield %1 : tensor<f32>
54+
} else {
55+
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
56+
tosa.yield %1 : tensor<f32>
57+
}
58+
return %0 : tensor<f32>
59+
}
60+
61+
// -----
62+
63+
// CHECK-LABEL: test_cond_if_simplified_form_just_yield
64+
// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
65+
func.func @test_cond_if_simplified_form_just_yield(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
66+
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
67+
tosa.yield %arg0 : tensor<f32>
68+
} else {
69+
tosa.yield %arg1 : tensor<f32>
70+
}
71+
return %0 : tensor<f32>
72+
}

0 commit comments

Comments
 (0)