Skip to content

Commit 5fae9dd

Browse files
committed
[mlir][tosa] Stop support the custom simplified form of COND_IF
Since the tensor_list_shape for input_list, output_list, then_graph, and else_graph is required to be equal according to the spec, this information must be explicitly provided during operation construction. The current custom simplified form does not meet this requirement. For example, the input_list and output_list can be empty in the simplified form. A new compatible simplified form will be introduced in the future if necessary.
1 parent 53fe3df commit 5fae9dd

File tree

9 files changed

+157
-152
lines changed

9 files changed

+157
-152
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2558,7 +2558,6 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
25582558
SizedRegion<1>:$else_graph
25592559
);
25602560

2561-
let hasCustomAssemblyFormat = 1;
25622561
let hasVerifier = 1;
25632562
}
25642563

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3518,65 +3518,6 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
35183518
return std::nullopt;
35193519
}
35203520

3521-
// parse and print of IfOp refer to the implementation of SCF dialect.
3522-
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3523-
// Create the regions for 'then'.
3524-
result.regions.reserve(2);
3525-
Region *thenRegion = result.addRegion();
3526-
Region *elseRegion = result.addRegion();
3527-
3528-
auto &builder = parser.getBuilder();
3529-
OpAsmParser::UnresolvedOperand cond;
3530-
// Create a i1 tensor type for the boolean condition.
3531-
Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
3532-
if (parser.parseOperand(cond) ||
3533-
parser.resolveOperand(cond, i1Type, result.operands))
3534-
return failure();
3535-
// Parse optional results type list.
3536-
if (parser.parseOptionalArrowTypeList(result.types))
3537-
return failure();
3538-
// Parse the 'then' region.
3539-
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
3540-
return failure();
3541-
3542-
// If we find an 'else' keyword then parse the 'else' region.
3543-
if (!parser.parseOptionalKeyword("else")) {
3544-
if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
3545-
return failure();
3546-
}
3547-
3548-
// Parse the optional attribute list.
3549-
if (parser.parseOptionalAttrDict(result.attributes))
3550-
return failure();
3551-
return success();
3552-
}
3553-
3554-
void IfOp::print(OpAsmPrinter &p) {
3555-
bool printBlockTerminators = false;
3556-
3557-
p << " " << getCondition();
3558-
if (!getResults().empty()) {
3559-
p << " -> (" << getResultTypes() << ")";
3560-
// Print yield explicitly if the op defines values.
3561-
printBlockTerminators = true;
3562-
}
3563-
p << ' ';
3564-
p.printRegion(getThenGraph(),
3565-
/*printEntryBlockArgs=*/false,
3566-
/*printBlockTerminators=*/printBlockTerminators);
3567-
3568-
// Print the 'else' regions if it exists and has a block.
3569-
auto &elseRegion = getElseGraph();
3570-
if (!elseRegion.empty()) {
3571-
p << " else ";
3572-
p.printRegion(elseRegion,
3573-
/*printEntryBlockArgs=*/false,
3574-
/*printBlockTerminators=*/printBlockTerminators);
3575-
}
3576-
3577-
p.printOptionalAttrDict((*this)->getAttrs());
3578-
}
3579-
35803521
LogicalResult IfOp::verify() {
35813522
if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
35823523
"'then_graph' arguments", getInputList(),

mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,15 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
3636
func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
3737
// CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
3838
// CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
39-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
40-
41-
// CHECK: scf.yield [[ARG0]]
42-
tosa.yield %arg0 : tensor<f32>
43-
44-
// CHECK: } else {
45-
} else {
46-
47-
// CHECK: scf.yield [[ARG1]]
48-
tosa.yield %arg1 : tensor<f32>
49-
50-
// CHECK: }
51-
// CHECK: return [[IF]]
52-
}
39+
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
40+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
41+
// CHECK: scf.yield [[ARG0]]
42+
tosa.yield %arg3 : tensor<f32>
43+
}, {
44+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
45+
// CHECK: scf.yield [[ARG1]]
46+
tosa.yield %arg4 : tensor<f32>
47+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
5348

5449
return %0 : tensor<f32>
5550
}

mlir/test/Dialect/Tosa/availability.mlir

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -645,13 +645,15 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
645645
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
646646
// CHECK: tosa.cond_if profiles: [ ]
647647
// CHECK: tosa.cond_if extensions: [ [controlflow] ]
648-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
649-
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
648+
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
649+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
650+
%1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
650651
tosa.yield %1 : tensor<f32>
651-
} else {
652-
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
652+
}, {
653+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
654+
%1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
653655
tosa.yield %1 : tensor<f32>
654-
}
656+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
655657
return %0 : tensor<f32>
656658
}
657659

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,15 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
337337
// -----
338338
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
339339
// expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
340-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
340+
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
341+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
341342
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
342343
tosa.yield %1 : tensor<f32>
343-
} else {
344+
}, {
345+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
344346
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
345347
tosa.yield %1 : tensor<f32>
346-
}
348+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
347349
return %0 : tensor<f32>
348350
}
349351

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,40 +1503,87 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
15031503

