Skip to content

Commit 221b77e

Browse files
authored
Fold many ranges (#337)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * fold many subgraphs into submodules
1 parent e7c6e03 commit 221b77e

File tree

2 files changed

+125
-94
lines changed

2 files changed

+125
-94
lines changed

graph_net/test/torch_extractor_test.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,36 +19,22 @@ def forward(self, x):
1919

2020

2121
class WrapperModule(torch.nn.Module):
22-
def __init__(self, submodule):
22+
def __init__(self, submodule, seq_no):
2323
super().__init__()
2424
self.submodule = submodule
25+
self.seq_no = seq_no
2526

2627
def forward(self, *args):
2728
print("Args:")
2829
print(args)
2930
return self.submodule(*args)
3031

3132

32-
def submodule_hook(submodule: torch.fx.GraphModule):
33-
print(f"{'-'*8} [submodule] {'-'*8}\n")
33+
def submodule_hook(submodule: torch.fx.GraphModule, seq_no):
34+
print(f"{'-'*8} [submodule-{seq_no}] {'-'*8}\n")
3435
print(submodule.graph)
35-
"""
36-
graph():
37-
%add : [num_users=1] = placeholder[target=add]
38-
%mul : [num_users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
39-
%clamp : [num_users=1] = call_method[target=clamp](args = (%mul,), kwargs = {min: 0.0, max: 1.0})
40-
return (clamp,)
41-
42-
"""
4336
print(submodule.code)
44-
"""
45-
def forward(self, add):
46-
mul = add * 2; add = None
47-
clamp = mul.clamp(min = 0.0, max = 1.0); mul = None
48-
return (clamp,)
49-
"""
50-
51-
return WrapperModule(submodule)
37+
return WrapperModule(submodule, seq_no)
5238

5339

5440
class TestExtractorSubmodule(unittest.TestCase):
@@ -87,9 +73,10 @@ def forward(self, x):
8773

8874
folded = fold_range_to_submodule(
8975
symbolic_traced,
90-
start_node_idx=2,
91-
end_node_idx=4,
76+
start_node_idx=0,
77+
end_node_idx=2,
9278
submodule_hook=submodule_hook,
79+
# group_head_and_tail=False,
9380
)
9481
folded_output = folded(inp)
9582

graph_net/torch/decompose_util.py

Lines changed: 117 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5,104 +5,148 @@
55
from dataclasses import dataclass
66

77

8-
def fold_range_to_submodule(
8+
def convert_to_submodules_graph(
99
original_gm: torch.fx.GraphModule,
10-
start_node_idx: int,
11-
end_node_idx: int,
10+
split_positions: list[int],
1211
submodule_hook=None,
13-
submodule_name="extraced_submodule",
12+
submodule_name_prefix="extraced_submodule",
13+
group_head_and_tail=True,
1414
):
1515
original_gm = copy.deepcopy(original_gm)
16-
submodule_body_nodes = list(original_gm.graph.nodes)[start_node_idx:end_node_idx]
17-
18-
def get_body_nodes():
19-
return submodule_body_nodes
20-
21-
assert len(get_body_nodes()) > 0
22-
23-
for idx, original_node in enumerate(get_body_nodes()):
24-
assert original_node.op not in {
16+
num_placeholders = len(
17+
[node for node in original_gm.graph.nodes if node.op == "placeholder"]
18+
)
19+
submodules_body_nodes = [
20+
node
21+
for node in original_gm.graph.nodes
22+
if node.op
23+
not in {
2524
"placeholder",
2625
"output",
27-
}, f"{idx=}, {original_node.op=}"
28-
29-
submodule_input_nodes, submodule_output_nodes = _get_submodule_inputs_and_outputs(
30-
original_gm=original_gm,
31-
start_node_idx=start_node_idx,
32-
end_node_idx=end_node_idx,
26+
}
27+
]
28+
split_positions = (
29+
[0, *split_positions, len(submodules_body_nodes)]
30+
if group_head_and_tail
31+
else split_positions
3332
)
33+
split_positions = [
34+
max(0, min(pos, len(submodules_body_nodes))) for pos in split_positions
35+
]
36+
submodule_ranges = [
37+
(start, end)
38+
for i in range(len(split_positions) - 1)
39+
for start in [split_positions[i]]
40+
for end in [split_positions[i + 1]]
41+
if end > start
42+
]
3443

35-
def get_input_nodes():
36-
return submodule_input_nodes
37-
38-
def get_output_nodes():
39-
return submodule_output_nodes
44+
def get_body_nodes(range_idx):
45+
start, end = submodule_ranges[range_idx]
46+
return submodules_body_nodes[start:end]
4047

4148
def get_name2sub_submodule():
42-
used_module_names = set()
43-
for node in get_body_nodes():
44-
if node.op == "call_module":
45-
used_module_names.add(node.target)
49+
used_module_names = set(
50+
[node.target for node in submodules_body_nodes if node.op == "call_module"]
51+
)
4652
return {
4753
name: module
4854
for name, module in original_gm.named_modules()
4955
if name in used_module_names
5056
}
5157

52-
new_graph = torch.fx.Graph()
53-
# Create a mapping for nodes from original graph to new graph
54-
node_map = {}
55-
56-
# Add placeholder nodes for inputs
57-
for original_node in get_input_nodes():
58-
new_node = new_graph.placeholder(original_node.name)
59-
node_map[original_node] = new_node
60-
61-
# Copy body nodes
62-
for original_node in get_body_nodes():
63-
print(original_node)
64-
new_node = new_graph.node_copy(original_node, lambda x: node_map[x])
65-
node_map[original_node] = new_node
66-
67-
# Add output nodes
68-
output_args = []
69-
for original_node in get_output_nodes():
70-
output_args.append(node_map[original_node])
71-
new_graph.output(tuple(output_args))
72-
73-
# Create the new GraphModule
74-
# This assumes no submodules are being extracted, or they are handled separately
75-
new_sub_module = torch.fx.GraphModule(get_name2sub_submodule(), new_graph)
76-
if submodule_hook is not None:
77-
new_sub_module = submodule_hook(new_sub_module)
78-
# Replace with submodule node
79-
original_gm.add_submodule(submodule_name, new_sub_module)
80-
with original_gm.graph.inserting_after(get_body_nodes()[-1]):
81-
submodule_node = original_gm.graph.call_module(
82-
submodule_name, tuple(get_input_nodes())
58+
for range_idx in range(len(submodule_ranges)):
59+
start_node_idx, end_node_idx = submodule_ranges[range_idx]
60+
(
61+
submodule_input_nodes,
62+
submodule_output_nodes,
63+
) = _get_submodule_inputs_and_outputs(
64+
original_gm=original_gm,
65+
start_node_idx=(num_placeholders + start_node_idx),
66+
end_node_idx=(num_placeholders + end_node_idx),
8367
)
84-
prev_node = submodule_node
85-
for idx, original_output in enumerate(get_output_nodes()):
86-
with original_gm.graph.inserting_after(prev_node):
87-
new_output_node = original_gm.graph.call_function(
88-
operator.getitem, (submodule_node, idx)
89-
)
90-
node_map[original_output] = new_output_node
91-
prev_node = new_output_node
9268

93-
# Replace all use of outputs
94-
for original_output in get_output_nodes():
95-
original_output.replace_all_uses_with(node_map[original_output])
69+
def get_input_nodes(range_idx):
70+
return submodule_input_nodes
9671

97-
# Erase old nodes
98-
for node in reversed(get_body_nodes()):
99-
original_gm.graph.erase_node(node)
72+
def get_output_nodes(range_idx):
73+
return submodule_output_nodes
74+
75+
submodule_name = (
76+
f"{submodule_name_prefix}_{range_idx}"
77+
if range_idx > 0
78+
else submodule_name_prefix
79+
)
80+
new_graph = torch.fx.Graph()
81+
# Create a mapping for nodes from original graph to new graph
82+
node_map = {}
83+
84+
# Add placeholder nodes for inputs
85+
for original_node in get_input_nodes(range_idx):
86+
new_node = new_graph.placeholder(original_node.name)
87+
node_map[original_node] = new_node
88+
89+
# Copy body nodes
90+
for original_node in get_body_nodes(range_idx):
91+
new_node = new_graph.node_copy(original_node, lambda x: node_map[x])
92+
node_map[original_node] = new_node
93+
94+
# Add output nodes
95+
output_args = []
96+
for original_node in get_output_nodes(range_idx):
97+
output_args.append(node_map[original_node])
98+
new_graph.output(tuple(output_args))
99+
100+
# Create the new GraphModule
101+
# This assumes no submodules are being extracted, or they are handled separately
102+
new_sub_module = torch.fx.GraphModule(get_name2sub_submodule(), new_graph)
103+
if submodule_hook is not None:
104+
new_sub_module = submodule_hook(new_sub_module, range_idx)
105+
# Replace with submodule node
106+
original_gm.add_submodule(submodule_name, new_sub_module)
107+
with original_gm.graph.inserting_after(get_body_nodes(range_idx)[-1]):
108+
submodule_node = original_gm.graph.call_module(
109+
submodule_name, tuple(get_input_nodes(range_idx))
110+
)
111+
prev_node = submodule_node
112+
for idx, original_output in enumerate(get_output_nodes(range_idx)):
113+
with original_gm.graph.inserting_after(prev_node):
114+
new_output_node = original_gm.graph.call_function(
115+
operator.getitem, (submodule_node, idx)
116+
)
117+
node_map[original_output] = new_output_node
118+
prev_node = new_output_node
119+
120+
# Replace all use of outputs
121+
for original_output in get_output_nodes(range_idx):
122+
original_output.replace_all_uses_with(node_map[original_output])
123+
124+
# Erase old nodes
125+
for node in reversed(get_body_nodes(range_idx)):
126+
original_gm.graph.erase_node(node)
100127

101128
original_gm.recompile()
102129

103130
return original_gm
104131

105132

133+
def fold_range_to_submodule(
134+
original_gm: torch.fx.GraphModule,
135+
start_node_idx: int,
136+
end_node_idx: int,
137+
submodule_hook=None,
138+
submodule_name="extraced_submodule",
139+
group_head_and_tail=True,
140+
):
141+
return convert_to_submodules_graph(
142+
original_gm,
143+
split_positions=[start_node_idx, end_node_idx],
144+
submodule_hook=submodule_hook,
145+
submodule_name_prefix=submodule_name,
146+
group_head_and_tail=group_head_and_tail,
147+
)
148+
149+
106150
@dataclass
107151
class NodeProducedOrConsumedCountCtx:
108152
node2before_input: defaultdict(int)

0 commit comments

Comments
 (0)