@@ -96,23 +96,24 @@ def named_graphmodules(gm: fx.GraphModule) -> Iterator[Tuple[str, fx.GraphModule
9696 yield name , m
9797
9898
99- def _move_single_gm_to_device (
100- gm : GraphModule , device : torch .device , recompile_graph : bool = False
101- ) -> None :
99+ def _move_single_gm_to_device (gm : GraphModule , device : torch .device ) -> None :
102100 """Move one GraphModule and its nodes to the specified device in-place.
103101 Partially inspired by https://github.com/pytorch/pytorch/blob/05cb98f91d49df9eadfcb3fc29bbd1b621d88860/torch/export/passes/__init__.py#L11
104102 """
105103 # move state dict
106104 gm .to (device )
105+ recompile_graph = False
107106
108107 for node in gm .graph .nodes :
109108 # move all the nodes kwargs with burnt-in device
110109 if "device" in node .kwargs :
110+ recompile_graph = True
111111 kwargs = node .kwargs .copy ()
112112 kwargs ["device" ] = device
113113 node .kwargs = kwargs
114114
115115 if is_op (node , torch .ops .aten .to .device ):
116+ recompile_graph = True
116117 args = list (node .args )
117118 args [1 ] = device
118119 node .args = tuple (args )
@@ -135,7 +136,7 @@ def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> fx.GraphModule
135136
136137 for _ , subgm in reversed (list (named_graphmodules (gm ))):
137138 # recompile graph to update self generated codes in subgraph
138- _move_single_gm_to_device (subgm , device , subgm is not gm )
139+ _move_single_gm_to_device (subgm , device )
139140
140141
141142def _is_impure_node (node : Node ) -> bool :
0 commit comments