Skip to content

Commit 1509d7c

Browse files
committed
[mlir][tosa] Apply 'Symbol' trait to tosa.variable
Implement SymbolOpInterface on tosa.variable so that it's declaration is automatically inserted into its parents SymbolTable. Verifiers for tosa.variable_read/write can now look up the symbol and guarantee it exists, and duplicate names are caught at creation time. Previously this was completed by walking the graph which could be inefficient. Unfortunately, the Symbol trait expects to find a symbol name via a hard-coded attribute name "sym_name". Therefore, "name" is renamed to"sym_name" and a getName() wrapper is provided for backwards compatibility. This change also restricts tosa.variable declarations to ops that carry a SymbolTable (e.g. modules), rather than allowing them to be placed inside a func.func. Note: EXT-VARIABLE is an experimental extension in the TOSA specification, so is not subject to backwards compatibility guarantees. Change-Id: I00a3f8f3b3b4f68cb3c120fe2c928d7b74b214cb
1 parent e92b7e9 commit 1509d7c

File tree

8 files changed

+197
-185
lines changed

8 files changed

+197
-185
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,9 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
201201
// and optional initial value. The builder will extract var_shape and element type
202202
// attributes from variable type.
203203
def Tosa_VariableOpBuilder : OpBuilder<
204-
(ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
204+
(ins "StringRef":$sym_name, "Type":$variable_type, "Attribute":$initial_value),
205205
[{
206-
buildVariableOp($_builder, $_state, name, variable_type, initial_value);
206+
buildVariableOp($_builder, $_state, sym_name, variable_type, initial_value);
207207
}]>;
208208

209209

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
include "mlir/IR/OpBase.td"
1919

2020
include "mlir/Interfaces/SideEffectInterfaces.td"
21+
include "mlir/IR/SymbolInterfaces.td"
2122
include "mlir/Interfaces/LoopLikeInterface.td"
2223
include "mlir/Interfaces/VectorInterfaces.td"
2324
include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
@@ -82,7 +83,7 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
8283
//===----------------------------------------------------------------------===//
8384
// Operator: variable
8485
//===----------------------------------------------------------------------===//
85-
def Tosa_VariableOp : Tosa_Op<"variable", []> {
86+
def Tosa_VariableOp : Tosa_Op<"variable", [Symbol]> {
8687
let summary = "Defines a variable";
8788

8889
let description = [{
@@ -91,7 +92,10 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
9192
}];
9293

9394
let arguments = (ins
94-
SymbolNameAttr:$name,
95+
// Note: "sym_name" is used as opposed to "name" in the specification,
96+
// since a Symbol must be named "sym_name" for it to be recognised by
97+
// the containing SymbolTable.
98+
SymbolNameAttr:$sym_name,
9599
IndexElementsAttr:$var_shape,
96100
TypeAttr:$type,
97101
OptionalAttr<AnyAttr>:$initial_value
@@ -105,14 +109,18 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
105109
let hasCustomAssemblyFormat = 1;
106110

107111
let assemblyFormat = [{
108-
$name
112+
$sym_name
109113
attr-dict
110114
custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
111115
}];
112116

113117
let builders = [Tosa_VariableOpBuilder];
114118

115-
let hasVerifier = 1;
119+
let extraClassDeclaration = [{
120+
::llvm::StringRef getName() {
121+
return getSymName();
122+
}
123+
}];
116124
}
117125

118126
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 15 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -667,56 +667,29 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
667667
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
668668
}
669669

670-
// Returns the first declaration point prior to this operation or failure if
671-
// not found.
672-
static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
673-
StringRef symName) {
674-
ModuleOp module = op->getParentOfType<ModuleOp>();
675-
tosa::VariableOp varOp = nullptr;
676-
677-
// TODO: Adopt SymbolTable trait to Varible ops.
678-
// Currently, the variable's definition point is searched via walk(),
679-
// starting from the top-level ModuleOp and stopping at the point of use. Once
680-
// TOSA control flow and variable extensions reach the complete state, may
681-
// leverage MLIR's Symbol Table functionality to look up symbol and enhance
682-
// the search to a TOSA specific graph traversal over the IR structure.
683-
module.walk([&](Operation *tempOp) {
684-
// Reach this op itself.
685-
if (tempOp == op) {
686-
return WalkResult::interrupt();
687-
}
688-
689-
if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
690-
if (symName == tosaOp.getName()) {
691-
varOp = tosaOp;
692-
return WalkResult::interrupt();
693-
}
694-
}
695-
696-
return WalkResult::advance();
697-
});
698-
699-
if (varOp)
700-
return varOp;
701-
702-
return failure();
703-
}
704-
705670
template <typename T>
706671
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
707-
StringRef symName = op.getName();
708-
FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
709-
if (failed(varOp))
672+
Operation *symTableOp =
673+
op->template getParentWithTrait<OpTrait::SymbolTable>();
674+
if (!symTableOp)
675+
// If the operation is not the scope of a symbol table, we cannot
676+
// verify it against it's declaration.
677+
return success();
678+
679+
SymbolTable symTable(symTableOp);
680+
const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
681+
682+
// Verify prior declaration
683+
if (!varOp)
710684
return op->emitOpError("'")
711-
<< symName << "' has not been declared by 'tosa.variable'";
685+
<< op.getName() << "' has not been declared by 'tosa.variable'";
712686

