Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
// and optional initial value. The builder will extract var_shape and element type
// attributes from variable type.
def Tosa_VariableOpBuilder : OpBuilder<
(ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
(ins "StringRef":$sym_name, "Type":$variable_type, "Attribute":$initial_value),
[{
buildVariableOp($_builder, $_state, name, variable_type, initial_value);
buildVariableOp($_builder, $_state, sym_name, variable_type, initial_value);
}]>;


Expand Down
16 changes: 12 additions & 4 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
include "mlir/IR/OpBase.td"

include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
Expand Down Expand Up @@ -82,7 +83,7 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
//===----------------------------------------------------------------------===//
// Operator: variable
//===----------------------------------------------------------------------===//
def Tosa_VariableOp : Tosa_Op<"variable", []> {
def Tosa_VariableOp : Tosa_Op<"variable", [Symbol]> {
let summary = "Defines a variable";

let description = [{
Expand All @@ -91,7 +92,10 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
}];

let arguments = (ins
SymbolNameAttr:$name,
// Note: "sym_name" is used as opposed to "name" in the specification,
// since a Symbol must be named "sym_name" for it to be recognised by
// the containing SymbolTable.
SymbolNameAttr:$sym_name,
IndexElementsAttr:$var_shape,
TypeAttr:$type,
OptionalAttr<AnyAttr>:$initial_value
Expand All @@ -105,14 +109,18 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
let hasCustomAssemblyFormat = 1;

let assemblyFormat = [{
$name
$sym_name
attr-dict
custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
}];

let builders = [Tosa_VariableOpBuilder];

let hasVerifier = 1;
let extraClassDeclaration = [{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be getSymbolName() to avoid ambiguity ?

Copy link
Contributor Author

@lhutton1 lhutton1 Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was to avoid breaking dependent code that might have previously relied on getName() (before the renaming). In theory, we could avoid providing this, breaking the API, and make it clear in the PR that this is a breaking change

::llvm::StringRef getName() {
return getSymName();
}
}];
}

//===----------------------------------------------------------------------===//
Expand Down
67 changes: 15 additions & 52 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,56 +905,29 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
}

// Returns the first declaration point prior to this operation or failure if
// not found.
static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
StringRef symName) {
ModuleOp module = op->getParentOfType<ModuleOp>();
tosa::VariableOp varOp = nullptr;

// TODO: Adopt SymbolTable trait to Varible ops.
// Currently, the variable's definition point is searched via walk(),
// starting from the top-level ModuleOp and stopping at the point of use. Once
// TOSA control flow and variable extensions reach the complete state, may
// leverage MLIR's Symbol Table functionality to look up symbol and enhance
// the search to a TOSA specific graph traversal over the IR structure.
module.walk([&](Operation *tempOp) {
// Reach this op itself.
if (tempOp == op) {
return WalkResult::interrupt();
}

if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
if (symName == tosaOp.getName()) {
varOp = tosaOp;
return WalkResult::interrupt();
}
}

return WalkResult::advance();
});

if (varOp)
return varOp;

return failure();
}

template <typename T>
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
StringRef symName = op.getName();
FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
if (failed(varOp))
Operation *symTableOp =
op->template getParentWithTrait<OpTrait::SymbolTable>();
if (!symTableOp)
// If the operation is not the scope of a symbol table, we cannot
// verify it against it's declaration.
return success();

SymbolTable symTable(symTableOp);
const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());

// Verify prior declaration
if (!varOp)
return op->emitOpError("'")
<< symName << "' has not been declared by 'tosa.variable'";
<< op.getName() << "' has not been declared by 'tosa.variable'";

// Verify type and shape
auto variableType = getVariableType(varOp.value());
auto variableType = getVariableType(varOp);
if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
"the input tensor")
.failed())
return failure();

return success();
}

Expand Down Expand Up @@ -1418,7 +1391,7 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> shape = shapedType.getShape();
auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));

result.addAttribute("name", nameAttr);
result.addAttribute("sym_name", nameAttr);
result.addAttribute("var_shape", varShapeAttr);
result.addAttribute("type", elementTypeAttr);
result.addAttribute("initial_value", initialValue);
Expand Down Expand Up @@ -4160,16 +4133,6 @@ LogicalResult tosa::SelectOp::verify() {
return success();
}

LogicalResult tosa::VariableOp::verify() {
StringRef symName = getName();
FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
if (succeeded(varOp))
return emitOpError("illegal to have multiple declaration of '")
<< symName << "'";

return success();
}

LogicalResult tosa::VariableReadOp::verify() {
if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
.failed())
Expand Down
55 changes: 26 additions & 29 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -573,64 +573,61 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten

// -----

func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () {
module {
tosa.variable @stored_var : tensor<*xi8>
// expected-error@+1 {{custom op 'tosa.variable' expected ranked type}}
return
}

// -----

func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () {
module {
// expected-error@+1 {{elements literal type must have static shape}}
tosa.variable @stored_var = dense<0> : tensor<*xi8>
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
return
}

// -----

func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
return
}

// -----

func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
module {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
return
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
return
}
}

// -----

func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
module {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
return
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
return
}
}

// -----

func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
module {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
return
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
return
}
}

// -----

func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
module {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
return
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}
}

// -----
Expand Down
22 changes: 14 additions & 8 deletions mlir/test/Dialect/Tosa/invalid_extension.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -310,21 +310,27 @@ func.func @test_identity(%arg0: tensor<13x21x3xi4>) -> tensor<13x21x3xi4> {
}

// -----
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
module {
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
return

func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
// expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
return
}
}

// -----
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
module {
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
return

func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
// expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
return
}
}

// -----
Expand Down
30 changes: 18 additions & 12 deletions mlir/test/Dialect/Tosa/level_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1097,14 +1097,17 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %

// -----

func.func @test_variable_read_write_tensor_size_invalid() -> () {
module {
// expected-error@+1 {{'tosa.variable' op failed level check: variable type tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
tosa.variable @stored_var : tensor<536870912xf32>
// expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
%0 = tosa.variable_read @stored_var : tensor<536870912xf32>
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
return

func.func @test_variable_read_write_tensor_size_invalid() -> () {
// expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
%0 = tosa.variable_read @stored_var : tensor<536870912xf32>
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
return
}
}

// -----
Expand Down Expand Up @@ -1165,14 +1168,17 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:

// -----

func.func @test_variable_read_write_rank_invalid() -> () {
module {
// expected-error@+1 {{'tosa.variable' op failed level check: variable type rank(shape) <= MAX_RANK}}
tosa.variable @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
// expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
%0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
return

func.func @test_variable_read_write_rank_invalid() -> () {
// expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
%0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
return
}
}

// -----
Expand Down
Loading