Skip to content

Conversation

@Yuvaraj-Venkatesh
Copy link
Contributor

…or and optional block arguments

This change extends the TOSA cond_if operation's print and parse logic to handle the following:

  • The condition operand may now have any rank, as long as the total number of elements sums to 1.

    %1 = tosa.cond_if %0 : tensor<1x1x1xi1> -> tensor<4xf32>

  • The then and else regions can now include optional block arguments. The updated IR syntax reflects this:

    %1 = tosa.cond_if %0 (%arg2 = %arg0, %arg3 = %arg1) : tensor (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>

  • Removed parentheses around single result types in the printed representation, aligning with the AsmPrinter conventions.

@github-actions
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Yuvaraj Venkatesh (Yuvaraj-Venkatesh)

Changes

…or and optional block arguments

This change extends the TOSA cond_if operation's print and parse logic to handle the following:

  • The condition operand may now have any rank, as long as the total number of elements sums to 1.

    %1 = tosa.cond_if %0 : tensor<1x1x1xi1> -> tensor<4xf32>

  • The then and else regions can now include optional block arguments. The updated IR syntax reflects this:

    %1 = tosa.cond_if %0 (%arg2 = %arg0, %arg3 = %arg1) : tensor<i1> (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>

  • Removed parentheses around single result types in the printed representation, aligning with the AsmPrinter conventions.


Full diff: https://github.com/llvm/llvm-project/pull/149791.diff

10 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+91-36)
  • (modified) mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+30-1)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+6-6)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+8-8)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+32-4)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index f0ff430bae882..c708579f7dbb8 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3646,6 +3646,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
   return std::nullopt;
 }
 
