1+ import torch
2+ import torch .nn as nn
3+ import torch .fx as fx
4+ from typing import Set , Dict , List
5+ import itertools
6+
7+ import argparse
8+
9+ def subrange_graph_extract (original_gm : fx .GraphModule , start_idx : int , end_idx : int ):
10+ nodes_list = list (original_gm .graph .nodes )
11+ if start_idx < 0 or end_idx >= len (nodes_list ) or start_idx > end_idx :
12+ raise ValueError ("invalid indices" )
13+
14+ selected_nodes = set (nodes_list [start_idx :end_idx + 1 ])
15+
16+ new_graph = fx .Graph ()
17+ node_map = {}
18+
19+ external_inputs = set ()
20+ for node in selected_nodes :
21+ if hasattr (node , 'args' ) and hasattr (node , 'kwargs' ):
22+ for arg in itertools .chain (node .args , node .kwargs .values ()):
23+ if isinstance (arg , fx .Node ) and arg not in selected_nodes :
24+ external_inputs .add (arg )
25+
26+ print (f"found external inputs: { [node .name for node in external_inputs ]} " )
27+
28+ for ext_node in external_inputs :
29+ placeholder_name = f"input_{ ext_node .name } "
30+ new_placeholder = new_graph .placeholder (placeholder_name )
31+ node_map [ext_node ] = new_placeholder
32+ print (f" { ext_node .name } create placeholder: { placeholder_name } " )
33+
34+ for node in selected_nodes :
35+ if node .op == "placeholder" :
36+ new_node = new_graph .placeholder (node .name )
37+ node_map [node ] = new_node
38+ print (f"kepp original placeholder: { node .name } " )
39+
40+ print (f"\n copy node (node_map length: { len (node_map )} ):" )
41+ for node in nodes_list [start_idx :end_idx + 1 ]:
42+ if node .op == "placeholder" :
43+ continue
44+
45+ print (f" copy node: { node .name } " )
46+ missing_deps = []
47+ if hasattr (node , 'args' ):
48+ for arg in node .args :
49+ if isinstance (arg , fx .Node ) and arg not in node_map :
50+ missing_deps .append (arg .name )
51+ if missing_deps :
52+ print (f" missing deps: { missing_deps } " )
53+ else :
54+ print (f" deps ok" )
55+
56+ new_node = new_graph .node_copy (node , lambda n : node_map .get (n ))
57+ node_map [node ] = new_node
58+ print (f" copy as: { new_node .name } " )
59+
60+ last_node = nodes_list [end_idx ]
61+ if last_node .op != "output" :
62+ new_graph .output (node_map [last_node ])
63+ else :
64+ new_graph .output (node_map .get (last_node .args [0 ], last_node .args [0 ]))
65+
66+ return fx .GraphModule (dict (original_gm .named_modules ()), new_graph )
0 commit comments