Skip to content

Commit fee71a3

Browse files
authored
[mlir][tosa] Apply 'Symbol' trait to tosa.variable (#153223)
Implement SymbolOpInterface on tosa.variable so that it's declaration is automatically inserted into its parents SymbolTable. Verifiers for tosa.variable_read/write can now look up the symbol and guarantee it exists, and duplicate names are caught at creation time. Previously this was completed by walking the graph which could be inefficient. Unfortunately, the Symbol trait expects to find a symbol name via a hard-coded attribute name "sym_name". Therefore, "name" is renamed to"sym_name" and a getName() wrapper is provided for backwards compatibility. This change also restricts tosa.variable declarations to ops that carry a SymbolTable (e.g. modules), rather than allowing them to be placed inside a func.func. Note: EXT-VARIABLE is an experimental extension in the TOSA specification, so is not subject to backwards compatibility guarantees.
1 parent 5e07093 commit fee71a3

File tree

8 files changed

+197
-185
lines changed

8 files changed

+197
-185
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,9 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
201201
// and optional initial value. The builder will extract var_shape and element type
202202
// attributes from variable type.
203203
def Tosa_VariableOpBuilder : OpBuilder<
204-
(ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
204+
(ins "StringRef":$sym_name, "Type":$variable_type, "Attribute":$initial_value),
205205
[{
206-
buildVariableOp($_builder, $_state, name, variable_type, initial_value);
206+
buildVariableOp($_builder, $_state, sym_name, variable_type, initial_value);
207207
}]>;
208208

209209

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
include "mlir/IR/OpBase.td"
1919

2020
include "mlir/Interfaces/SideEffectInterfaces.td"
21+
include "mlir/IR/SymbolInterfaces.td"
2122
include "mlir/Interfaces/LoopLikeInterface.td"
2223
include "mlir/Interfaces/VectorInterfaces.td"
2324
include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
@@ -82,7 +83,7 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
8283
//===----------------------------------------------------------------------===//
8384
// Operator: variable
8485
//===----------------------------------------------------------------------===//
85-
def Tosa_VariableOp : Tosa_Op<"variable", []> {
86+
def Tosa_VariableOp : Tosa_Op<"variable", [Symbol]> {
8687
let summary = "Defines a variable";
8788

8889
let description = [{
@@ -91,7 +92,10 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
9192
}];
9293

9394
let arguments = (ins
94-
SymbolNameAttr:$name,
95+
// Note: "sym_name" is used as opposed to "name" in the specification,
96+
// since a Symbol must be named "sym_name" for it to be recognised by
97+
// the containing SymbolTable.
98+
SymbolNameAttr:$sym_name,
9599
IndexElementsAttr:$var_shape,
96100
TypeAttr:$type,
97101
OptionalAttr<AnyAttr>:$initial_value
@@ -105,14 +109,18 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
105109
let hasCustomAssemblyFormat = 1;
106110

107111
let assemblyFormat = [{
108-
$name
112+
$sym_name
109113
attr-dict
110114
custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
111115
}];
112116

113117
let builders = [Tosa_VariableOpBuilder];
114118

115-
let hasVerifier = 1;
119+
let extraClassDeclaration = [{
120+
::llvm::StringRef getName() {
121+
return getSymName();
122+
}
123+
}];
116124
}
117125

118126
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 15 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -905,56 +905,29 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
905905
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
906906
}
907907

908-
// Returns the first declaration point prior to this operation or failure if
909-
// not found.
910-
static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
911-
StringRef symName) {
912-
ModuleOp module = op->getParentOfType<ModuleOp>();
913-
tosa::VariableOp varOp = nullptr;
914-
915-
// TODO: Adopt SymbolTable trait to Varible ops.
916-
// Currently, the variable's definition point is searched via walk(),
917-
// starting from the top-level ModuleOp and stopping at the point of use. Once
918-
// TOSA control flow and variable extensions reach the complete state, may
919-
// leverage MLIR's Symbol Table functionality to look up symbol and enhance
920-
// the search to a TOSA specific graph traversal over the IR structure.
921-
module.walk([&](Operation *tempOp) {
922-
// Reach this op itself.
923-
if (tempOp == op) {
924-
return WalkResult::interrupt();
925-
}
926-
927-
if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
928-
if (symName == tosaOp.getName()) {
929-
varOp = tosaOp;
930-
return WalkResult::interrupt();
931-
}
932-
}
933-
934-
return WalkResult::advance();
935-
});
936-
937-
if (varOp)
938-
return varOp;
939-
940-
return failure();
941-
}
942-
943908
template <typename T>
944909
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
945-
StringRef symName = op.getName();
946-
FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
947-
if (failed(varOp))
910+
Operation *symTableOp =
911+
op->template getParentWithTrait<OpTrait::SymbolTable>();
912+
if (!symTableOp)
913+
// If the operation is not the scope of a symbol table, we cannot
914+
// verify it against it's declaration.
915+
return success();
916+
917+
SymbolTable symTable(symTableOp);
918+
const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
919+
920+
// Verify prior declaration
921+
if (!varOp)
948922
return op->emitOpError("'")
949-
<< symName << "' has not been declared by 'tosa.variable'";
923+
<< op.getName() << "' has not been declared by 'tosa.variable'";
950924

951925
// Verify type and shape
952-
auto variableType = getVariableType(varOp.value());
926+
auto variableType = getVariableType(varOp);
953927
if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
954928
"the input tensor")
955929
.failed())
956930
return failure();
957-
958931
return success();
959932
}
960933

