Skip to content

Commit ea44b6a

Browse files
committed
[mlir][tosa] Align Variable ops to match with TOSA v1.0 spec
* updated $name to $uid * updated AnyType:$value to Tosa_Tensor:$input1 and Tosa_Tensor:$output1 for VariableWrite and VriableRead Operators * updated description discrepancies * note: in the TOSA spec, we had var_shape attr, but it's already included in the TypeAttr:$type in MLIR Signed-off-by: Jerry Ge <[email protected]> Change-Id: I4cd0348cd4e306dbc2e0e53a89a9404d91fb44d4
1 parent eaca60d commit ea44b6a

File tree

3 files changed

+22
-22
lines changed

3 files changed

+22
-22
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,13 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
8686
let summary = "Defines a variable";
8787

8888
let description = [{
89-
Defines a new TOSA variable. This is a mutable value.
89+
Defines a new TOSA variable.
90+
This is a persistent mutable value across multiple TOSA graph invocations.
9091
Modifications are expressed using read/write semantics.
9192
}];
9293

9394
let arguments = (ins
94-
SymbolNameAttr:$name,
95+
SymbolNameAttr:$uid,
9596
TypeAttr:$type,
9697
OptionalAttr<AnyAttr>:$initial_value
9798
);
@@ -102,7 +103,7 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
102103
];
103104

104105
let assemblyFormat = [{
105-
$name
106+
$uid
106107
attr-dict
107108
custom<TypeOrAttr>($type, $initial_value)
108109
}];
@@ -115,12 +116,12 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
115116
let summary = "write_buffer operator";
116117

117118
let description = [{
118-
Assigns a value to pseudo-buffer resource holding a mutable tensor.
119+
Assigns a value to the pseudo-buffer resource holding a persistent mutable tensor.
119120
}];
120121

121122
let arguments = (ins
122-
SymbolNameAttr:$name,
123-
AnyType:$value
123+
SymbolNameAttr:$uid,
124+
Tosa_Tensor:$input1
124125
);
125126

126127
list<Availability> availability = [
@@ -129,7 +130,7 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
129130
];
130131

131132
let assemblyFormat = [{
132-
$name attr-dict `,` $value `:` type($value)
133+
$uid attr-dict `,` $input1 `:` type($input1)
133134
}];
134135
}
135136

@@ -140,15 +141,15 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
140141
let summary = "read_buffer operator";
141142

142143
let description = [{
143-
Reads the value from a pseudo-buffer resource holding a mutable tensor.
144+
Reads the value from a pseudo-buffer resource holding a persistent mutable tensor.
144145
}];
145146

146147
let arguments = (ins
147-
SymbolNameAttr:$name
148+
SymbolNameAttr:$uid
148149
);
149150

150151
let results = (outs
151-
AnyType:$value
152+
Tosa_Tensor:$output1
152153
);
153154

154155
list<Availability> availability = [
@@ -157,7 +158,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
157158
];
158159

159160
let assemblyFormat = [{
160-
$name attr-dict `:` type($value)
161+
$uid attr-dict `:` type($output1)
161162
}];
162163
}
163164

mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
2727
LogicalResult matchAndRewrite(tosa::VariableOp op,
2828
PatternRewriter &rewriter) const final {
2929
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
30-
op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
30+
op.getLoc(), op.getUid(), op.getType(), /*is_mutable=*/true,
3131
op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
3232
newVariable.setPrivate();
3333
rewriter.replaceOp(op, newVariable);
@@ -43,9 +43,9 @@ class VariableWriteOpConverter
4343
LogicalResult matchAndRewrite(tosa::VariableWriteOp op,
4444
PatternRewriter &rewriter) const final {
4545
auto globalSymbolRef =
46-
SymbolRefAttr::get(rewriter.getContext(), op.getName());
46+
SymbolRefAttr::get(rewriter.getContext(), op.getUid());
4747
auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
48-
op.getLoc(), globalSymbolRef, op.getValue());
48+
op.getLoc(), globalSymbolRef, op.getInput1());
4949
rewriter.replaceOp(op, newVariableWrite);
5050
return success();
5151
}
@@ -58,7 +58,7 @@ class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
5858
LogicalResult matchAndRewrite(tosa::VariableReadOp op,
5959
PatternRewriter &rewriter) const final {
6060
auto globalSymbolRef =
61-
SymbolRefAttr::get(rewriter.getContext(), op.getName());
61+
SymbolRefAttr::get(rewriter.getContext(), op.getUid());
6262
auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
6363
op.getLoc(), op.getType(), globalSymbolRef);
6464
rewriter.replaceOp(op, newVariableRead);

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -677,17 +677,17 @@ inline bool CompatibleTypes(const mlir::Type &type,
677677

678678
bool TosaValidation::CheckVariable(Operation *op) {
679679
if (isa<mlir::tosa::VariableOp>(op)) {
680-
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
680+
mlir::StringAttr uidAttr = cast<mlir::StringAttr>(op->getAttr("uid"));
681681

682-
if (variablesMap.count(nameAttr)) {
682+
if (variablesMap.count(uidAttr)) {
683683
op->emitOpError() << "name has already been declared";
684684
return false;
685685
}
686686

687687
auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
688688
mlir::Type type = typeAttr.getValue();
689689

690-
variablesMap[nameAttr] = type;
690+
variablesMap[uidAttr] = type;
691691
}
692692

693693
return true;
@@ -696,14 +696,13 @@ bool TosaValidation::CheckVariable(Operation *op) {
696696
bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
697697
if (isa<mlir::tosa::VariableReadOp>(op) ||
698698
isa<mlir::tosa::VariableWriteOp>(op)) {
699-
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
700-
701-
if (!variablesMap.count(nameAttr)) {
699+
mlir::StringAttr uidAttr = cast<mlir::StringAttr>(op->getAttr("uid"));
700+
if (!variablesMap.count(uidAttr)) {
702701
op->emitOpError() << "name has not been declared";
703702
return false;
704703
}
705704

706-
auto varType = variablesMap[nameAttr];
705+
auto varType = variablesMap[uidAttr];
707706

708707
for (auto v : op->getOperands()) {
709708
auto type = v.getType();

0 commit comments

Comments
 (0)