15041504
// -----
15051505

1506-
func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
1507-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
1508-
%1 = tosa.cond_if %arg3 -> (tensor<f32>) {
1509-
%2 = tosa.cond_if %arg2 -> (tensor<f32>) {
1510-
%3 = tosa.cond_if %arg3 -> (tensor<f32>) {
1511-
%4 = tosa.cond_if %arg2 -> (tensor<f32>) {
1506+
func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
1507+
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
1508+
1509+
// COM: then graph of IF-1
1510+
^bb1(%a1: tensor<f32>, %b1: tensor<f32>):
1511+
%cond1 = tosa.equal %a1, %b1 : (tensor<f32>, tensor<f32>) -> tensor<i1>
1512+
%1 = "tosa.cond_if"(%cond1, %a1, %b1) ({
1513+
1514+
// COM: then graph of IF-2
1515+
^bb2(%a2: tensor<f32>, %b2: tensor<f32>):
1516+
%cond2 = tosa.equal %a2, %b2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
1517+
%2 = "tosa.cond_if"(%cond2, %a2, %b2) ({
1518+
1519+
// COM: then graph of IF-3
1520+
^bb3(%a3: tensor<f32>, %b3: tensor<f32>):
1521+
%cond3 = tosa.equal %a3, %b3 : (tensor<f32>, tensor<f32>) -> tensor<i1>
1522+
%3 = "tosa.cond_if"(%cond3, %a3, %b3) ({
1523+
1524+
// COM: then graph of IF-4
1525+
^bb4(%a4: tensor<f32>, %b4: tensor<f32>):
1526+
%cond4 = tosa.equal %a4, %b4 : (tensor<f32>, tensor<f32>) -> tensor<i1>
1527+
%4 = "tosa.cond_if"(%cond4, %a4, %b4) ({
1528+
1529+
// COM: then graph of IF-5
1530+
^bb5(%a5: tensor<f32>, %b5: tensor<f32>):
1531+
%cond5 = tosa.equal %a5, %b5 : (tensor<f32>, tensor<f32>) -> tensor<i1>
15121532
// expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
1513-
%5 = tosa.cond_if %arg3 -> (tensor<f32>) {
1533+
%5 = "tosa.cond_if"(%cond5, %a5, %b5) ({
1534+
1535+
// COM: then graph of IF-6
1536+
^bb6(%a6: tensor<f32>, %b6: tensor<f32>):
15141537
%res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
15151538
tosa.yield %res : tensor<f32>
1516-
} else {
1539+
}, {
1540+
1541+
// COM: else graph of IF-6
1542+
^bb6(%a6: tensor<f32>, %b6: tensor<f32>):
15171543
tosa.yield %arg0 : tensor<f32>
1518-
}
1544+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
1545+
15191546
tosa.yield %5 : tensor<f32>
1520-
} else {
1547+
}, {
1548+
1549+
// COM: else graph of IF-5
1550+
^bb5(%a5: tensor<f32>, %b5: tensor<f32>):
1551+
tosa.yield %arg0 : tensor<f32>
1552+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
1553+
1554+
tosa.yield %4 : tensor<f32>
1555+
}, {
1556+
1557+
// COM: else graph of IF-4
1558+
^bb4(%a4: tensor<f32>, %b4: tensor<f32>):
15211559
tosa.yield %arg0 : tensor<f32>
1522-
}
1523-
tosa.yield %4 : tensor<f32>
1524-
} else {
1525-
tosa.yield %arg0 : tensor<f32>
1526-
}
1527-
tosa.yield %3 : tensor<f32>
1528-
} else {
1560+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
1561+
1562+
tosa.yield %3 : tensor<f32>
1563+
}, {
1564+
1565+
// COM: else graph of IF-3
1566+
^bb3(%a3: tensor<f32>, %b3: tensor<f32>):
15291567
tosa.yield %arg0 : tensor<f32>
1530-
}
1531-
tosa.yield %2 : tensor<f32>
1532-
} else {
1533-
tosa.yield %arg0 : tensor<f32>
1534-
}
1535-
tosa.yield %1 : tensor<f32>
1536-
} else {
1537-
%res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
1568+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
1569+
1570+
tosa.yield %2 : tensor<f32>
1571+
}, {
1572+
1573+
// COM: else graph of IF-2
1574+
^bb2(%a2: tensor<f32>, %b2: tensor<f32>):
1575+
tosa.yield %arg0 : tensor<f32>
1576+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
1577+
1578+
tosa.yield %1 : tensor<f32>
1579+
}, {
1580+
1581+
// COM: else graph of IF-1
1582+
^bb1(%a1: tensor<f32>, %b1: tensor<f32>):
1583+
%res = tosa.sub %a1, %b1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
15381584
tosa.yield %res : tensor<f32>
1539-
}
1585+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
1586+
15401587
return %0 : tensor<f32>
15411588
}
15421589

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -781,13 +781,15 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
781781
// -----
782782
// CHECK-LABEL: cond_if
783783
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
784-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
784+
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
785+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
785786
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
786787
tosa.yield %1 : tensor<f32>
787-
} else {
788+
}, {
789+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
788790
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
789791
tosa.yield %1 : tensor<f32>
790-
}
792+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
791793
return %0 : tensor<f32>
792794
}
793795

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,12 +1121,14 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
11211121
%b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32>
11221122

11231123
// CHECK: tosa.cond_if
1124-
// CHECK: -> (tensor<f32>)
1125-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
1126-
tosa.yield %a : tensor<f32>
1127-
} else {
1128-
tosa.yield %b : tensor<f32>
1129-
}
1124+
// CHECK: -> tensor<f32>
1125+
%0 = "tosa.cond_if"(%arg2, %a, %b) ({
1126+
^bb0(%a1: tensor<f32>, %b1: tensor<f32>):
1127+
tosa.yield %a1 : tensor<f32>
1128+
}, {
1129+
^bb0(%a1: tensor<f32>, %b1: tensor<f32>):
1130+
tosa.yield %b1 : tensor<f32>
1131+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
11301132
return
11311133
}
11321134

@@ -1135,12 +1137,14 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
11351137
// CHECK-LABEL: @if_test_dynamic
11361138
func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
11371139
// CHECK: tosa.cond_if
1138-
// CHECK: -> (tensor<?xf32>)
1139-
%0 = tosa.cond_if %arg2 -> (tensor<?xf32>) {
1140-
tosa.yield %arg0 : tensor<2xf32>
1141-
} else {
1142-
tosa.yield %arg1 : tensor<3xf32>
1143-
}
1140+
// CHECK: -> tensor<?xf32>
1141+
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
1142+
^bb0(%a: tensor<2xf32>, %b: tensor<3xf32>):
1143+
tosa.yield %a : tensor<2xf32>
1144+
}, {
1145+
^bb0(%a: tensor<2xf32>, %b: tensor<3xf32>):
1146+
tosa.yield %b : tensor<3xf32>
1147+
}) : (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> tensor<?xf32>
11441148
return
11451149
}
11461150

@@ -1149,12 +1153,14 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 :
11491153
// CHECK-LABEL: @if_test_unranked
11501154
func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
11511155
// CHECK: tosa.cond_if
1152-
// CHECK: -> (tensor<*xf32>)
1153-
%0 = tosa.cond_if %arg2 -> (tensor<*xf32>) {
1154-
tosa.yield %arg0 : tensor<f32>
1155-
} else {
1156-
tosa.yield %arg1 : tensor<3xf32>
1157-
}
1156+
// CHECK: -> tensor<*xf32>
1157+
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
1158+
^bb0(%a: tensor<f32>, %b: tensor<3xf32>):
1159+
tosa.yield %a : tensor<f32>
1160+
}, {
1161+
^bb0(%a: tensor<f32>, %b: tensor<3xf32>):
1162+
tosa.yield %b : tensor<3xf32>
1163+
}) : (tensor<i1>, tensor<f32>, tensor<3xf32>) -> tensor<*xf32>
11581164
return
11591165
}
11601166

@@ -1163,14 +1169,16 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 :
11631169
// CHECK-LABEL: @if_test_propagate
11641170
func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
11651171
// CHECK: tosa.cond_if
1166-
// CHECK: -> (tensor<f32>)
1167-
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
1168-
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
1172+
// CHECK: -> tensor<f32>
1173+
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
1174+
^bb0(%a: tensor<f32>, %b: tensor<f32>):
1175+
%1 = tosa.add %a, %b : (tensor<f32>, tensor<f32>) -> tensor<f32>
11691176
tosa.yield %1 : tensor<f32>
1170-
} else {
1171-
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
1177+
}, {
1178+
^bb0(%a: tensor<f32>, %b: tensor<f32>):
1179+
%1 = tosa.sub %a, %b : (tensor<f32>, tensor<f32>) -> tensor<f32>
11721180
tosa.yield %1 : tensor<f32>
1173-
}
1181+
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
11741182
return
11751183
}
11761184

0 commit comments

Comments
 (0)