diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index cc78aaed911e6..52bb0eb992b69 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -2559,6 +2559,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if", ); let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -2597,6 +2598,7 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [ ); let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; } include "mlir/Dialect/Tosa/IR/TosaUtilOps.td" diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index d535009f34533..b2e471f2bba93 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -562,6 +562,57 @@ static LogicalResult verifyConvOpErrorIf(T op) { return success(); } +// Verify whether same type and shape of the given two types. +static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, + StringRef name1, Type type2, + StringRef name2) { + auto shapeType1 = dyn_cast(type1); + auto shapeType2 = dyn_cast(type2); + if (!shapeType1 || !shapeType2) + return failure(); + + auto elemType1 = shapeType1.getElementType(); + auto elemType2 = shapeType2.getElementType(); + if (elemType1 != elemType2) + return op->emitOpError() + << "require same element type for " << name1 << " (" << elemType1 + << ") and " << name2 << " (" << elemType2 << ")"; + + if (failed(verifyCompatibleShape(type1, type2))) + return op->emitOpError() + << "require same shapes for " << name1 << " (" << type1 << ") and " + << name2 << " (" << type2 << ")"; + + return success(); +} + +// Verify whether same length, type, and shape of the given two tensor lists. +static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1, + StringRef name1, + ValueRange list2, + StringRef name2) { + if (list1.size() != list2.size()) + return op->emitOpError() + << "require same number of values in " << name1 << " (" + << list1.size() << ") and " << name2 << " (" << list2.size() << ")"; + + for (auto [type1, type2] : + llvm::zip_equal(list1.getTypes(), list2.getTypes())) { + if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed()) + return failure(); + } + + return success(); +} + +static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) { + ShapeAdaptor shapeAdaptor(type); + if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape()) + return success(); + + return shapeAdaptor.getNumElements() == 1 ? success() : failure(); +} + // verify that inType and outType have same element types template static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { @@ -3473,6 +3524,84 @@ void IfOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); } +LogicalResult IfOp::verify() { + if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(), + "'then_graph' arguments", getInputList(), + "'input_list'") + .failed()) + return failure(); + + if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(), + "'else_graph' arguments", getInputList(), + "'input_list'") + .failed()) + return failure(); + + auto thenYield = cast(getThenGraph().front().getTerminator()); + if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(), + "'then_graph' results", getOutputList(), + "'output_list'") + .failed()) + return failure(); + + auto elseYield = cast(getElseGraph().front().getTerminator()); + if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(), + "'else_graph' results", getOutputList(), + "'output_list'") + .failed()) + return failure(); + + auto condType = getCondition().getType(); + if (errorIfShapeNotSizeOne(*this, condType).failed()) + return emitOpError() << "'condition' must be a size 1 tensor, got " + << condType; + + return success(); +} + +LogicalResult WhileOp::verify() { + if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'", + getOutputList(), "'output_list'") + .failed()) + return failure(); + + if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(), + "'cond_graph' arguments", getInputList(), + "'input_list'") + .failed()) + return failure(); + + if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(), + "'body_graph' arguments", getInputList(), + "'input_list'") + .failed()) + return failure(); + + auto bodyYield = cast(getBodyGraph().front().getTerminator()); + if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(), + "'body_graph' results", getInputList(), + "'input_list'") + .failed()) + return failure(); + + // Condition block output must be a single element tensor with a single bool + // value. + auto condYield = cast(getCondGraph().front().getTerminator()); + if (condYield.getInputs().size() != 1) + return emitOpError() << "require 'cond_graph' only have one result"; + + auto condOutType = condYield.getInputs()[0].getType(); + if (errorIfShapeNotSizeOne(*this, condOutType).failed()) + return emitOpError() << "'cond_graph' result must be a size 1 tensor, got " + << condOutType; + + if (!getElementTypeOrSelf(condOutType).isInteger(1)) + return emitOpError() << "'cond_graph' result must be a boolean tensor, got " + << condOutType; + + return success(); +} + LogicalResult ReverseOp::verify() { if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(), /* outType = */ getOutput().getType()) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index df950939645b0..e8b52d48347ab 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -449,6 +449,35 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { return true; } + // Recursively perform a bottom-up search to determine the maximum nesting + // depth, starting from a specific operation and continuing up to the function + // or module scope. Tosa nesting_depth starts at 0 and increments by one each + // time a new nested `region` is encountered. + static void getMaxNestedDepth(Operation *op, int32_t &depth) { + if (isa(op) || isa(op)) + return; + + op = op->getParentOp(); + if (!op) + return; + + depth++; + getMaxNestedDepth(op, depth); + return; + } + + bool levelCheckMaxNesting(Operation *op) { + int32_t maxNestedDepth = 0; + getMaxNestedDepth(op, maxNestedDepth); + + if (maxNestedDepth >= tosaLevel.MAX_NESTING) { + op->emitOpError() << "failed level check: " << maxNestedDepth + << " >= MAX_NESTING"; + return false; + } + return true; + } + bool levelCheckListSize(Operation *op) { if (auto concat = dyn_cast(op)) { return levelCheckListSize(op, concat.getInput1().size(), "input1"); @@ -750,6 +779,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) { return failure(); } + if (isa(op) || isa(op)) { + if (!levelCheckMaxNesting(op)) { + return failure(); + } + } + return success(); } diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 5307645324b81..d24c1fa57883d 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1327,33 +1327,42 @@ func.func @test_if_tensor_list_size(%arg0 : tensor) { %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> // expected-error@+1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}} %1 = "tosa.cond_if"(%arg0, // condition - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0) ({ - ^bb0(%arg3: tensor<1xi32>): - "tosa.yield"(%arg3) : (tensor<1xi32>) -> () + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0) ({ + ^bb0(%00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>, + %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>, + %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>, + %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>, + %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>, + %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>, + %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>, %63: tensor<1xi32>, %64: tensor<1xi32> + ): + "tosa.yield"(%64) : (tensor<1xi32>) -> () }, { - ^bb0(%arg3: tensor<1xi32>): - "tosa.yield"(%arg3) : (tensor<1xi32>) -> () + ^bb0(%00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>, + %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>, + %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>, + %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>, + %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>, + %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>, + %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>, %63: tensor<1xi32>, %64: tensor<1xi32> + ): + "tosa.yield"(%01) : (tensor<1xi32>) -> () }) : ( - tensor, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32> - ) -> tensor<1xi32> - + tensor, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32> + ) -> tensor<1xi32> return } @@ -1361,27 +1370,54 @@ func.func @test_if_tensor_list_size(%arg0 : tensor) { // CHECK-LABEL: test_if_tensor_list_size_outputs func.func @test_if_tensor_list_size_outputs(%arg0 : tensor) { - %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> // expected-error@+1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}} - %r:65 = "tosa.cond_if"(%arg0) ({ - ^bb0(%arg3: tensor<1xi32>): - "tosa.yield"(%arg3) : (tensor<1xi32>) -> () + %r:65 = "tosa.cond_if"(%arg0, %cst_0) ({ + ^bb0(%0: tensor<1xi32>): + "tosa.yield"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0 + ) : ( + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32> + ) -> () }, { - ^bb0(%arg3: tensor<1xi32>): - "tosa.yield"(%arg3) : (tensor<1xi32>) -> () - }) : (tensor) -> ( - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, - tensor<1xi32> - ) - + ^bb0(%0: tensor<1xi32>): + "tosa.yield"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0 + ) : ( + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32> + ) -> () + }) : (tensor, tensor<1xi32>) -> ( + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32> + ) return } @@ -1391,25 +1427,57 @@ func.func @test_if_tensor_list_size_outputs(%arg0 : tensor) { func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>) { %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> // expected-error@+1 {{'tosa.while_loop' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}} - %1:2 = "tosa.while_loop"(%0, %arg0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0, %0, - %0, %0, %0, %0, %0, %0, %0 + %1:65 = "tosa.while_loop"(%0, %arg0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0 ) ({ - ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>): + ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>, + %00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>, + %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>, + %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>, + %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>, + %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>, + %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>, + %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32> + ): %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %3 = "tosa.logical_not"(%2) : (tensor<1xi1>) -> tensor<1xi1> "tosa.yield"(%3) : (tensor<1xi1>) -> () }, { - ^bb0(%arg3: tensor, %arg4: tensor<1x1x1x1x1x1x1xf32>): - %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor - %3 = "tosa.add"(%arg3, %2) : (tensor, tensor) -> tensor - "tosa.yield"(%3, %arg4) : (tensor, tensor<1x1x1x1x1x1x1xf32>) -> () + ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>, + %00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>, + %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>, + %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>, + %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>, + %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>, + %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>, + %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32> + ): + %2 = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tosa.add"(%arg3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + "tosa.yield"(%3, %arg4, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, + %0, %0, %0 + ) : ( + tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, + tensor<1xi32>, tensor<1xi32>, tensor<1xi32> + ) -> () }) : (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, @@ -1419,28 +1487,7 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32> - ) -> (tensor, tensor<1x1x1x1x1x1x1xf32>) - - return -} - -// ----- - -// CHECK-LABEL: test_while_tensor_list_size_outputs -func.func @test_while_tensor_list_size_outputs(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>) { - %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - // expected-error@+1 {{'tosa.while_loop' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}} - %1:65 = "tosa.while_loop"(%0, %arg0) ({ - ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>): - %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %3 = "tosa.logical_not"(%2) : (tensor<1xi1>) -> tensor<1xi1> - "tosa.yield"(%3) : (tensor<1xi1>) -> () - }, { - ^bb0(%arg3: tensor, %arg4: tensor<1x1x1x1x1x1x1xf32>): - %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor - %3 = "tosa.add"(%arg3, %2) : (tensor, tensor) -> tensor - "tosa.yield"(%3, %arg4) : (tensor, tensor<1x1x1x1x1x1x1xf32>) -> () - }) : (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>) -> ( tensor, tensor<1x1x1x1x1x1x1xf32>, + ) -> (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, @@ -1453,3 +1500,101 @@ func.func @test_while_tensor_list_size_outputs(%arg0: tensor<1x1x1x1x1x1x1xf32>, return } + +// ----- + +func.func @test_cond_if_max_nested_depth(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = tosa.cond_if %arg2 -> (tensor) { + %1 = tosa.cond_if %arg3 -> (tensor) { + %2 = tosa.cond_if %arg2 -> (tensor) { + %3 = tosa.cond_if %arg3 -> (tensor) { + %4 = tosa.cond_if %arg2 -> (tensor) { + // expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}} + %5 = tosa.cond_if %arg3 -> (tensor) { + %res = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %res : tensor + } else { + tosa.yield %arg0 : tensor + } + tosa.yield %5 : tensor + } else { + tosa.yield %arg0 : tensor + } + tosa.yield %4 : tensor + } else { + tosa.yield %arg0 : tensor + } + tosa.yield %3 : tensor + } else { + tosa.yield %arg0 : tensor + } + tosa.yield %2 : tensor + } else { + tosa.yield %arg0 : tensor + } + tosa.yield %1 : tensor + } else { + %res = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %res : tensor + } + return %0 : tensor +} + +// ----- + +func.func @test_while_loop_max_nested_depth(%arg0: tensor) { + %init_0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %cst_1 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + + %1:2 = tosa.while_loop (%arg2 = %init_0, %arg3 = %arg0) : (tensor, tensor) -> (tensor, tensor) { + %2 = tosa.greater_equal %arg3, %arg2 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg2: tensor, %arg2b: tensor): + %1:2 = tosa.while_loop (%arg4 = %init_0, %arg5 = %arg0) : (tensor, tensor) -> (tensor, tensor) { + %2 = tosa.greater_equal %arg5, %arg4 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg4: tensor, %arg4b: tensor): + %1:2 = tosa.while_loop (%arg6 = %init_0, %arg7 = %arg0) : (tensor, tensor) -> (tensor, tensor) { + %2 = tosa.greater_equal %arg7, %arg6 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg6: tensor, %arg6b: tensor): + %1:2 = tosa.while_loop (%arg8 = %init_0, %arg9 = %arg0) : (tensor, tensor) -> (tensor, tensor) { + %2 = tosa.greater_equal %arg9, %arg8 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg8: tensor, %arg8b: tensor): + %1:2 = tosa.while_loop (%arg10 = %init_0, %arg11 = %arg0) : (tensor, tensor) -> (tensor, tensor) { + %2 = tosa.greater_equal %arg11, %arg10 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg10: tensor, %arg10b: tensor): + // expected-error@+1 {{'tosa.while_loop' op failed level check: 6 >= MAX_NESTING}} + %1:2 = tosa.while_loop (%arg12 = %init_0, %arg13 = %arg0) : (tensor, tensor) -> (tensor, tensor) { + %2 = tosa.greater_equal %arg13, %arg12 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg12: tensor, %arg12b: tensor): + %3 = tosa.add %arg12, %cst_1 : (tensor, tensor) -> tensor + tosa.yield %arg2, %3: tensor, tensor + } + %3 = tosa.add %arg10, %cst_1 : (tensor, tensor) -> tensor + tosa.yield %arg2, %3: tensor, tensor + } + %3 = tosa.add %arg8, %cst_1 : (tensor, tensor) -> tensor + tosa.yield %arg2, %3: tensor, tensor + } + %3 = tosa.add %arg6, %cst_1 : (tensor, tensor) -> tensor + tosa.yield %arg2, %3: tensor, tensor + } + %3 = tosa.add %arg4, %cst_1 : (tensor, tensor) -> tensor + tosa.yield %arg2, %3: tensor, tensor + } + %3 = tosa.add %arg2, %cst_1 : (tensor, tensor) -> tensor + tosa.yield %arg2, %3: tensor, tensor + } + return +} + diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index c669c36e5452f..7ae8ec470c3dd 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -436,4 +436,352 @@ func.func @test_pad_invalid_padding_value(%arg0: tensor<10xi8>, %arg1: tensor<1x // expected-error@+1 {{invalid padding values at dimension 0: values must be non-negative or -1 for dynamic padding, got [-2, 2]}} %1 = tosa.pad %arg0, %0, %arg1 : (tensor<10xi8>, !tosa.shape<2>, tensor<1xi8>) -> tensor<10xi8> return %1 : tensor<10xi8> -} \ No newline at end of file +} + +// ----- + +func.func @test_cond_if_input_list_mismatch_then_block(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' arguments (1) and 'input_list' (2)}} + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor): + tosa.yield %arg3 : tensor + }, { + ^bb0(%arg4: tensor): + tosa.yield %arg4 : tensor + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor + +} + +// ----- + +func.func @test_cond_if_input_list_mismatch_then_block_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' arguments (2) and 'input_list' (1)}} + %0 = "tosa.cond_if"(%arg2, %arg0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + tosa.yield %arg3 : tensor + }, { + ^bb0(%arg4: tensor): + tosa.yield %arg4 : tensor + }) : (tensor, tensor) -> tensor + return %0 : tensor + +} + +// ----- + +func.func @test_cond_if_input_list_mismatch_else_block(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}} + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + tosa.yield %arg3 : tensor + }, { + ^bb0(%arg4: tensor): + tosa.yield %arg4 : tensor + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor + +} + +// ----- + +func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}} + %0 = "tosa.cond_if"(%arg2, %arg0) ({ + ^bb0(%arg3: tensor): + tosa.yield %arg3 : tensor + }, { + ^bb0(%arg4: tensor, %arg3: tensor): + tosa.yield %arg4 : tensor + }) : (tensor, tensor) -> tensor + return %0 : tensor + +} + +// ----- + +func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}} + %0 = tosa.cond_if %arg2 -> (tensor) { + %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + %2 = tosa.add %1, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1, %2 : tensor, tensor + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } + return %0 : tensor +} + +// ----- + +func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}} + %0, %2 = tosa.cond_if %arg2 -> (tensor, tensor) { + %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } + return %0 : tensor +} + +// ----- + +func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}} + %0 = tosa.cond_if %arg2 -> (tensor) { + %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + %2 = tosa.add %1, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1, %2 : tensor, tensor + } + return %0 : tensor +} + +// ----- + +func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}} + %0, %2 = tosa.cond_if %arg2 -> (tensor, tensor) { + %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + %2 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1, %2 : tensor, tensor + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } + return %0 : tensor +} + +// ----- + +func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor, %arg1: tensor, %arg2: tensor<2xi1>) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op 'condition' must be a size 1 tensor, got 'tensor<2xi1>'}} + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + tosa.yield %arg3 : tensor + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + tosa.yield %arg4 : tensor + }) : (tensor<2xi1>, tensor, tensor) -> tensor + return %0 : tensor + +} + +// ----- + +func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (3) and 'input_list' (2)}} + %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) : (tensor, tensor<10xi32>) -> (tensor, tensor<10xi32>) { + %2 = tosa.greater_equal %arg2, %arg1 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor<10xi32>): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %3 = tosa.add %arg2, %2 : (tensor, tensor) -> tensor + tosa.yield %3, %arg4 : tensor, tensor<10xi32> + } + return +} + +// ----- + +func.func @test_while_loop_input_list_mismatch_body_block_in_2(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (2) and 'input_list' (3)}} + %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0, %arg4 = %arg0) + : (tensor, tensor<10xi32>, tensor<10xi32>) -> (tensor, tensor<10xi32>, tensor<10xi32>) { + %2 = tosa.greater_equal %arg2, %arg1 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg2: tensor, %arg3: tensor): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %3 = tosa.add %arg2, %2 : (tensor, tensor) -> tensor + tosa.yield %3, %arg3 : tensor, tensor + } + return +} + +// ----- + +func.func @test_while_loop_input_list_mismatch_output_list(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'input_list' (3) and 'output_list' (2)}} + %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0, %arg4 = %arg0) + : (tensor, tensor<10xi32>, tensor<10xi32>) -> (tensor, tensor<10xi32>) { + %2 = tosa.greater_equal %arg2, %arg1 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg2: tensor, %arg3: tensor): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %3 = tosa.add %arg2, %2 : (tensor, tensor) -> tensor + tosa.yield %3, %arg3 : tensor, tensor + } + return +} + +// ----- + +func.func @test_while_loop_input_list_mismatch_output_list_2(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'input_list' (2) and 'output_list' (3)}} + %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) + : (tensor, tensor<10xi32>) -> (tensor, tensor<10xi32>, tensor<10xi32>) { + %2 = tosa.greater_equal %arg2, %arg1 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg2: tensor, %arg3: tensor): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %3 = tosa.add %arg2, %2 : (tensor, tensor) -> tensor + tosa.yield %3, %arg3 : tensor, tensor + } + return +} + +// ----- + +func.func @test_while_loop_input_list_mismatch_cond_block(%arg0: tensor<2xf32>, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'cond_graph' arguments (3) and 'input_list' (2)}} + %1:2 = "tosa.while_loop"(%0, %arg0) ({ + ^bb0(%arg3: tensor, %arg4: tensor<2xf32>, %arg5: tensor<2xf32>): + %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor, tensor) -> tensor + "tosa.yield"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor<2xf32>): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %3 = "tosa.const"() {values = dense<2> : tensor<1xi8>} : () -> tensor<1xi8> + %4 = "tosa.mul"(%arg3, %2, %3) : (tensor, tensor, tensor<1xi8>) -> tensor + "tosa.yield"(%4, %arg4) : (tensor, tensor<2xf32>) -> () + }) : (tensor, tensor<2xf32>) -> (tensor, tensor<2xf32>) + return +} + +// ----- + +func.func @test_while_loop_input_list_mismatch_cond_block_2(%arg0: tensor<2xf32>, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'cond_graph' arguments (1) and 'input_list' (3)}} + %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({ + ^bb0(%arg3: tensor): + %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor, tensor) -> tensor + "tosa.yield"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor<2xf32>): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %3 = "tosa.const"() {values = dense<2> : tensor<1xi8>} : () -> tensor<1xi8> + %4 = "tosa.mul"(%arg3, %2, %3) : (tensor, tensor, tensor<1xi8>) -> tensor + "tosa.yield"(%4, %arg4) : (tensor, tensor<2xf32>) -> () + }) : (tensor, tensor<2xf32>, tensor) -> (tensor, tensor<2xf32>, tensor) + return +} + +// ----- + +func.func @test_while_loop_input_list_mismatch_body_block_out(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'body_graph' results (3) and 'input_list' (2)}} + %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) : (tensor, tensor<10xi32>) -> (tensor, tensor<10xi32>) { + %2 = tosa.greater_equal %arg2, %arg1 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg2: tensor, %arg4: tensor<10xi32>): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %3 = tosa.add %arg2, %2 : (tensor, tensor) -> tensor + tosa.yield %2, %3, %arg4 : tensor, tensor, tensor<10xi32> + } + return +} + +// ----- + +func.func @test_while_loop_input_list_mismatch_body_block_out_2(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'body_graph' results (1) and 'input_list' (2)}} + %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) : (tensor, tensor<10xi32>) -> (tensor, tensor<10xi32>) { + %2 = tosa.greater_equal %arg2, %arg1 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg2: tensor, %arg4: tensor<10xi32>): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %3 = tosa.add %arg2, %2 : (tensor, tensor) -> tensor + tosa.yield %3 : tensor + } + return +} + +// ----- + +func.func @test_while_loop_type_mismatch(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op require same element type for 'body_graph' arguments ('f32') and 'input_list' ('i32')}} + %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor<10xi32>) { + %2 = tosa.greater_equal %arg3, %arg1 : (tensor, tensor) -> tensor + %3 = tosa.logical_not %2 : (tensor) -> tensor + tosa.yield %3 : tensor + } do { + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor<10xi32>): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %6 = tosa.add %arg2, %2 : (tensor, tensor) -> tensor + tosa.yield %6, %2, %arg4 : tensor, tensor, tensor<10xi32> + } + return +} + +// ----- + +func.func @test_while_loop_type_mismatch_2(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op require same shapes for 'body_graph' arguments ('tensor<10xi32>') and 'input_list' ('tensor')}} + %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor<10xi32>) { + %2 = tosa.greater_equal %arg3, %arg1 : (tensor, tensor) -> tensor + %3 = tosa.logical_not %2 : (tensor) -> tensor + tosa.yield %3 : tensor + } do { + ^bb0(%arg2: tensor<10xi32>, %arg3: tensor, %arg4: tensor<10xi32>): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %6 = tosa.add %arg2, %2 : (tensor<10xi32>, tensor) -> tensor + tosa.yield %6, %2, %arg4 : tensor, tensor, tensor<10xi32> + } + return +} + +// ----- + +func.func @test_while_loop_cond_output_not_size_one(%arg0: tensor<10xi32>, %arg1: tensor<2xi32>) { + %0 = "tosa.const"() {values = dense<[4, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error@+1 {{'tosa.while_loop' op 'cond_graph' result must be a size 1 tensor, got 'tensor<2xi1>'}} + %1:3 = tosa.while_loop (%arg2 = %arg0, %arg3 = %0, %arg4 = %arg0) : (tensor<10xi32>, tensor<2xi32>, tensor<10xi32>) -> (tensor<10xi32>, tensor<2xi32>, tensor<10xi32>) { + %2 = tosa.greater_equal %arg3, %arg1 : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + tosa.yield %2 : tensor<2xi1> + } do { + ^bb0(%arg2: tensor<10xi32>, %arg3: tensor<2xi32>, %arg4: tensor<10xi32>): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %3 = "tosa.const"() {values = dense<[3, 5]> : tensor<2xi32>} : () -> tensor<2xi32> + %4 = tosa.add %arg2, %2 : (tensor<10xi32>, tensor) -> tensor<10xi32> + tosa.yield %4, %3, %arg4 : tensor<10xi32>, tensor<2xi32>, tensor<10xi32> + } + return +} + +// ----- + +func.func @test_while_loop_cond_output_not_bool(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<9> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op 'cond_graph' result must be a boolean tensor, got 'tensor'}} + %1:3 = tosa.while_loop (%arg2 = %arg0, %arg3 = %0, %arg4 = %arg0) : (tensor<10xi32>, tensor, tensor<10xi32>) -> (tensor<10xi32>, tensor, tensor<10xi32>) { + %2 = tosa.add %arg3, %arg1 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } do { + ^bb0(%arg2: tensor<10xi32>, %arg3: tensor, %arg4: tensor<10xi32>): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %4 = tosa.add %arg2, %2 : (tensor<10xi32>, tensor) -> tensor<10xi32> + tosa.yield %4, %2, %arg4 : tensor<10xi32>, tensor, tensor<10xi32> + } + return +}