Skip to content

Commit 10193c2

Browse files
committed
subrange_graph_extract
1 parent 51d0d7e commit 10193c2

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

graph_net/torch/subrange_graph.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.fx as fx
4+
from typing import Set, Dict, List
5+
import itertools
6+
7+
import argparse
8+
9+
def subrange_graph_extract(original_gm: fx.GraphModule, start_idx: int, end_idx: int):
10+
nodes_list = list(original_gm.graph.nodes)
11+
if start_idx < 0 or end_idx >= len(nodes_list) or start_idx > end_idx:
12+
raise ValueError("invalid indices")
13+
14+
selected_nodes = set(nodes_list[start_idx:end_idx + 1])
15+
16+
new_graph = fx.Graph()
17+
node_map = {}
18+
19+
external_inputs = set()
20+
for node in selected_nodes:
21+
if hasattr(node, 'args') and hasattr(node, 'kwargs'):
22+
for arg in itertools.chain(node.args, node.kwargs.values()):
23+
if isinstance(arg, fx.Node) and arg not in selected_nodes:
24+
external_inputs.add(arg)
25+
26+
print(f"found external inputs: {[node.name for node in external_inputs]}")
27+
28+
for ext_node in external_inputs:
29+
placeholder_name = f"input_{ext_node.name}"
30+
new_placeholder = new_graph.placeholder(placeholder_name)
31+
node_map[ext_node] = new_placeholder
32+
print(f" {ext_node.name} create placeholder: {placeholder_name}")
33+
34+
for node in selected_nodes:
35+
if node.op == "placeholder":
36+
new_node = new_graph.placeholder(node.name)
37+
node_map[node] = new_node
38+
print(f"kepp original placeholder: {node.name}")
39+
40+
print(f"\ncopy node (node_map length: {len(node_map)}):")
41+
for node in nodes_list[start_idx:end_idx + 1]:
42+
if node.op == "placeholder":
43+
continue
44+
45+
print(f" copy node: {node.name}")
46+
missing_deps = []
47+
if hasattr(node, 'args'):
48+
for arg in node.args:
49+
if isinstance(arg, fx.Node) and arg not in node_map:
50+
missing_deps.append(arg.name)
51+
if missing_deps:
52+
print(f" missing deps: {missing_deps}")
53+
else:
54+
print(f" deps ok")
55+
56+
new_node = new_graph.node_copy(node, lambda n: node_map.get(n))
57+
node_map[node] = new_node
58+
print(f" copy as: {new_node.name}")
59+
60+
last_node = nodes_list[end_idx]
61+
if last_node.op != "output":
62+
new_graph.output(node_map[last_node])
63+
else:
64+
new_graph.output(node_map.get(last_node.args[0], last_node.args[0]))
65+
66+
return fx.GraphModule(dict(original_gm.named_modules()), new_graph)

0 commit comments

Comments
 (0)