@@ -96,24 +96,23 @@ def named_graphmodules(gm: fx.GraphModule) -> Iterator[Tuple[str, fx.GraphModule
9696 yield name , m
9797
9898
99- def _move_single_gm_to_device (gm : GraphModule , device : torch .device ) -> None :
99+ def _move_single_gm_to_device (
100+ gm : GraphModule , device : torch .device , recompile_graph : bool = False
101+ ) -> None :
100102 """Move one GraphModule and its nodes to the specified device in-place.
101103 Partially inspired by https://github.com/pytorch/pytorch/blob/05cb98f91d49df9eadfcb3fc29bbd1b621d88860/torch/export/passes/__init__.py#L11
102104 """
103105 # move state dict
104106 gm .to (device )
105- recompile_graph = False
106107
107108 for node in gm .graph .nodes :
108109 # move all the nodes kwargs with burnt-in device
109110 if "device" in node .kwargs :
110- recompile_graph = True
111111 kwargs = node .kwargs .copy ()
112112 kwargs ["device" ] = device
113113 node .kwargs = kwargs
114114
115115 if is_op (node , torch .ops .aten .to .device ):
116- recompile_graph = True
117116 args = list (node .args )
118117 args [1 ] = device
119118 node .args = tuple (args )
@@ -136,7 +135,7 @@ def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> fx.GraphModule
136135
137136 for _ , subgm in reversed (list (named_graphmodules (gm ))):
138137 # recompile graph to update self generated codes in subgraph
139- _move_single_gm_to_device (subgm , device )
138+ _move_single_gm_to_device (subgm , device , subgm is not gm )
140139
141140
142141def _is_impure_node (node : Node ) -> bool :
0 commit comments