Skip to content

Commit ea149bd

Browse files
committed
Use default overloads when calling custom ops
If a node is created without specifying an overload, A OpOverloadPacket is created, rather than an OpOverload. This works in a GraphModule, but the OpOverloadPacket is not a valid operator type in the _EXIREdgeDialectVerifier, which means that Edge ExportedPrograms can't contain a GraphModule with such ops. In short, specifying using the default overload seems to be the more correct way of creating a custom operator. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I3a1733c0ae88826d88b1e820eaacff765df7fbd2
1 parent f4e77c7 commit ea149bd

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def insert_input_transpose(node, input_node, graph_module):
118118
with graph_module.graph.inserting_before(node):
119119
permute_node = create_node(
120120
graph_module.graph,
121-
torch.ops.passthrough_to_tosa._transpose,
121+
torch.ops.passthrough_to_tosa._transpose.default,
122122
args=(
123123
input_node,
124124
list(AnnotateChannelsLastDimOrder.NHWC_inverse_order),
@@ -137,7 +137,7 @@ def insert_output_transpose(node, graph_module):
137137
with graph_module.graph.inserting_after(node):
138138
permute_node = create_node(
139139
graph_module.graph,
140-
torch.ops.passthrough_to_tosa._transpose,
140+
torch.ops.passthrough_to_tosa._transpose.default,
141141
args=(node, list(AnnotateChannelsLastDimOrder.NHWC_order)),
142142
)
143143
permute_node.meta["tosa_dim_order"] = (

backends/arm/_passes/insert_table_ops.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -92,7 +92,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
9292
with graph_module.graph.inserting_before(node):
9393
table_node = create_node(
9494
graph=graph_module.graph,
95-
op_target=torch.ops.tosa._table,
95+
op_target=torch.ops.tosa._table.default,
9696
args=(node.args[0],),
9797
)
9898
assert len(input_qparams) == 1
@@ -104,7 +104,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
104104
out_quantargs=output_qparams[0],
105105
)
106106
# Register buffer in self.exported_program.state_dict
107-
self.register_buffer(buffer_name=table_node.name, buffer=buffer)
107+
# When the graph is retraced, the implementation _table is used and the suffix _default disappears from the node name
108+
# Remove it here to make it possible to find in the node_visitor
109+
self.register_buffer(
110+
buffer_name=table_node.name.replace("_default", ""), buffer=buffer
111+
)
108112
node.replace_all_uses_with(table_node)
109113
graph_module.graph.erase_node(node)
110114
table_node.meta["input_qparams"] = input_qparams

backends/arm/operators/op_table.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -21,7 +21,7 @@
2121

2222
@register_node_visitor
2323
class TableVisitor(NodeVisitor):
24-
target = "_table"
24+
target = "_table.default"
2525

2626
def define_node(
2727
self,

backends/arm/operators/op_transpose.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -25,7 +25,7 @@ class TransposeVisitor(NodeVisitor):
2525
Inserts a TOSA TRANSPOSE.
2626
"""
2727

28-
target = "_transpose"
28+
target = "_transpose.default"
2929

3030
def define_node(
3131
self,

0 commit comments

Comments
 (0)