66
77import logging
88from copy import deepcopy
9- from typing import Set
9+ from typing import Any , Set
1010
1111import executorch .backends .vulkan .utils as utils
1212
@@ -190,20 +190,24 @@ def propose_node_layout(
190190 return next (iter (valid_layouts ))
191191
192192 def should_annotate (self , node ) -> bool :
193- if not isinstance (node , torch .fx .Node ):
194- return False
195-
196- if not utils .is_tensor_node (node ):
197- return False
198-
199- # Storage type and memory layout for tensorref will be determined at runtime
200- # so there's no use in setting those attributes ahead of time.
201- if node .meta .get ("vkdg_tensorref" , False ):
202- return False
203-
204- # Skip annotating output node. The output tensors should be annotated by the
205- # time the output node is observed.
206- if node .op == "output" :
193+ if isinstance (node , torch .fx .Node ):
194+ if not utils .is_tensor_node (node ):
195+ return False
196+
197+ # Storage type and memory layout for tensorref will be determined at runtime
198+ # so there's no use in setting those attributes ahead of time.
199+ if node .meta .get ("vkdg_tensorref" , False ):
200+ return False
201+
202+ # Skip annotating output node. The output tensors should be annotated by the
203+ # time the output node is observed.
204+ if node .op == "output" :
205+ return False
206+ elif isinstance (node , (list , tuple )):
207+ return all (
208+ isinstance (n , torch .fx .Node ) and self .should_annotate (n ) for n in node
209+ )
210+ else :
207211 return False
208212
209213 return True
@@ -215,6 +219,70 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool:
215219 # time the prepack node is observed.
216220 return node .target == exir_ops .edge .et_vk .prepack .default
217221
222+ def set_or_transition_arg_node (
223+ self ,
224+ i : int ,
225+ arg : torch .fx .Node ,
226+ node : torch .fx .Node ,
227+ graph_module : torch .fx .GraphModule ,
228+ dirty : bool ,
229+ ) -> bool :
230+ assert isinstance (arg , torch .fx .Node )
231+
232+ storage = utils .get_node_storage_type (node )
233+ assert storage is not None
234+ layout = utils .get_node_memory_layout (node )
235+ assert layout is not None
236+
237+ arg_storage = utils .get_node_storage_type (arg )
238+ arg_layout = utils .get_node_memory_layout (arg )
239+
240+ if arg_storage is None :
241+ utils .set_node_spec_attr (arg , "vk_storage_type" , storage )
242+ arg_storage = storage
243+ if arg_layout is None :
244+ utils .set_node_spec_attr (arg , "vk_memory_layout" , layout )
245+ arg_layout = layout
246+
247+ if arg_storage == storage and arg_layout == layout :
248+ return False
249+
250+ if not dirty :
251+ logger .info (
252+ f"[Vulkan Delegate] Inserting transition(s) for { node .format_node ()} :"
253+ )
254+
255+ insert_transition_node (graph_module , node , arg , storage , layout )
256+
257+ logger .info (
258+ f" args { i } ({ arg } ): ({ arg_storage } , { arg_layout } ) -> ({ storage } , { layout } )"
259+ )
260+
261+ return True
262+
263+ def set_or_transition_arg (
264+ self ,
265+ i : int ,
266+ arg : Any ,
267+ node : torch .fx .Node ,
268+ graph_module : torch .fx .GraphModule ,
269+ dirty : bool ,
270+ ) -> bool :
271+ if isinstance (arg , torch .fx .Node ):
272+ return self .set_or_transition_arg_node (i , arg , node , graph_module , dirty )
273+ elif isinstance (arg , (list , tuple )):
274+ need_transition = False
275+ for arg_node in arg :
276+ need_transition = (
277+ self .set_or_transition_arg_node (
278+ i , arg_node , node , graph_module , need_transition
279+ )
280+ or need_transition
281+ )
282+ return need_transition
283+ else :
284+ return False
285+
218286 # noqa
219287 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
220288 for node in graph_module .graph .nodes :
@@ -226,36 +294,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
226294
227295 set_memory_metadata (node , storage , layout )
228296
229- inserting_transitions_for_node = False
297+ need_transition = False
230298 for i , arg in enumerate (node .args ):
231299 if not self .should_annotate (arg ):
232300 continue
233301
234- assert isinstance (arg , torch .fx .Node )
235-
236- arg_storage = utils .get_node_storage_type (arg )
237- arg_layout = utils .get_node_memory_layout (arg )
238-
239- if arg_storage is None :
240- utils .set_node_spec_attr (arg , "vk_storage_type" , storage )
241- arg_storage = storage
242- if arg_layout is None :
243- utils .set_node_spec_attr (arg , "vk_memory_layout" , layout )
244- arg_layout = layout
245-
246- if arg_storage == storage and arg_layout == layout :
247- continue
248-
249- if not inserting_transitions_for_node :
250- inserting_transitions_for_node = True
251- logger .info (
252- f"[Vulkan Delegate] Inserting transition(s) for { node .format_node ()} :"
302+ need_transition = (
303+ self .set_or_transition_arg (
304+ i , arg , node , graph_module , need_transition
253305 )
254-
255- insert_transition_node (graph_module , node , arg , storage , layout )
256-
257- logger .info (
258- f" args { i } ({ arg } ): ({ arg_storage } , { arg_layout } ) -> ({ storage } , { layout } )"
306+ or need_transition
259307 )
260308
261309 return PassResult (graph_module , True )
0 commit comments