55
66# pyre-unsafe
77
8- from typing import Callable , Dict
8+ from itertools import chain
9+ from typing import Callable , cast , Dict , Iterator , Set
910
1011import torch
1112from executorch .backends .arm ._passes .arm_pass_utils import create_node
1718
1819from executorch .exir .pass_base import ExportPass , PassResult
1920from torch .fx import GraphModule
20-
21+ from torch . fx . node import Node
2122from torch .library import impl , Library
2223
2324lib = Library ("tosa" , "DEF" )
@@ -32,15 +33,13 @@ def _table_impl(*args, **kwargs): # pyre-ignore
3233 return args [0 ].to (dtype = torch .int32 )
3334
3435
35- class InsertTableOpsPass ( ExportPass ) :
36+ class TableOps :
3637 """
37- For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
38- edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
39- When lowering the _table node target_str will be used to find the corresponding torch operator
40- which will be used to produce the table values in operators/op_table.py.
38+ Helper class for finding the corresponding table operator for a given Node.
4139 """
4240
43- table_ops : Dict [EdgeOpOverload , Callable [[torch .Tensor ], torch .Tensor ]] = {
41+ # Targets that follow a straigtforward one-to-one mapping to their table op
42+ unary_table_ops : Dict [EdgeOpOverload , Callable [[torch .Tensor ], torch .Tensor ]] = {
4443 exir_ops .edge .aten .ceil .default : torch .ceil ,
4544 exir_ops .edge .aten .exp .default : torch .exp ,
4645 exir_ops .edge .aten .floor .default : torch .floor ,
@@ -53,9 +52,52 @@ class InsertTableOpsPass(ExportPass):
5352 exir_ops .edge .aten .hardswish .default : torch .nn .functional .hardswish ,
5453 }
5554
55+ # Targets that must be treated explicitly
56+ special_table_ops : Set [EdgeOpOverload ] = {
57+ exir_ops .edge .aten .pow .Tensor_Scalar ,
58+ }
59+
60+ def __init__ (self , exported_program : ExportedProgram ):
61+ self .exported_program = exported_program
62+
63+ def __contains__ (self , node : Node ) -> bool :
64+ return (
65+ node .target in self .unary_table_ops or node .target in self .special_table_ops
66+ )
67+
68+ def __getitem__ (self , node : Node ):
69+ target = cast (EdgeOpOverload , node .target )
70+ if target in self .unary_table_ops :
71+ return self .unary_table_ops [target ]
72+ elif target in self .special_table_ops :
73+ match target :
74+ case exir_ops .edge .aten .pow .Tensor_Scalar :
75+ # Exponent is a constant. Embed it into a lambda.
76+ exp = cast (int , node .args [1 ])
77+ return lambda x : torch .pow (x , exp ).flatten ()
78+ case _:
79+ # Op must be handled if it's inside self.special_ops
80+ raise AssertionError ("Unhandled table operation" )
81+ else :
82+ raise KeyError ("Table op for {target} does not exist" )
83+
84+ @staticmethod
85+ def included_ops () -> Iterator [EdgeOpOverload ]:
86+ return chain (TableOps .unary_table_ops , TableOps .special_table_ops )
87+
88+
89+ class InsertTableOpsPass (ExportPass ):
90+ """
91+ For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
92+ edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
93+ When lowering the _table node target_str will be used to find the corresponding torch operator
94+ which will be used to produce the table values in operators/op_table.py.
95+ """
96+
5697 def __init__ (self , exported_program : ExportedProgram ) -> None :
5798 super ().__init__ ()
5899 self .exported_program = exported_program
100+ self .table_ops = TableOps (exported_program )
59101
60102 def register_buffer (self , buffer_name : str , buffer : torch .Tensor ) -> None :
61103 """
@@ -166,7 +208,7 @@ def generate_table_values(
166208 def call (self , graph_module : GraphModule ) -> PassResult :
167209 modified = False
168210 for node in graph_module .graph .nodes :
169- if node .op != "call_function" or node . target not in self .table_ops :
211+ if node .op != "call_function" or node not in self .table_ops :
170212 continue
171213 input_qparams = node .meta ["input_qparams" ]
172214 output_qparams = node .meta ["output_qparams" ]
@@ -186,7 +228,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
186228
187229 # Generate table buffer and how much to lshift the table output.
188230 buffer , lshift = self .generate_table_values (
189- torch_op = self .table_ops [node . target ],
231+ torch_op = self .table_ops [node ],
190232 in_quantargs = input_qparams [0 ],
191233 out_quantargs = output_qparams [0 ],
192234 )
@@ -207,7 +249,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
207249 output_node = rescale_node
208250
209251 node .replace_all_uses_with (output_node )
252+
210253 graph_module .graph .erase_node (node )
254+
211255 output_node .meta ["input_qparams" ] = input_qparams
212256 output_node .meta ["output_qparams" ] = output_qparams
213257 modified = True
0 commit comments