66
77# pyre-unsafe
88
9- from typing import Callable , Dict
9+ from typing import Callable , cast , Dict , Set
1010
1111import torch
1212from executorch .backends .arm ._passes .arm_pass_utils import create_node
1313from executorch .backends .arm .tosa_quant_utils import QuantArgs
14+ from executorch .backends .transforms .utils import delete_constant_placeholder
1415from executorch .exir import ExportedProgram
1516
1617from executorch .exir .dialects ._ops import ops as exir_ops
1718from executorch .exir .dialects .edge ._ops import EdgeOpOverload
1819
1920from executorch .exir .pass_base import ExportPass , PassResult
2021from torch .fx import GraphModule
22+ from torch .fx .node import Node
2123from torch .library import impl , Library
2224
2325lib = Library ("tosa" , "DEF" )
@@ -29,6 +31,59 @@ def _table_impl(*args, **kwargs): # pyre-ignore
2931 return args [0 ]
3032
3133
34+ class TableOps :
35+ """
36+ Helper class for finding the corresponding table operator for a given Node.
37+ """
38+
39+ def __init__ (self , exported_program : ExportedProgram ):
40+ self .exported_program = exported_program
41+
42+ # Targets that follow a straigtforward one-to-one mapping to their table op
43+ self .unary_table_ops : Dict [
44+ EdgeOpOverload , Callable [[torch .Tensor ], torch .Tensor ]
45+ ] = {
46+ exir_ops .edge .aten .exp .default : torch .exp ,
47+ exir_ops .edge .aten .floor .default : torch .floor ,
48+ exir_ops .edge .aten .log .default : torch .log ,
49+ exir_ops .edge .aten .reciprocal .default : torch .reciprocal ,
50+ exir_ops .edge .aten .rsqrt .default : torch .rsqrt ,
51+ exir_ops .edge .aten .sigmoid .default : torch .sigmoid ,
52+ exir_ops .edge .aten .tanh .default : torch .tanh ,
53+ exir_ops .edge .aten .hardsigmoid .default : torch .nn .functional .hardsigmoid ,
54+ exir_ops .edge .aten .hardswish .default : torch .nn .functional .hardswish ,
55+ }
56+
57+ # Targets that must be treated explicitly
58+ self .special_table_ops : Set [EdgeOpOverload ] = {
59+ exir_ops .edge .aten .pow .Tensor_Tensor ,
60+ }
61+
62+ def __contains__ (self , node : Node ) -> bool :
63+ return (
64+ node .target in self .unary_table_ops or node .target in self .special_table_ops
65+ )
66+
67+ def __getitem__ (self , node : Node ):
68+ target = cast (EdgeOpOverload , node .target )
69+ if target in self .unary_table_ops :
70+ return self .unary_table_ops [target ]
71+ elif target in self .special_table_ops :
72+ match target :
73+ case exir_ops .edge .aten .pow .Tensor_Tensor :
74+ # Exponent is a constant. Retrieve it from the graph and embed it into a lambda.
75+ exp_node = cast (Node , node .args [1 ])
76+ exp_name = self .exported_program .graph_signature .inputs_to_buffers [
77+ exp_node .name
78+ ]
79+ exp = self .exported_program .state_dict [exp_name ]
80+ return lambda x : torch .pow (x , exp ).flatten ()
81+ case _:
82+ raise NotImplementedError ("Unhandled table operation" )
83+ else :
84+ raise KeyError ("Table op for {target} does not exist" )
85+
86+
3287class InsertTableOpsPass (ExportPass ):
3388 """
3489 For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
@@ -37,21 +92,10 @@ class InsertTableOpsPass(ExportPass):
3792 which will be used to produce the table values in operators/op_table.py.
3893 """
3994
40- table_ops : Dict [EdgeOpOverload , Callable [[torch .Tensor ], torch .Tensor ]] = {
41- exir_ops .edge .aten .exp .default : torch .exp ,
42- exir_ops .edge .aten .floor .default : torch .floor ,
43- exir_ops .edge .aten .log .default : torch .log ,
44- exir_ops .edge .aten .reciprocal .default : torch .reciprocal ,
45- exir_ops .edge .aten .rsqrt .default : torch .rsqrt ,
46- exir_ops .edge .aten .sigmoid .default : torch .sigmoid ,
47- exir_ops .edge .aten .tanh .default : torch .tanh ,
48- exir_ops .edge .aten .hardsigmoid .default : torch .nn .functional .hardsigmoid ,
49- exir_ops .edge .aten .hardswish .default : torch .nn .functional .hardswish ,
50- }
51-
5295 def __init__ (self , exported_program : ExportedProgram ) -> None :
5396 super ().__init__ ()
5497 self .exported_program = exported_program
98+ self .table_ops = TableOps (exported_program )
5599
56100 def register_buffer (self , buffer_name : str , buffer : torch .Tensor ) -> None :
57101 """
@@ -86,7 +130,7 @@ def f(x: torch.Tensor) -> torch.Tensor:
86130 def call (self , graph_module : GraphModule ) -> PassResult :
87131 modified = False
88132 for node in graph_module .graph .nodes :
89- if node .op != "call_function" or node . target not in self .table_ops :
133+ if node .op != "call_function" or node not in self .table_ops :
90134 continue
91135 input_qparams = node .meta ["input_qparams" ]
92136 output_qparams = node .meta ["output_qparams" ]
@@ -104,7 +148,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
104148 assert len (output_qparams ) == 1
105149 # Generate table buffer
106150 buffer = self .generate_table_values (
107- torch_op = self .table_ops [node . target ],
151+ torch_op = self .table_ops [node ],
108152 in_quantargs = input_qparams [0 ],
109153 out_quantargs = output_qparams [0 ],
110154 )
@@ -115,7 +159,19 @@ def call(self, graph_module: GraphModule) -> PassResult:
115159 buffer_name = table_node .name .replace ("_default" , "" ), buffer = buffer
116160 )
117161 node .replace_all_uses_with (table_node )
118- graph_module .graph .erase_node (node )
162+
163+ if node .target in self .table_ops .special_table_ops :
164+ # The node must be treated explicitly
165+ match node .target :
166+ case exir_ops .edge .aten .pow .Tensor_Tensor :
167+ exp_node = node .args [1 ]
168+ graph_module .graph .erase_node (node )
169+ delete_constant_placeholder (self .exported_program , exp_node )
170+ case _:
171+ raise NotImplementedError ("Unhandled table operation" )
172+ else :
173+ graph_module .graph .erase_node (node )
174+
119175 table_node .meta ["input_qparams" ] = input_qparams
120176 table_node .meta ["output_qparams" ] = output_qparams
121177 modified = True
0 commit comments