66
77# pyre-strict
88
9- from typing import List , Optional
9+ import operator
10+ from typing import Optional
1011
1112import torch
1213from executorch .exir .delegate import executorch_call_delegate
13- from executorch .exir .pass_base import ExportPass , NodeMetadata , ProxyValue
14+ from executorch .exir .pass_base import ExportPass , ProxyValue
1415from executorch .exir .tensor import TensorSpec
1516from torch .export .exported_program import ExportGraphSignature
1617from torch .fx .node import Node
18+ from torch .fx .passes .infra .pass_base import PassResult
1719from torch .utils import _pytree as pytree
1820
1921
@@ -52,12 +54,40 @@ class SpecPropPass(ExportPass):
5254 def __init__ (self ) -> None :
5355 super ().__init__ ()
5456
55- def on_attr (self , attr : ProxyValue ) -> None :
56- attr .node .meta ["spec" ] = pytree .tree_map_only (
57- torch .Tensor ,
58- make_spec ,
59- attr .data ,
60- )
57+ def __call__ (self , graph_module : torch .fx .GraphModule ) -> PassResult :
58+ # Re-trace metadata to ensure it's up to date.
59+ res = ExportPass ()(graph_module )
60+ assert res is not None
61+ gm = res .graph_module
62+
63+ def get_spec (x ):
64+ if hasattr (x , "meta" ):
65+ return x .meta .get ("spec" , None )
66+ else :
67+ return None
68+
69+ for module in gm .modules ():
70+ if isinstance (module , torch .fx .GraphModule ):
71+ for node in module .graph .nodes :
72+ meta_val = node .meta .get ("val" , None )
73+
74+ if node .op == "output" :
75+ node .meta ["spec" ] = pytree .tree_map (get_spec , node .args [0 ])
76+ elif node .op == "call_function" and node .target == operator .getitem :
77+ value_spec = pytree .tree_map (get_spec , node .args [0 ])
78+ node .meta ["spec" ] = value_spec [node .args [1 ]]
79+ elif (
80+ node .op == "call_function"
81+ and node .target == executorch_call_delegate
82+ ):
83+ if "spec" not in node .meta :
84+ node .meta ["spec" ] = pytree .tree_map (make_spec , meta_val )
85+ else :
86+ node .meta ["spec" ] = pytree .tree_map (make_spec , meta_val )
87+ return res
88+
89+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
90+ return self (graph_module )
6191
6292 def update_placeholder_tensor_specs (
6393 self ,
@@ -84,85 +114,3 @@ def update_placeholder_tensor_specs(
84114 in exported_program .graph_signature .inputs_to_lifted_tensor_constants
85115 ):
86116 spec .const = True
87-
88- # pyre-ignore
89- def placeholder (self , name : str , arg , meta ):
90- meta ["spec" ] = make_spec (arg )
91- return super ().placeholder (name , arg , meta )
92-
93- # pyre-ignore
94- def call_operator (self , op , args , kwargs , meta ):
95- args_data , kwargs_data = pytree .tree_map_only (
96- ProxyValue , lambda x : x .data , (args , kwargs )
97- )
98- meta ["spec" ] = pytree .tree_map (make_spec , op (* args_data , ** kwargs_data ))
99- return super ().call_operator (op , args , kwargs , meta )
100-
101- # pyre-ignore
102- def call_getitem (self , value , key : int , meta ):
103- meta ["spec" ] = value .node .meta ["spec" ][key ]
104- return super ().call_getitem (value , key , meta )
105-
106- # pyre-ignore
107- def call_cond (self , pred , true_fn , false_fn , inputs , meta ):
108- # true_fn/false_fn return tensors of the same shape, so we can pick
109- # either one here.
110- * _ , true_out_node = true_fn .graph .nodes
111- meta ["spec" ] = pytree .tree_map (make_spec , true_out_node .meta ["val" ])
112- return super ().call_cond (pred , true_fn , false_fn , inputs , meta )
113-
114- def call_while (
115- self ,
116- cond_fn : torch .fx .GraphModule ,
117- body_fn : torch .fx .GraphModule ,
118- carried_inputs : List [ProxyValue ],
119- additional_inputs : List [ProxyValue ],
120- meta : NodeMetadata ,
121- ):
122- meta ["spec" ] = pytree .tree_map (make_spec , carried_inputs )
123- return super ().call_while (
124- cond_fn , body_fn , carried_inputs , additional_inputs , meta
125- )
126-
127- def call_map (
128- self ,
129- f : torch .fx .GraphModule ,
130- mapped_args : List [ProxyValue ],
131- operands : List [ProxyValue ],
132- meta : NodeMetadata ,
133- ) -> ProxyValue :
134- mapped_dim_size = [arg .data for arg in mapped_args ][0 ].size (0 )
135- * _ , body_out_node = f .graph .nodes
136- body_out_node_fake_tensor = body_out_node .meta ["val" ]
137- map_fake_tensor = pytree .tree_map_only (
138- torch .Tensor ,
139- lambda x : x .new_empty (mapped_dim_size , * x .shape ),
140- body_out_node_fake_tensor ,
141- )
142- meta ["spec" ] = pytree .tree_map (make_spec , map_fake_tensor )
143- return super ().call_map (f , mapped_args , operands , meta )
144-
145- # pyre-ignore
146- def call_delegate (self , lowered_module , args , kwargs , meta ):
147- args_data , kwargs_data = pytree .tree_map_only (
148- ProxyValue , lambda x : x .data , (args , kwargs )
149- )
150- # If spec is missing, re-genenrate it with args data
151- if "spec" not in meta :
152- meta ["spec" ] = pytree .tree_map (
153- make_spec ,
154- executorch_call_delegate (lowered_module , * args_data ),
155- )
156- return super ().call_delegate (lowered_module , args , kwargs , meta )
157-
158- # pyre-ignore
159- def output (self , results , meta ):
160- # pyre-ignore
161- def get_spec (x ):
162- if isinstance (x , ProxyValue ):
163- return x .node .meta ["spec" ]
164- else :
165- return make_spec (x )
166-
167- meta ["spec" ] = pytree .tree_map (get_spec , results )
168- return super ().output (results , meta )
0 commit comments