1515from torch .export .exported_program import ExportGraphSignature
1616from torch .fx .node import Node
1717from torch .utils import _pytree as pytree
18+ from torch .fx .passes .infra .pass_base import PassResult
1819
1920
2021# pyre-ignore
@@ -44,19 +45,9 @@ def _is_mutable_buffer(
4445 if fqn in graph_signature .buffers_to_mutate .values ():
4546 return True
4647 return False
47-
48-
49- class SpecPropPass (ExportPass ):
50- def __init__ (self ) -> None :
51- super ().__init__ ()
52-
53- def on_attr (self , attr : ProxyValue ) -> None :
54- attr .node .meta ["spec" ] = pytree .tree_map_only (
55- torch .Tensor ,
56- make_spec ,
57- attr .data ,
58- )
59-
48+ class SpecPropPass :
49+ def __call__ (self , gm : torch .fx .GraphModule ) -> PassResult :
50+ return spec_prop_pass (gm )
6051 def update_placeholder_tensor_specs (
6152 self ,
6253 exported_program : torch .export .ExportedProgram ,
@@ -83,71 +74,16 @@ def update_placeholder_tensor_specs(
8374 ):
8475 spec .const = True
8576
86- # pyre-ignore
87- def placeholder (self , name : str , arg , meta ):
88- meta ["spec" ] = make_spec (arg )
89- return super ().placeholder (name , arg , meta )
90-
91- # pyre-ignore
92- def call_operator (self , op , args , kwargs , meta ):
93- args_data , kwargs_data = pytree .tree_map_only (
94- ProxyValue , lambda x : x .data , (args , kwargs )
95- )
96- meta ["spec" ] = pytree .tree_map (make_spec , op (* args_data , ** kwargs_data ))
97- return super ().call_operator (op , args , kwargs , meta )
98-
99- # pyre-ignore
100- def call_getitem (self , value , key : int , meta ):
101- meta ["spec" ] = value .node .meta ["spec" ][key ]
102- return super ().call_getitem (value , key , meta )
103-
104- # pyre-ignore
105- def call_cond (self , pred , true_fn , false_fn , inputs , meta ):
106- # true_fn/false_fn return tensors of the same shape, so we can pick
107- # either one here.
108- * _ , true_out_node = true_fn .graph .nodes
109- meta ["spec" ] = pytree .tree_map (make_spec , true_out_node .meta ["val" ])
110- return super ().call_cond (pred , true_fn , false_fn , inputs , meta )
111-
112- def call_map (
113- self ,
114- f : torch .fx .GraphModule ,
115- mapped_args : List [ProxyValue ],
116- operands : List [ProxyValue ],
117- meta : NodeMetadata ,
118- ) -> ProxyValue :
119- mapped_dim_size = [arg .data for arg in mapped_args ][0 ].size (0 )
120- * _ , body_out_node = f .graph .nodes
121- body_out_node_fake_tensor = body_out_node .meta ["val" ]
122- map_fake_tensor = pytree .tree_map_only (
123- torch .Tensor ,
124- lambda x : x .new_empty (mapped_dim_size , * x .shape ),
125- body_out_node_fake_tensor ,
126- )
127- meta ["spec" ] = pytree .tree_map (make_spec , map_fake_tensor )
128- return super ().call_map (f , mapped_args , operands , meta )
129-
130- # pyre-ignore
131- def call_delegate (self , lowered_module , args , kwargs , meta ):
132- args_data , kwargs_data = pytree .tree_map_only (
133- ProxyValue , lambda x : x .data , (args , kwargs )
134- )
135- # If spec is missing, re-genenrate it with args data
136- if "spec" not in meta :
137- meta ["spec" ] = pytree .tree_map (
138- make_spec ,
139- executorch_call_delegate (lowered_module , * args_data ),
140- )
141- return super ().call_delegate (lowered_module , args , kwargs , meta )
142-
143- # pyre-ignore
144- def output (self , results , meta ):
145- # pyre-ignore
146- def get_spec (x ):
147- if isinstance (x , ProxyValue ):
148- return x .node .meta ["spec" ]
149- else :
150- return make_spec (x )
151-
152- meta ["spec" ] = pytree .tree_map (get_spec , results )
153- return super ().output (results , meta )
77+ def spec_prop_pass (gm : torch .fx .GraphModule ) -> PassResult :
78+ # Update all the meta["val"]
79+ pass_result = ExportPass ()(gm )
80+ assert pass_result is not None
81+ gm = pass_result .graph_module
82+ # set node.meta["spec"] based on meta["val"]
83+ for module in gm .modules ():
84+ if isinstance (module , torch .fx .GraphModule ):
85+ for node in module .graph .nodes :
86+ if node .op == "get_attr" :
87+ continue
88+ node .meta ["spec" ] = pytree .tree_map (lambda meta_val : make_spec (meta_val ), node .meta ["val" ])
89+ return pass_result
0 commit comments