10
10
import torch
11
11
from executorch .exir .pass_base import ExportPass , PassResult
12
12
13
- from .utils import copy_nn_module_stack
13
+ from .utils import merge_decomposed_graph
14
14
15
15
16
16
class DecomposeWrapWithAutocast (ExportPass ):
@@ -52,7 +52,7 @@ def _replace(self, gm: torch.fx.GraphModule) -> None:
52
52
graph = gm .graph
53
53
for node in graph .nodes :
54
54
if isinstance (node .target , torch ._higher_order_ops .wrap .WrapWithAutocast ):
55
- submod , submod_name = self ._get_submod (gm , node )
55
+ submod , _ = self ._get_submod (gm , node )
56
56
n_args = node .args
57
57
input_submod = n_args [4 ]
58
58
decomposed_module = submod
@@ -61,22 +61,13 @@ def _replace(self, gm: torch.fx.GraphModule) -> None:
61
61
# which ensures that reference to nodes are correctly updated in the new graph
62
62
# remap = {"expand_1": node.args[5], "to_4": node.args[6]}
63
63
remap = {n_args [i ].name : n_args [i ] for i in range (5 , len (n_args ))}
64
-
65
- for decomposed_node in decomposed_module .graph .nodes :
66
- copy_nn_module_stack (node , decomposed_node )
67
- # no need to copy existent 'output'
68
- if decomposed_node .op == "output" :
69
- self ._replace_output (node , decomposed_node , remap )
70
- # no need to copy existent placeholders
71
- elif decomposed_node .op == "placeholder" :
72
- # replace node map from string to graph node
73
- remap [decomposed_node ] = remap .pop (decomposed_node .name )
74
- else :
75
- remap [decomposed_node ] = graph .node_copy (
76
- decomposed_node ,
77
- arg_transform = lambda x , remap = remap : remap [x ],
78
- )
79
-
64
+ merge_decomposed_graph (
65
+ remap = remap ,
66
+ target_node = node ,
67
+ target_graph = graph ,
68
+ decomposed_graph_module = decomposed_module ,
69
+ output_processor = self ._replace_output ,
70
+ )
80
71
graph .erase_node (node )
81
72
82
73
graph .erase_node (input_submod )
0 commit comments