diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index ecd93ff4c6e7b..3cafb199d2db3 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -3647,6 +3647,22 @@ std::optional> ApplyScaleOp::getShapeForUnroll() { return std::nullopt; } +static void printInitializationList(OpAsmPrinter &parser, + Block::BlockArgListType blocksArgs, + ValueRange initializers, + StringRef prefix = "") { + assert(blocksArgs.size() == initializers.size() && + "expected same length of arguments and initializers"); + if (initializers.empty()) + return; + + parser << prefix << '('; + llvm::interleaveComma( + llvm::zip(blocksArgs, initializers), parser, + [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); + parser << ")"; +} + // parse and print of IfOp refer to the implementation of SCF dialect. ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions for 'then'. @@ -3654,16 +3670,64 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); - auto &builder = parser.getBuilder(); OpAsmParser::UnresolvedOperand cond; - // Create a i1 tensor type for the boolean condition. - Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1)); - if (parser.parseOperand(cond) || - parser.resolveOperand(cond, i1Type, result.operands)) + + if (parser.parseOperand(cond)) return failure(); - // Parse optional results type list. - if (parser.parseOptionalArrowTypeList(result.types)) + + SmallVector regionArgs; + SmallVector operands; + + // Parse the optional block arguments + OptionalParseResult listResult = + parser.parseOptionalAssignmentList(regionArgs, operands); + if (listResult.has_value() && failed(listResult.value())) return failure(); + + // Parse a colon. + if (failed(parser.parseColon())) + return parser.emitError(parser.getCurrentLocation(), + "expected type for condition operand"); + + // Parse the type of the condition operand + Type condType; + if (failed(parser.parseType(condType))) + return parser.emitError(parser.getCurrentLocation(), + "expected type for condition operand"); + + // Resolve operand with provided type + if (failed(parser.resolveOperand(cond, condType, result.operands))) + return failure(); + + // Parse optional block arg types + if (listResult.has_value()) { + FunctionType functionType; + + if (failed(parser.parseType(functionType))) + return parser.emitError(parser.getCurrentLocation()) + << "expected list of types for block arguments " + << "followed by arrow type and list of return types"; + + result.addTypes(functionType.getResults()); + + if (functionType.getNumInputs() != operands.size()) { + return parser.emitError(parser.getCurrentLocation()) + << "expected as many input types as operands " + << "(expected " << operands.size() << " got " + << functionType.getNumInputs() << ")"; + } + + // Resolve input operands. + if (failed(parser.resolveOperands(operands, functionType.getInputs(), + parser.getCurrentLocation(), + result.operands))) + return failure(); + } else { + // Parse optional results type list. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + } + // Parse the 'then' region. if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); @@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { } void IfOp::print(OpAsmPrinter &p) { - bool printBlockTerminators = false; - p << " " << getCondition(); - if (!getResults().empty()) { - p << " -> (" << getResultTypes() << ")"; - // Print yield explicitly if the op defines values. - printBlockTerminators = true; + + printInitializationList(p, getThenGraph().front().getArguments(), + getInputList(), " "); + p << " : "; + p << getCondition().getType(); + + if (!getInputList().empty()) { + p << " ("; + llvm::interleaveComma(getInputList().getTypes(), p); + p << ")"; } - p << ' '; - p.printRegion(getThenGraph(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + p.printArrowTypeList(getResultTypes()); + p << " "; + + p.printRegion(getThenGraph()); // Print the 'else' regions if it exists and has a block. auto &elseRegion = getElseGraph(); if (!elseRegion.empty()) { p << " else "; - p.printRegion(elseRegion, - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + p.printRegion(elseRegion); } p.printOptionalAttrDict((*this)->getAttrs()); @@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { parser.parseOptionalAttrDictWithKeyword(result.attributes)); } -static void printInitializationList(OpAsmPrinter &parser, - Block::BlockArgListType blocksArgs, - ValueRange initializers, - StringRef prefix = "") { - assert(blocksArgs.size() == initializers.size() && - "expected same length of arguments and initializers"); - if (initializers.empty()) - return; - - parser << prefix << '('; - llvm::interleaveComma( - llvm::zip(blocksArgs, initializers), parser, - [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); - parser << ")"; -} - void WhileOp::print(OpAsmPrinter &parser) { printInitializationList(parser, getCondGraph().front().getArguments(), getInputList(), " "); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 32b5fb63a6ece..8ec77654fb896 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) { // }) // // Simplified: - // %0 = tosa.cond_if %arg2 { - // tosa.yield %arg0 + // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) { + // ^bb0(%arg3, %arg4): + // tosa.yield %arg3 // } else { - // tosa.yield %arg1 + // ^bb0(%arg3, %arg4): + // tosa.yield %arg4 // } - // - // Unfortunately, the simplified syntax does not encapsulate values - // used in then/else regions (see 'simplified' example above), so it - // must be rewritten to use the generic syntax in order to be conformant - // to the specification. + return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) || failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")); } diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir index fa7a91cda0a47..b6f2383ac81fc 100644 --- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir +++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir @@ -36,7 +36,7 @@ func.func @while_test(%arg0 : tensor) -> (tensor) { func.func @if_test(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> (tensor) { // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]] // CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor) { - %0 = tosa.cond_if %arg2 -> (tensor) { + %0 = tosa.cond_if %arg2 : tensor -> tensor { // CHECK: scf.yield [[ARG0]] tosa.yield %arg0 : tensor diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 0176fc2883518..6398161126e80 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -645,7 +645,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { func.func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // CHECK: tosa.cond_if profiles: [ ] // CHECK: tosa.cond_if extensions: [ [controlflow] ] - %0 = tosa.cond_if %arg2 -> (tensor) { + %0 = tosa.cond_if %arg2 : tensor -> tensor { %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor } else { diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir new file mode 100644 index 0000000000000..06312c7ba2df6 --- /dev/null +++ b/mlir/test/Dialect/Tosa/controlflow.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt -split-input-file %s | FileCheck %s + +// ----- + +func.func @condif_cond_type_check(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: tosa.cond_if %[[ARG2:.*]] : tensor -> tensor { + %0 = tosa.cond_if %arg2 : tensor -> tensor { + %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + // CHECK: } else { + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } + return %0 : tensor +} + +// ----- + +func.func @condif_block_args_check(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: tosa.cond_if %[[ARG2:.*]] (%[[ARG3:.*]] = %[[ARG0:.*]], %[[ARG4:.*]] = %[[ARG1:.*]]) : tensor (tensor, tensor) -> tensor { + // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor, %[[ARG4]]: tensor): + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor (tensor, tensor) -> tensor { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.add %arg3, %arg4 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + // CHECK: } else { + // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor, %[[ARG4]]: tensor): + } else { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.sub %arg3, %arg4 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } + return %0 : tensor +} diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index eb25011ff3a9d..fad1bec0e3ecc 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -259,7 +259,7 @@ func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor, %arg1: func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}} - %0 = tosa.cond_if %arg2 -> (tensor) { + %0 = tosa.cond_if %arg2 : tensor -> (tensor) { tosa.yield %arg0 : tensor } else { tosa.yield %arg1 : tensor diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 5630c33639d86..3154f541e0519 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -337,7 +337,7 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32 // ----- func.func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}} - %0 = tosa.cond_if %arg2 -> (tensor) { + %0 = tosa.cond_if %arg2 : tensor -> tensor { %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor } else { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 0dddf26fb1f85..cbe0056bafe22 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1506,13 +1506,13 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: // ----- func.func @test_cond_if_max_nested_depth(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { - %0 = tosa.cond_if %arg2 -> (tensor) { - %1 = tosa.cond_if %arg3 -> (tensor) { - %2 = tosa.cond_if %arg2 -> (tensor) { - %3 = tosa.cond_if %arg3 -> (tensor) { - %4 = tosa.cond_if %arg2 -> (tensor) { + %0 = tosa.cond_if %arg2 : tensor -> tensor { + %1 = tosa.cond_if %arg3 : tensor-> tensor { + %2 = tosa.cond_if %arg2 : tensor -> tensor { + %3 = tosa.cond_if %arg3 : tensor -> tensor { + %4 = tosa.cond_if %arg2 : tensor -> tensor { // expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}} - %5 = tosa.cond_if %arg3 -> (tensor) { + %5 = tosa.cond_if %arg3 : tensor -> tensor { %res = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %res : tensor } else { diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index ef51197e86d56..30361a882afe5 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -839,7 +839,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { // ----- // CHECK-LABEL: cond_if func.func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - %0 = tosa.cond_if %arg2 -> (tensor) { + %0 = tosa.cond_if %arg2 : tensor -> tensor { %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor } else { diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir index 38ac8d8fb66d9..e957bdd15e1ec 100644 --- a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir +++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir @@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { // CHECK-LABEL: test_regions // CHECK: %arg0: tensor, %arg1: tensor func.func @test_regions(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK: tosa.cond_if %arg2 -> (tensor) + // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor (tensor, tensor) -> tensor %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ ^bb0(%arg3: tensor, %arg4: tensor): // CHECK: %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 9d43f8998da93..ece4bf8017abd 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1166,8 +1166,8 @@ func.func @if_test_simple(%arg0 : tensor, %arg1 : tensor, %arg2 : tens %b = tosa.log %arg1 : (tensor) -> tensor // CHECK: tosa.cond_if - // CHECK: -> (tensor) - %0 = tosa.cond_if %arg2 -> (tensor) { + // CHECK: -> tensor + %0 = tosa.cond_if %arg2 : tensor -> tensor { tosa.yield %a : tensor } else { tosa.yield %b : tensor @@ -1180,8 +1180,8 @@ func.func @if_test_simple(%arg0 : tensor, %arg1 : tensor, %arg2 : tens // CHECK-LABEL: @if_test_dynamic func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor) - %0 = tosa.cond_if %arg2 -> (tensor) { + // CHECK: -> tensor + %0 = tosa.cond_if %arg2 : tensor -> tensor { tosa.yield %arg0 : tensor<2xf32> } else { tosa.yield %arg1 : tensor<3xf32> @@ -1194,8 +1194,8 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_unranked func.func @if_test_unranked(%arg0 : tensor, %arg1 : tensor<3xf32>, %arg2 : tensor) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<*xf32>) - %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) { + // CHECK: -> tensor<*xf32> + %0 = tosa.cond_if %arg2 : tensor -> tensor<*xf32> { tosa.yield %arg0 : tensor } else { tosa.yield %arg1 : tensor<3xf32> @@ -1208,8 +1208,8 @@ func.func @if_test_unranked(%arg0 : tensor, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_propagate func.func @if_test_propagate(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor) - %0 = tosa.cond_if %arg2 -> (tensor) { + // CHECK: -> tensor + %0 = tosa.cond_if %arg2 : tensor -> tensor { %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor } else { diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index b3052369b055e..2a937b0a88f28 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -500,9 +500,39 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor, %ar // ----- +func.func @test_cond_if_input_list_mismatch_else_block_simple(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor (tensor, tensor) -> tensor { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.add %arg3, %arg4 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } else { + ^bb0(%arg3: tensor): + tosa.yield %arg3 : tensor + } + return %0 : tensor +} + +// ----- + +func.func @test_cond_if_input_list_mismatch_else_block_simple_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0) : tensor (tensor) -> tensor { + ^bb0(%arg3: tensor): + tosa.yield %arg3 : tensor + } else { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.sub %arg3, %arg4 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } + return %0 : tensor +} + +// ----- + func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}} - %0 = tosa.cond_if %arg2 -> (tensor) { + %0 = tosa.cond_if %arg2 : tensor -> tensor { %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor %2 = tosa.add %1, %arg1 : (tensor, tensor) -> tensor tosa.yield %1, %2 : tensor, tensor @@ -517,7 +547,7 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor, %arg func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}} - %0, %2 = tosa.cond_if %arg2 -> (tensor, tensor) { + %0, %2 = tosa.cond_if %arg2 : tensor -> (tensor, tensor) { %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor } else { @@ -531,7 +561,7 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor, %a func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}} - %0 = tosa.cond_if %arg2 -> (tensor) { + %0 = tosa.cond_if %arg2 : tensor -> tensor { %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor } else { @@ -546,7 +576,7 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor, %arg func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}} - %0, %2 = tosa.cond_if %arg2 -> (tensor, tensor) { + %0, %2 = tosa.cond_if %arg2 : tensor -> (tensor, tensor) { %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor %2 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1, %2 : tensor, tensor @@ -574,6 +604,53 @@ func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor, %arg1: tenso // ----- +// CHECK-LABEL: cond_if_cond_type +func.func @test_cond_if_cond_type(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+2 {{expected ':'}} + // expected-error@+1 {{custom op 'tosa.cond_if' expected type for condition operand}} + %0 = tosa.cond_if %arg2 -> (tensor) { + tosa.yield %arg0 : tensor + } else { + tosa.yield %arg1 : tensor + } + return %0 : tensor +} + +// ----- + +func.func @test_cond_if_input_list_type_mismatch_simple(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{custom op 'tosa.cond_if' expected as many input types as operands (expected 2 got 0)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor () -> tensor { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.add %arg3, %arg4 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } else { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.sub %arg3, %arg4 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } + return %0 : tensor +} + +// ----- + +func.func @test_cond_if_incorrect_type_simple(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+2 {{expected non-function type}} + // expected-error@+1 {{custom op 'tosa.cond_if' expected list of types for block arguments followed by arrow type and list of return types}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor (%arg3) -> tensor { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.add %arg3, %arg4 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } else { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.sub %arg3, %arg4 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } + return %0 : tensor +} + +// ----- + func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor) { %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (3) and 'input_list' (2)}}