1111import torch
1212from executorch .backends .arm ._passes .arm_pass_utils import create_node
1313from executorch .backends .arm ._passes .quant_args import QuantArgs
14+ from executorch .backends .transforms .utils import create_constant_placeholder
15+
1416from executorch .exir import ExportedProgram
1517
1618from executorch .exir .dialects ._ops import ops as exir_ops
1719from executorch .exir .dialects .edge ._ops import EdgeOpOverload
1820
1921from executorch .exir .pass_base import ExportPass , PassResult
22+ from torch .export .graph_signature import InputKind
2023from torch .fx import GraphModule
2124from torch .fx .node import Node
22- from torch .library import impl , Library
23-
24- lib = Library ("tosa" , "DEF" )
25- lib .define ("_table(Tensor self) -> Tensor" )
26-
27-
28- @impl (lib , "_table" )
29- def _table_impl (* args , ** kwargs ): # pyre-ignore
30- in_dtype = args [0 ].dtype
31- if in_dtype == torch .int8 :
32- return args [0 ]
33- return args [0 ].to (dtype = torch .int32 )
3425
3526
3627class TableOps :
@@ -242,13 +233,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
242233 # We only want to replace the node if it's quantized
243234 continue
244235 # Create table node
245- with graph_module .graph .inserting_before (node ):
246- table_node = create_node (
247- graph = graph_module .graph ,
248- op_target = torch .ops .tosa ._table .default ,
249- args = (node .args [0 ],),
250- )
251- output_node = table_node
236+ insert_pos = list (node .graph .nodes )[0 ]
237+ with graph_module .graph .inserting_before (insert_pos ):
252238 # Expect exactly one quantization parameter for input and output
253239 if len (input_qparams ) != 1 :
254240 raise ValueError (
@@ -268,27 +254,37 @@ def call(self, graph_module: GraphModule) -> PassResult:
268254 out_quantargs = output_qparams [0 ],
269255 )
270256 # Register buffer in self.exported_program.state_dict
271- # When the graph is retraced, the implementation _table is used and the suffix _default disappears from the node name
272- # Remove it here to make it possible to find in the node_visitor
273- self .register_buffer (
274- buffer_name = table_node .name .replace ("_default" , "" ), buffer = buffer
257+ const_table_node = create_constant_placeholder (
258+ exp_program = self .exported_program ,
259+ graph = node .graph ,
260+ kind = InputKind .BUFFER ,
261+ name = node .name + "_table_constant" ,
262+ data = buffer ,
263+ persistent_buffer = True ,
275264 )
276265
266+ # Create table node
267+ with graph_module .graph .inserting_before (node ):
268+ table_op_node = create_node (
269+ graph = graph_module .graph ,
270+ op_target = exir_ops .backend .tosa .TABLE .default ,
271+ args = (node .args [0 ], const_table_node ),
272+ )
273+ output_node = table_op_node
274+
277275 if lshift != 0 :
278276 scale = 2.0 ** lshift
279277 rescale_node = create_node (
280278 graph = graph_module .graph ,
281- op_target = torch . ops .tosa ._rescale .default ,
282- args = (table_node , output_qparams [0 ].dtype , scale , 0 , 0 ),
279+ op_target = exir_ops . backend .tosa .RESCALE .default ,
280+ args = (table_op_node , output_qparams [0 ].dtype , scale , 0 , 0 ),
283281 )
284282 output_node = rescale_node
285283
286284 node .replace_all_uses_with (output_node )
287-
288285 graph_module .graph .erase_node (node )
289-
290- output_node .meta ["input_qparams" ] = input_qparams
291- output_node .meta ["output_qparams" ] = output_qparams
286+ table_op_node .meta ["input_qparams" ] = input_qparams
287+ table_op_node .meta ["output_qparams" ] = output_qparams
292288 modified = True
293289
294290 if modified :
0 commit comments