Skip to content

Commit 9c606ae

Browse files
[MLIR][TOSA] Update IfOp print/parse to support ranked condition tens… (llvm#149791)
…or and optional block arguments This change extends the TOSA `cond_if` operation's print and parse logic to handle the following: - The condition operand may now have any rank, as long as the total number of elements sums to 1. %1 = tosa.cond_if %0 : tensor<1x1x1xi1> -> tensor<4xf32> - The `then` and `else` regions can now include optional block arguments. The updated IR syntax reflects this: %1 = tosa.cond_if %0 (%arg2 = %arg0, %arg3 = %arg1) : tensor<i1> (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - Removed parentheses around single result types in the printed representation, aligning with the `AsmPrinter` conventions. Co-authored-by: Luke Hutton <[email protected]>
1 parent 51194a4 commit 9c606ae

File tree

12 files changed

+228
-68
lines changed

12 files changed

+228
-68
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 86 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3647,23 +3647,87 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
36473647
return std::nullopt;
36483648
}
36493649

3650+
static void printInitializationList(OpAsmPrinter &parser,
3651+
Block::BlockArgListType blocksArgs,
3652+
ValueRange initializers,
3653+
StringRef prefix = "") {
3654+
assert(blocksArgs.size() == initializers.size() &&
3655+
"expected same length of arguments and initializers");
3656+
if (initializers.empty())
3657+
return;
3658+
3659+
parser << prefix << '(';
3660+
llvm::interleaveComma(
3661+
llvm::zip(blocksArgs, initializers), parser,
3662+
[&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3663+
parser << ")";
3664+
}
3665+
36503666
// parse and print of IfOp refer to the implementation of SCF dialect.
36513667
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
36523668
// Create the regions for 'then'.
36533669
result.regions.reserve(2);
36543670
Region *thenRegion = result.addRegion();
36553671
Region *elseRegion = result.addRegion();
36563672

3657-
auto &builder = parser.getBuilder();
36583673
OpAsmParser::UnresolvedOperand cond;
3659-
// Create a i1 tensor type for the boolean condition.
3660-
Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
3661-
if (parser.parseOperand(cond) ||
3662-
parser.resolveOperand(cond, i1Type, result.operands))
3674+
3675+
if (parser.parseOperand(cond))
36633676
return failure();
3664-
// Parse optional results type list.
3665-
if (parser.parseOptionalArrowTypeList(result.types))
3677+
3678+
SmallVector<OpAsmParser::Argument, 4> regionArgs;
3679+
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3680+
3681+
// Parse the optional block arguments
3682+
OptionalParseResult listResult =
3683+
parser.parseOptionalAssignmentList(regionArgs, operands);
3684+
if (listResult.has_value() && failed(listResult.value()))
36663685
return failure();
3686+
3687+
// Parse a colon.
3688+
if (failed(parser.parseColon()))
3689+
return parser.emitError(parser.getCurrentLocation(),
3690+
"expected type for condition operand");
3691+
3692+
// Parse the type of the condition operand
3693+
Type condType;
3694+
if (failed(parser.parseType(condType)))
3695+
return parser.emitError(parser.getCurrentLocation(),
3696+
"expected type for condition operand");
3697+
3698+
// Resolve operand with provided type
3699+
if (failed(parser.resolveOperand(cond, condType, result.operands)))
3700+
return failure();
3701+
3702+
// Parse optional block arg types
3703+
if (listResult.has_value()) {
3704+
FunctionType functionType;
3705+
3706+
if (failed(parser.parseType(functionType)))
3707+
return parser.emitError(parser.getCurrentLocation())
3708+
<< "expected list of types for block arguments "
3709+
<< "followed by arrow type and list of return types";
3710+
3711+
result.addTypes(functionType.getResults());
3712+
3713+
if (functionType.getNumInputs() != operands.size()) {
3714+
return parser.emitError(parser.getCurrentLocation())
3715+
<< "expected as many input types as operands "
3716+
<< "(expected " << operands.size() << " got "
3717+
<< functionType.getNumInputs() << ")";
3718+
}
3719+
3720+
// Resolve input operands.
3721+
if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3722+
parser.getCurrentLocation(),
3723+
result.operands)))
3724+
return failure();
3725+
} else {
3726+
// Parse optional results type list.
3727+
if (parser.parseOptionalArrowTypeList(result.types))
3728+
return failure();
3729+
}
3730+
36673731
// Parse the 'then' region.
36683732
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
36693733
return failure();
@@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
36813745
}
36823746

36833747
void IfOp::print(OpAsmPrinter &p) {
3684-
bool printBlockTerminators = false;
3685-
36863748
p << " " << getCondition();
3687-
if (!getResults().empty()) {
3688-
p << " -> (" << getResultTypes() << ")";
3689-
// Print yield explicitly if the op defines values.
3690-
printBlockTerminators = true;
3749+
3750+
printInitializationList(p, getThenGraph().front().getArguments(),
3751+
getInputList(), " ");
3752+
p << " : ";
3753+
p << getCondition().getType();
3754+
3755+
if (!getInputList().empty()) {
3756+
p << " (";
3757+
llvm::interleaveComma(getInputList().getTypes(), p);
3758+
p << ")";
36913759
}
3692-
p << ' ';
3693-
p.printRegion(getThenGraph(),
3694-
/*printEntryBlockArgs=*/false,
3695-
/*printBlockTerminators=*/printBlockTerminators);
3760+
p.printArrowTypeList(getResultTypes());
3761+
p << " ";
3762+
3763+
p.printRegion(getThenGraph());
36963764

36973765
// Print the 'else' regions if it exists and has a block.
36983766
auto &elseRegion = getElseGraph();
36993767
if (!elseRegion.empty()) {
37003768
p << " else ";
3701-
p.printRegion(elseRegion,
3702-
/*printEntryBlockArgs=*/false,
3703-
/*printBlockTerminators=*/printBlockTerminators);
3769+
p.printRegion(elseRegion);
37043770
}
37053771

37063772
p.printOptionalAttrDict((*this)->getAttrs());
@@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
39093975
parser.parseOptionalAttrDictWithKeyword(result.attributes));
39103976
}
39113977