@@ -1418,7 +1391,7 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result,
14181391
ArrayRef<int64_t> shape = shapedType.getShape();
14191392
auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
14201393

1421-
result.addAttribute("name", nameAttr);
1394+
result.addAttribute("sym_name", nameAttr);
14221395
result.addAttribute("var_shape", varShapeAttr);
14231396
result.addAttribute("type", elementTypeAttr);
14241397
result.addAttribute("initial_value", initialValue);
@@ -4160,16 +4133,6 @@ LogicalResult tosa::SelectOp::verify() {
41604133
return success();
41614134
}
41624135

4163-
LogicalResult tosa::VariableOp::verify() {
4164-
StringRef symName = getName();
4165-
FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
4166-
if (succeeded(varOp))
4167-
return emitOpError("illegal to have multiple declaration of '")
4168-
<< symName << "'";
4169-
4170-
return success();
4171-
}
4172-
41734136
LogicalResult tosa::VariableReadOp::verify() {
41744137
if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
41754138
.failed())

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -573,64 +573,61 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
573573

574574
// -----
575575

576-
func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () {
576+
module {
577577
tosa.variable @stored_var : tensor<*xi8>
578578
// expected-error@+1 {{custom op 'tosa.variable' expected ranked type}}
579-
return
580579
}
581580

582581
// -----
583582

584-
func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () {
583+
module {
585584
// expected-error@+1 {{elements literal type must have static shape}}
586585
tosa.variable @stored_var = dense<0> : tensor<*xi8>
587586
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
588-
return
589-
}
590-
591-
// -----
592-
593-
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
594-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
595-
// expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
596-
tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
597-
return
598587
}
599588

600589
// -----
601590

602-
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
591+
module {
603592
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
604-
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
605-
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
606-
return
593+
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
594+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
595+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
596+
return
597+
}
607598
}
608599

609600
// -----
610601

611-
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
602+
module {
612603
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
613-
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
614-
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
615-
return
604+
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
605+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
606+
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
607+
return
608+
}
616609
}
617610

618611
// -----
619612

620-
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
613+
module {
621614
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
622-
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
623-
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
624-
return
615+
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
616+
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
617+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
618+
return
619+
}
625620
}
626621

627622
// -----
628623

629-
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
624+
module {
630625
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
631-
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
632-
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
633-
return
626+
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
627+
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
628+
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
629+
return
630+
}
634631
}
635632

636633
// -----

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -310,21 +310,27 @@ func.func @test_identity(%arg0: tensor<13x21x3xi4>) -> tensor<13x21x3xi4> {
310310
}
311311

312312
// -----
313-
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
313+
module {
314314
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
315315
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
316-
// expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
317-
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
318-
return
316+
317+
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
318+
// expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
319+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
320+
return
321+
}
319322
}
320323

321324
// -----
322-
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
325+
module {
323326
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
324327
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
325-
// expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
326-
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
327-
return
328+
329+
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
330+
// expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
331+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
332+
return
333+
}
328334
}
329335

330336
// -----

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,14 +1097,17 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %
10971097

10981098
// -----
10991099

1100-
func.func @test_variable_read_write_tensor_size_invalid() -> () {
1100+
module {
11011101
// expected-error@+1 {{'tosa.variable' op failed level check: variable type tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
11021102
tosa.variable @stored_var : tensor<536870912xf32>
1103-
// expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1104-
%0 = tosa.variable_read @stored_var : tensor<536870912xf32>
1105-
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1106-
tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
1107-
return
1103+
1104+
func.func @test_variable_read_write_tensor_size_invalid() -> () {
1105+
// expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1106+
%0 = tosa.variable_read @stored_var : tensor<536870912xf32>
1107+
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1108+
tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
1109+
return
1110+
}
11081111
}
11091112

11101113
// -----
@@ -1165,14 +1168,17 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
11651168

11661169
// -----
11671170

1168-
func.func @test_variable_read_write_rank_invalid() -> () {
1171+
module {
11691172
// expected-error@+1 {{'tosa.variable' op failed level check: variable type rank(shape) <= MAX_RANK}}
11701173
tosa.variable @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
1171-
// expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
1172-
%0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
1173-
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
1174-
tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
1175-
return
1174+
1175+
func.func @test_variable_read_write_rank_invalid() -> () {
1176+
// expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
1177+
%0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
1178+
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
1179+
tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
1180+
return
1181+
}
11761182
}
11771183

11781184
// -----

0 commit comments

Comments
 (0)