+static void printInitializationList(OpAsmPrinter &parser,
+                                    Block::BlockArgListType blocksArgs,
+                                    ValueRange initializers,
+                                    StringRef prefix = "") {
+  assert(blocksArgs.size() == initializers.size() &&
+         "expected same length of arguments and initializers");
+  if (initializers.empty())
+    return;
+
+  parser << prefix << '(';
+  llvm::interleaveComma(
+      llvm::zip(blocksArgs, initializers), parser,
+      [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
+  parser << ")";
+}
+
 // parse and print of IfOp refer to the implementation of SCF dialect.
 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
   // Create the regions for 'then'.
@@ -3653,18 +3669,70 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
   Region *thenRegion = result.addRegion();
   Region *elseRegion = result.addRegion();
 
-  auto &builder = parser.getBuilder();
   OpAsmParser::UnresolvedOperand cond;
-  // Create a i1 tensor type for the boolean condition.
-  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
-  if (parser.parseOperand(cond) ||
-      parser.resolveOperand(cond, i1Type, result.operands))
+
+  if (parser.parseOperand(cond))
     return failure();
-  // Parse optional results type list.
-  if (parser.parseOptionalArrowTypeList(result.types))
+
+  SmallVector<OpAsmParser::Argument, 4> regionArgs;
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+
+  // Parse the optional block arguments
+  OptionalParseResult listResult =
+      parser.parseOptionalAssignmentList(regionArgs, operands);
+  if (listResult.has_value() && failed(listResult.value()))
     return failure();
+
+  // Parse a colon.
+  if (failed(parser.parseColon()))
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected type for condition operand");
+
+  // Parse the type of the operand
+  Type condType;
+  if (failed(parser.parseType(condType)))
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected type for condition operand");
+
+  // Resolve operand with provided type
+  if (failed(parser.resolveOperand(cond, condType, result.operands)))
+    return failure();
+
+  // Parse optional block arg types
+  if (listResult.has_value()) {
+    FunctionType functionType;
+
+    OptionalParseResult typeResult = parser.parseOptionalType(functionType);
+    if (typeResult.has_value() && failed(typeResult.value()))
+      return parser.emitError(parser.getCurrentLocation(),
+                              "expected type for block args");
+
+    result.addTypes(functionType.getResults());
+
+    if (functionType.getNumInputs() != operands.size()) {
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected as many input types as operands "
+             << "(expected " << operands.size() << " got "
+             << functionType.getNumInputs() << ")";
+    }
+
+    // Resolve input operands.
+    if (failed(parser.resolveOperands(operands, functionType.getInputs(),
+                                      parser.getCurrentLocation(),
+                                      result.operands)))
+      return failure();
+
+    // Propagate the types into the region arguments.
+    for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
+      regionArgs[i].type = functionType.getInput(i);
+  } else {
+    // Parse optional results type list.
+    if (parser.parseOptionalArrowTypeList(result.types))
+      return failure();
+  }
+
   // Parse the 'then' region.
-  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
+  if (parser.parseRegion(*thenRegion, regionArgs))
     return failure();
 
   // If we find an 'else' keyword then parse the 'else' region.
@@ -3680,26 +3748,29 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void IfOp::print(OpAsmPrinter &p) {
-  bool printBlockTerminators = false;
-
   p << " " << getCondition();
-  if (!getResults().empty()) {
-    p << " -> (" << getResultTypes() << ")";
-    // Print yield explicitly if the op defines values.
-    printBlockTerminators = true;
+
+  printInitializationList(p, getThenGraph().front().getArguments(),
+                          getInputList(), " ");
+  p << " : ";
+  p << getCondition().getType();
+
+  if (!getInputList().empty()) {
+    p << " (";
+    llvm::interleaveComma(getInputList().getTypes(), p);
+    p << ")";
   }
-  p << ' ';
+  p.printArrowTypeList(getResultTypes());
+  p << " ";
+
   p.printRegion(getThenGraph(),
-                /*printEntryBlockArgs=*/false,
-                /*printBlockTerminators=*/printBlockTerminators);
+                /*printEntryBlockArgs=*/false);
 
   // Print the 'else' regions if it exists and has a block.
   auto &elseRegion = getElseGraph();
   if (!elseRegion.empty()) {
     p << " else ";
-    p.printRegion(elseRegion,
-                  /*printEntryBlockArgs=*/false,
-                  /*printBlockTerminators=*/printBlockTerminators);
+    p.printRegion(elseRegion);
   }
 
   p.printOptionalAttrDict((*this)->getAttrs());
@@ -3908,22 +3979,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
                  parser.parseOptionalAttrDictWithKeyword(result.attributes));
 }
 
-static void printInitializationList(OpAsmPrinter &parser,
-                                    Block::BlockArgListType blocksArgs,
-                                    ValueRange initializers,
-                                    StringRef prefix = "") {
-  assert(blocksArgs.size() == initializers.size() &&
-         "expected same length of arguments and initializers");
-  if (initializers.empty())
-    return;
-
-  parser << prefix << '(';
-  llvm::interleaveComma(
-      llvm::zip(blocksArgs, initializers), parser,
-      [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
-  parser << ")";
-}
-
 void WhileOp::print(OpAsmPrinter &parser) {
   printInitializationList(parser, getCondGraph().front().getArguments(),
                           getInputList(), " ");
diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
index fa7a91cda0a47..b6f2383ac81fc 100644
--- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
@@ -36,7 +36,7 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
 func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
   // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
   // CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
 
   // CHECK:   scf.yield [[ARG0]]
     tosa.yield %arg0 : tensor<f32>
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 0176fc2883518..6398161126e80 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -645,7 +645,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // CHECK: tosa.cond_if profiles: [ ]
   // CHECK: tosa.cond_if extensions: [ [controlflow] ]
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
   } else {
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 8924dd9885827..4556801144e68 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -259,7 +259,7 @@ func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1:
 
 func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>) {
     tosa.yield %arg0 : tensor<f32>
   } else {
     tosa.yield %arg1 : tensor<f32>
@@ -280,3 +280,32 @@ func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f3
     }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
+
+// -----
+
+// CHECK-LABEL: cond_if_cond_type
+func.func @test_cond_if_cond_type(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error@+2 {{expected ':'}}
+  // expected-error@+1 {{custom op 'tosa.cond_if' expected type for condition operand}}
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    tosa.yield %arg0 : tensor<f32>
+  } else {
+    tosa.yield %arg1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: cond_if_cond_size
+func.func @test_cond_if_cond_size(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<1x2xi1>) -> tensor<f32> {
+  // expected-error@+1 {{'tosa.cond_if' op 'condition' must be a size 1 tensor, got 'tensor<1x2xi1>'}}
+  %0 = tosa.cond_if %arg2 : tensor<1x2xi1> -> (tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 5630c33639d86..3154f541e0519 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -337,7 +337,7 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
 // -----
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
   } else {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0dddf26fb1f85..cbe0056bafe22 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1506,13 +1506,13 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
 // -----
 
 func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-    %1 = tosa.cond_if %arg3 -> (tensor<f32>) {
-      %2 = tosa.cond_if %arg2 -> (tensor<f32>) {
-        %3 = tosa.cond_if %arg3 -> (tensor<f32>) {
-          %4 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+    %1 = tosa.cond_if %arg3 : tensor<i1>-> tensor<f32> {
+      %2 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+        %3 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
+          %4 = tosa.cond_if %arg2 : tensor<i1>  -> tensor<f32> {
             // expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
-            %5 = tosa.cond_if %arg3 -> (tensor<f32>) {
+            %5 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
               %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
               tosa.yield %res : tensor<f32>
             } else {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index ef51197e86d56..30361a882afe5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -839,7 +839,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
 // -----
 // CHECK-LABEL: cond_if
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
   } else {
diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
index 38ac8d8fb66d9..e957bdd15e1ec 100644
--- a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
@@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
 // CHECK-LABEL: test_regions
 // CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
 func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
-  // CHECK: tosa.cond_if %arg2 -> (tensor<i8>)
+  // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8> 
   %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
   ^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
     // CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d0f40279421f4..9ff0ce3b449d1 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1153,8 +1153,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
   %b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32>
 
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<f32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  // CHECK: -> tensor<f32>
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     tosa.yield %a : tensor<f32>
   } else {
     tosa.yield %b : tensor<f32>
@@ -1167,8 +1167,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
 // CHECK-LABEL: @if_test_dynamic
 func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<?xf32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) {
+  // CHECK: -> tensor<?xf32>
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<?xf32> {
     tosa.yield %arg0 : tensor<2xf32>
   } else {
     tosa.yield %arg1 : tensor<3xf32>
@@ -1181,8 +1181,8 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 :
 // CHECK-LABEL: @if_test_unranked
 func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<*xf32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) {
+  // CHECK: -> tensor<*xf32>
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<*xf32> {
     tosa.yield %arg0 : tensor<f32>
   } else {
     tosa.yield %arg1 : tensor<3xf32>
@@ -1195,8 +1195,8 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 :
 // CHECK-LABEL: @if_test_propagate
 func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<f32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  // CHECK: -> tensor<f32>
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
   } else {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index b3052369b055e..572dfd82676f1 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -500,9 +500,37 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %ar
 
 // -----
 
+func.func @test_cond_if_input_list_mismatch_else_block_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}}
+  %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+  ^bb0(%arg3: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_else_block_simple_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}}
+  %0 = tosa.cond_if %arg2 (%arg3 = %arg0) : tensor<i1> (tensor<f32>) -> tensor<f32> {
+    tosa.yield %arg3 : tensor<f32>
+  } else {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
 func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1, %2 : tensor<f32>, tensor<f32>
@@ -517,7 +545,7 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg
 
 func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}}
-  %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+  %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
   } else {
@@ -531,7 +559,7 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %a
 
 func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
   } else {
@@ -546,7 +574,7 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg
 
 func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}}
-  %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+  %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     %2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1, %2 : tensor<f32>, tensor<f32>

@lhutton1
Copy link
Contributor

lhutton1 commented Jul 21, 2025

Note: this change makes #144859 redundant.

@Yuvaraj-Venkatesh Yuvaraj-Venkatesh force-pushed the cond-if-print-parse-improvements branch from 5d3232b to 7b2d9bc Compare July 22, 2025 15:33
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @Yuvaraj-Venkatesh! Apologies, I took another look and picked up some things I missed the first time round. Could we also add some positive test cases to check the improved functionality?:

  • One to check that the shape of the condition is retained after parse/print round-trip
  • One to check that the block arguments are retained after parse/print round-trip

Since the cond_if validation check is now merged, it would be good to update this comment: https://github.com/llvm/llvm-project/pull/143772/files#diff-9094a3c264abb63100ab95e0e9e6d5c303b346ec9926b814e58b50ffaa1253e0R1235 as part of this change as well.

Since this change potentially impacts downstream users, tagging a few more people that might be interested: @sjarus @GeorgeARM @amirBish @udaya-ranga

@Yuvaraj-Venkatesh Yuvaraj-Venkatesh force-pushed the cond-if-print-parse-improvements branch from 7b2d9bc to 69d4e76 Compare July 25, 2025 15:56
@Yuvaraj-Venkatesh
Copy link
Contributor Author

Thanks for the updates @Yuvaraj-Venkatesh! Apologies, I took another look and picked up some things I missed the first time round. Could we also add some positive test cases to check the improved functionality?:

  • One to check that the shape of the condition is retained after parse/print round-trip
  • One to check that the block arguments are retained after parse/print round-trip

Since the cond_if validation check is now merged, it would be good to update this comment: https://github.com/llvm/llvm-project/pull/143772/files#diff-9094a3c264abb63100ab95e0e9e6d5c303b346ec9926b814e58b50ffaa1253e0R1235 as part of this change as well.

Since this change potentially impacts downstream users, tagging a few more people that might be interested: @sjarus @GeorgeARM @amirBish @udaya-ranga

Thanks @lhutton1!
I’ve added a couple of positive test cases and updated the comments in the validation logic as well.

…or and optional block arguments

This change extends the TOSA `cond_if` operation's print and parse logic to handle the following:

- The condition operand may now have any rank, as long as the total number of elements sums to 1.

  %1 = tosa.cond_if %0 : tensor<1x1x1xi1> -> tensor<4xf32>

- The `then` and `else` regions can now include optional block arguments. The updated IR syntax reflects this:

  %1 = tosa.cond_if %0 (%arg2 = %arg0, %arg3 = %arg1) : tensor<i1> (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>

- Removed parentheses around single result types in the printed representation, aligning with the `AsmPrinter` conventions.
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @Yuvaraj-Venkatesh, LGTM!

@Yuvaraj-Venkatesh Yuvaraj-Venkatesh force-pushed the cond-if-print-parse-improvements branch from 69d4e76 to 8787810 Compare July 25, 2025 16:25
@lhutton1 lhutton1 merged commit 9c606ae into llvm:main Jul 28, 2025
9 checks passed
@github-actions
Copy link

@Yuvaraj-Venkatesh Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants