From eeb04346fac4c0c85bbeb53f3a338aac7f8cd3bc Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Thu, 24 Apr 2025 23:38:38 -0700 Subject: [PATCH] [mlir][tosa] Add error if checks Variable Operators For VARIABLE, VARIABLE_WRITE & VARIABLE_READ --- .../mlir/Dialect/Tosa/IR/TosaUtilOps.td | 6 ++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 78 ++++++++++++++++++ .../TosaToLinalg/tosa-to-linalg-pipeline.mlir | 10 --- mlir/test/Dialect/Tosa/invalid.mlir | 10 +-- mlir/test/Dialect/Tosa/variables.mlir | 4 +- mlir/test/Dialect/Tosa/verifier.mlir | 79 +++++++++++++++++++ 6 files changed, 170 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td index 0ab0a62f1cf11..5f99162907949 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td @@ -106,6 +106,8 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> { attr-dict custom($type, $initial_value) }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -131,6 +133,8 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> { let assemblyFormat = [{ $name attr-dict `,` $input1 `:` type($input1) }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -159,6 +163,8 @@ def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> { let assemblyFormat = [{ $name attr-dict `:` type($output1) }]; + + let hasVerifier = 1; } #endif // TOSA_UTIL_OPS diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index b2e471f2bba93..c669bc4a31d43 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -613,6 +613,58 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) { return shapeAdaptor.getNumElements() == 1 ? success() : failure(); } +// Returns the first declaration point prior to this operation or failure if +// not found. +static FailureOr findVariableDecl(Operation *op, + StringRef symName) { + ModuleOp module = op->getParentOfType(); + tosa::VariableOp varOp = nullptr; + + // TODO: Adopt SymbolTable trait to Varible ops. + // Currently, the variable's definition point is searched via walk(), + // starting from the top-level ModuleOp and stopping at the point of use. Once + // TOSA control flow and variable extensions reach the complete state, may + // leverage MLIR's Symbol Table functionality to look up symbol and enhance + // the search to a TOSA specific graph traversal over the IR structure. + module.walk([&](Operation *tempOp) { + // Reach this op itself. + if (tempOp == op) { + return WalkResult::interrupt(); + } + + if (auto tosaOp = dyn_cast(tempOp)) { + if (symName == tosaOp.getName()) { + varOp = tosaOp; + return WalkResult::interrupt(); + } + } + + return WalkResult::advance(); + }); + + if (varOp) + return varOp; + + return failure(); +} + +template +static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { + StringRef symName = op.getName(); + FailureOr varOp = findVariableDecl(op, symName); + if (failed(varOp)) + return op->emitOpError("'") + << symName << "' has not been declared by 'tosa.variable'"; + + // Verify type and shape + Type varType = cast(varOp.value()).getType(); + if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor") + .failed()) + return failure(); + + return success(); +} + // verify that inType and outType have same element types template static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { @@ -3660,6 +3712,32 @@ LogicalResult tosa::SelectOp::verify() { return success(); } +LogicalResult tosa::VariableOp::verify() { + StringRef symName = getName(); + FailureOr varOp = findVariableDecl(*this, symName); + if (succeeded(varOp)) + return emitOpError("illegal to have multiple declaration of '") + << symName << "'"; + + return success(); +} + +LogicalResult tosa::VariableReadOp::verify() { + if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'") + .failed()) + return failure(); + + return success(); +} + +LogicalResult tosa::VariableWriteOp::verify() { + if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'") + .failed()) + return failure(); + + return success(); +} + // parse and print of WhileOp refer to the implementation of SCF dialect. ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector regionArgs; diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir index 37ed5cec00a0d..74706c426ea9c 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir @@ -1,16 +1,6 @@ // RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics -// ----- - -// check that -tosa-validate of stateful ops kick in -func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}} - tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8> - return -} - // ----- // check that -tosa-validate level checking kick in diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index c4f95b47628d1..9ccb310c4491d 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -566,7 +566,7 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () { tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable' op name has already been declared}} + // expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}} tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8> return } @@ -575,7 +575,7 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () { func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () { tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}} + // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}} %0 = tosa.variable_read @stored_var : tensor<2x4x8xi16> return } @@ -584,7 +584,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () { func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () { tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}} + // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}} %0 = tosa.variable_read @stored_var : tensor<1x4x8xi32> return } @@ -593,7 +593,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () { func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () { tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable_write' op illegal: operand/result data types not supported}} + // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}} tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16> return } @@ -602,7 +602,7 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () { func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () { tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}} + // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}} tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8> return } diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir index 6fa6b26155461..25f63331f39df 100644 --- a/mlir/test/Dialect/Tosa/variables.mlir +++ b/mlir/test/Dialect/Tosa/variables.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s | mlir-opt | FileCheck %s -// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --split-input-file | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --split-input-file --mlir-print-op-generic | mlir-opt | FileCheck %s // ----- diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 7ae8ec470c3dd..990e0d954f54e 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -785,3 +785,82 @@ func.func @test_while_loop_cond_output_not_bool(%arg0: tensor<10xi32>, %arg1: te } return } + +// ----- + +func.func @test_variable_multiple_declaration() -> () { + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> + // expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}} + tosa.variable @stored_var = dense<-3> : tensor<2x4x8xi32> + return +} + +// ----- + +func.func @test_variable_shape_mismatch() -> () { + // expected-error@+1 {{inferred shape of elements literal ([2]) does not match type ([3])}} + tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32> + // expected-error@+1 {{custom op 'tosa.variable' expected attribute}} + return +} + +// ----- + +func.func @test_variable_type_mismatch() -> () { + // expected-error@+1 {{expected integer elements, but parsed floating-point}} + tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32> + // expected-error@+1 {{custom op 'tosa.variable' expected attribute}} + return +} + +// ----- + +func.func @test_variable_read_no_declaration() -> () { + // expected-error@+1 {{'tosa.variable_read' op 'stored_var' has not been declared by 'tosa.variable'}} + %0 = tosa.variable_read @stored_var : tensor + return +} + +// ----- + +func.func @test_variable_read_type_mismatch() -> () { + tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32> + // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}} + %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32> + return +} + +// ----- + +func.func @test_variable_read_shape_mismatch() -> () { + tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32> + // expected-error@+1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}} + %0 = tosa.variable_read @stored_var : tensor<2x4x8xf32> + return +} + +// ----- + +func.func @test_variable_write_no_declaration(%arg0: tensor) -> () { + // expected-error@+1 {{'tosa.variable_write' op 'stored_var' has not been declared by 'tosa.variable'}} + tosa.variable_write @stored_var, %arg0 : tensor + return +} + +// ----- + +func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () { + tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32> + // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}} + tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32> + return +} + +// ----- + +func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () { + tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32> + // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}} + tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32> + return +}