713687
// Verify type and shape
714-
auto variableType = getVariableType(varOp.value());
688+
auto variableType = getVariableType(varOp);
715689
if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
716690
"the input tensor")
717691
.failed())
718692
return failure();
719-
720693
return success();
721694
}
722695

@@ -1180,7 +1153,7 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result,
11801153
ArrayRef<int64_t> shape = shapedType.getShape();
11811154
auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
11821155

1183-
result.addAttribute("name", nameAttr);
1156+
result.addAttribute("sym_name", nameAttr);
11841157
result.addAttribute("var_shape", varShapeAttr);
11851158
result.addAttribute("type", elementTypeAttr);
11861159
result.addAttribute("initial_value", initialValue);
@@ -3908,16 +3881,6 @@ LogicalResult tosa::SelectOp::verify() {
39083881
return success();
39093882
}
39103883

3911-
LogicalResult tosa::VariableOp::verify() {
3912-
StringRef symName = getName();
3913-
FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
3914-
if (succeeded(varOp))
3915-
return emitOpError("illegal to have multiple declaration of '")
3916-
<< symName << "'";
3917-
3918-
return success();
3919-
}
3920-
39213884
LogicalResult tosa::VariableReadOp::verify() {
39223885
if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
39233886
.failed())

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -564,64 +564,61 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
564564

565565
// -----
566566

567-
func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () {
567+
module {
568568
tosa.variable @stored_var : tensor<*xi8>
569569
// expected-error@+1 {{custom op 'tosa.variable' expected ranked type}}
570-
return
571570
}
572571

573572
// -----
574573

575-
func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () {
574+
module {
576575
// expected-error@+1 {{elements literal type must have static shape}}
577576
tosa.variable @stored_var = dense<0> : tensor<*xi8>
578577
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
579-
return
580-
}
581-
582-
// -----
583-
584-
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
585-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
586-
// expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
587-
tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
588-
return
589578
}
590579

591580
// -----
592581

593-
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
582+
module {
594583
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
595-
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
596-
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
597-
return
584+
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
585+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
586+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
587+
return
588+
}
598589
}
599590

600591
// -----
601592

602-
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
593+
module {
603594
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
604-
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
605-
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
606-
return
595+
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
596+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
597+
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
598+
return
599+
}
607600
}
608601

609602
// -----
610603

611-
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
604+
module {
612605
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
613-
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
614-
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
615-
return
606+
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
607+
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
608+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
609+
return
610+
}
616611
}
617612

618613
// -----
619614

620-
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
615+
module {
621616
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
622-
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
623-
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
624-
return
617+
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
618+
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
619+
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
620+
return
621+
}
625622
}
626623

627624
// -----

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -310,21 +310,27 @@ func.func @test_identity(%arg0: tensor<13x21x3xi4>) -> tensor<13x21x3xi4> {
310310
}
311311

312312
// -----
313-
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
313+
module {
314314
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
315315
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
316-
// expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
317-
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
318-
return
316+
317+
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
318+
// expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
319+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
320+
return
321+
}
319322
}
320323

321324
// -----
322-
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
325+
module {
323326
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
324327
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
325-
// expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
326-
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
327-
return
328+
329+
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
330+
// expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
331+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
332+
return
333+
}
328334
}
329335

330336
// -----

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,14 +1088,17 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %
10881088

10891089
// -----
10901090

1091-
func.func @test_variable_read_write_tensor_size_invalid() -> () {
1091+
module {
10921092
// expected-error@+1 {{'tosa.variable' op failed level check: variable type tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
10931093
tosa.variable @stored_var : tensor<536870912xf32>
1094-
// expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1095-
%0 = tosa.variable_read @stored_var : tensor<536870912xf32>
1096-
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1097-
tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
1098-
return
1094+
1095+
func.func @test_variable_read_write_tensor_size_invalid() -> () {
1096+
// expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1097+
%0 = tosa.variable_read @stored_var : tensor<536870912xf32>
1098+
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1099+
tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
1100+
return
1101+
}
10991102
}
11001103

11011104
// -----
@@ -1156,14 +1159,17 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
11561159

11571160
// -----
11581161

1159-
func.func @test_variable_read_write_rank_invalid() -> () {
1162+
module {
11601163
// expected-error@+1 {{'tosa.variable' op failed level check: variable type rank(shape) <= MAX_RANK}}
11611164
tosa.variable @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
1162-
// expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
1163-
%0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
1164-
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
1165-
tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
1166-
return
1165+
1166+
func.func @test_variable_read_write_rank_invalid() -> () {
1167+
// expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
1168+
%0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
1169+
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
1170+
tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
1171+
return
1172+
}
11671173
}
11681174

11691175
// -----

0 commit comments

Comments
 (0)