3912-
static void printInitializationList(OpAsmPrinter &parser,
3913-
Block::BlockArgListType blocksArgs,
3914-
ValueRange initializers,
3915-
StringRef prefix = "") {
3916-
assert(blocksArgs.size() == initializers.size() &&
3917-
"expected same length of arguments and initializers");
3918-
if (initializers.empty())
3919-
return;
3920-
3921-
parser << prefix << '(';
3922-
llvm::interleaveComma(
3923-
llvm::zip(blocksArgs, initializers), parser,
3924-
[&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3925-
parser << ")";
3926-
}
3927-
39283978
void WhileOp::print(OpAsmPrinter &parser) {
39293979
printInitializationList(parser, getCondGraph().front().getArguments(),
39303980
getInputList(), " ");

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) {
12481248
// })
12491249
//
12501250
// Simplified:
1251-
// %0 = tosa.cond_if %arg2 {
1252-
// tosa.yield %arg0
1251+
// %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
1252+
// ^bb0(%arg3, %arg4):
1253+
// tosa.yield %arg3
12531254
// } else {
1254-
// tosa.yield %arg1
1255+
// ^bb0(%arg3, %arg4):
1256+
// tosa.yield %arg4
12551257
// }
1256-
//
1257-
// Unfortunately, the simplified syntax does not encapsulate values
1258-
// used in then/else regions (see 'simplified' example above), so it
1259-
// must be rewritten to use the generic syntax in order to be conformant
1260-
// to the specification.
1258+
12611259
return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
12621260
failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
12631261
}

mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
3636
func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
3737
// CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
3838
// CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
39-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
39+
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
4040

4141
// CHECK: scf.yield [[ARG0]]
4242
tosa.yield %arg0 : tensor<f32>

mlir/test/Dialect/Tosa/availability.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
645645
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
646646
// CHECK: tosa.cond_if profiles: [ ]
647647
// CHECK: tosa.cond_if extensions: [ [controlflow] ]
648-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
648+
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
649649
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
650650
tosa.yield %1 : tensor<f32>
651651
} else {
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: mlir-opt -split-input-file %s | FileCheck %s
2+
3+
// -----
4+
5+
func.func @condif_cond_type_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
6+
// CHECK: tosa.cond_if %[[ARG2:.*]] : tensor<i1> -> tensor<f32> {
7+
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
8+
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
9+
tosa.yield %1 : tensor<f32>
10+
// CHECK: } else {
11+
} else {
12+
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
13+
tosa.yield %1 : tensor<f32>
14+
}
15+
return %0 : tensor<f32>
16+
}
17+
18+
// -----
19+
20+
func.func @condif_block_args_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
21+
// CHECK: tosa.cond_if %[[ARG2:.*]] (%[[ARG3:.*]] = %[[ARG0:.*]], %[[ARG4:.*]] = %[[ARG1:.*]]) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
22+
// CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>):
23+
%0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
24+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
25+
%1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
26+
tosa.yield %1 : tensor<f32>
27+
// CHECK: } else {
28+
// CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>):
29+
} else {
30+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
31+
%1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
32+
tosa.yield %1 : tensor<f32>
33+
}
34+
return %0 : tensor<f32>
35+
}

mlir/test/Dialect/Tosa/error_if_check.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1:
259259

260260
func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
261261
// expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}}
262-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
262+
%0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>) {
263263
tosa.yield %arg0 : tensor<f32>
264264
} else {
265265
tosa.yield %arg1 : tensor<f32>

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
337337
// -----
338338
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
339339
// expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
340-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
340+
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
341341
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
342342
tosa.yield %1 : tensor<f32>
343343
} else {

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,13 +1506,13 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
15061506
// -----
15071507

15081508
func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
1509-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
1510-
%1 = tosa.cond_if %arg3 -> (tensor<f32>) {
1511-
%2 = tosa.cond_if %arg2 -> (tensor<f32>) {
1512-
%3 = tosa.cond_if %arg3 -> (tensor<f32>) {
1513-
%4 = tosa.cond_if %arg2 -> (tensor<f32>) {
1509+
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
1510+
%1 = tosa.cond_if %arg3 : tensor<i1>-> tensor<f32> {
1511+
%2 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
1512+
%3 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
1513+
%4 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
15141514
// expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
1515-
%5 = tosa.cond_if %arg3 -> (tensor<f32>) {
1515+
%5 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
15161516
%res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
15171517
tosa.yield %res : tensor<f32>
15181518
} else {

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -839,7 +839,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
839839
// -----
840840
// CHECK-LABEL: cond_if
841841
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
842-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
842+
%0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
843843
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
844844
tosa.yield %1 : tensor<f32>
845845
} else {

mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
5454
// CHECK-LABEL: test_regions
5555
// CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
5656
func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
57-
// CHECK: tosa.cond_if %arg2 -> (tensor<i8>)
57+
// CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8>
5858
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
5959
^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
6060
// CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>

0 commit comments

Comments
 (0)