Skip to content

Commit 70996ea

Browse files
committed
merge code
2 parents 63084a7 + c65f7fa commit 70996ea

File tree

8 files changed

+149
-79
lines changed

8 files changed

+149
-79
lines changed

graph_net/imp_util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import importlib.util as imp
2+
3+
4+
def load_module(path, name="unamed"):
5+
spec = imp.spec_from_file_location(name, path)
6+
module = imp.module_from_spec(spec)
7+
spec.loader.exec_module(module)
8+
return module
Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
#!/bin/bash
2-
set -x
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
35

46
# input model path
57
MODEL_PATH_IN_SAMPLES=/timm/resnet18
6-
# extract subgraph 0-8, 8-16
7-
read -r -d '' json_str <<'EOF'
8+
extractor_config_json_str=$(cat <<EOF
89
{
9-
"output_dir": "/tmp/naive_decompose_workspace",
10-
"split_positions": [8, 16, 32],
11-
"group_head_and_tail": true,
12-
"chain_style": true
10+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
11+
"custom_extractor_config": {
12+
"output_dir": "/tmp/chain_naive_decompose_workspace",
13+
"split_positions": [8, 16, 32],
14+
"group_head_and_tail": true,
15+
"chain_style": true
16+
}
1317
}
1418
EOF
15-
CONFIG=$(echo $json_str | base64 -w 0)
19+
)
20+
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
1621

1722
mkdir -p /tmp/naive_decompose_workspace
18-
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
19-
os.path.dirname(graph_net.__file__))")
20-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --custom-extractor-path=$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py --custom-extractor-config=$CONFIG
23+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
#!/bin/bash
22

3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
36
# input model path
47
MODEL_PATH_IN_SAMPLES=/timm/resnet18
5-
read -r -d '' json_str <<'EOF'
8+
extractor_config_json_str=$(cat <<EOF
69
{
7-
"output_dir": "/tmp/naive_decompose_workspace",
8-
"split_positions": [8, 32],
9-
"group_head_and_tail": true
10+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
11+
"custom_extractor_config": {
12+
"output_dir": "/tmp/naive_decompose_workspace",
13+
"split_positions": [8, 16, 32],
14+
"group_head_and_tail": true,
15+
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
16+
"filter_config": {}
17+
}
1018
}
1119
EOF
12-
CONFIG=$(echo $json_str | base64 -w 0)
20+
)
21+
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
1322

1423
mkdir -p /tmp/naive_decompose_workspace
15-
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
16-
os.path.dirname(graph_net.__file__))")
17-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --custom-extractor-path=$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py --custom-extractor-config=$CONFIG
24+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG

graph_net/torch/decompose_util.py

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@
66

77

88
def convert_to_submodules_graph(
9-
original_gm: torch.fx.GraphModule,
9+
gm: torch.fx.GraphModule,
1010
split_positions: list[int],
1111
submodule_hook=None,
1212
submodule_name_prefix="extracted_submodule",
1313
chain_style=False,
1414
group_head_and_tail=True,
1515
):
1616
"""
17-
chain_style=True: decompose original_gm into g0 * g1 * g2 * g3
17+
chain_style=True: decompose gm into g0 * g1 * g2 * g3
1818
"""
19-
original_gm = copy.deepcopy(original_gm)
19+
gm = copy.deepcopy(gm)
2020
num_placeholders = len(
21-
[node for node in original_gm.graph.nodes if node.op == "placeholder"]
21+
[node for node in gm.graph.nodes if node.op == "placeholder"]
2222
)
2323
submodules_body_nodes = [
2424
node
25-
for node in original_gm.graph.nodes
25+
for node in gm.graph.nodes
2626
if node.op
2727
not in {
2828
"placeholder",
@@ -37,13 +37,16 @@ def convert_to_submodules_graph(
3737
split_positions = [
3838
max(0, min(pos, len(submodules_body_nodes))) for pos in split_positions
3939
]
40-
range_idx2submodule_body_nodes = [
41-
submodules_body_nodes[start:end]
40+
range_idx2range = [
41+
(start, end)
4242
for i in range(len(split_positions) - 1)
4343
for start in [split_positions[i]]
4444
for end in [split_positions[i + 1]]
4545
if end > start
4646
]
47+
range_idx2submodule_body_nodes = [
48+
submodules_body_nodes[start:end] for start, end in range_idx2range
49+
]
4750

4851
def get_body_nodes(range_idx):
4952
return range_idx2submodule_body_nodes[range_idx]
@@ -54,20 +57,20 @@ def get_name2sub_submodule():
5457
)
5558
return {
5659
name: module
57-
for name, module in original_gm.named_modules()
60+
for name, module in gm.named_modules()
5861
if name in used_module_names
5962
}
6063

6164
def get_start_node_idx(range_idx):
6265
start_node = get_body_nodes(range_idx)[0]
63-
for i, node in enumerate(original_gm.graph.nodes):
66+
for i, node in enumerate(gm.graph.nodes):
6467
if node == start_node:
6568
return i
6669
raise NotImplementedError("Dead code.")
6770

6871
def get_end_node_idx(range_idx):
6972
last_node = get_body_nodes(range_idx)[-1]
70-
for i, node in enumerate(original_gm.graph.nodes):
73+
for i, node in enumerate(gm.graph.nodes):
7174
if node == last_node:
7275
return i + 1
7376
raise NotImplementedError("Dead code.")
@@ -78,23 +81,33 @@ def print_submodule_call(prompt, gm):
7881
]
7982
print(f"{prompt} ", submodule_call_stmts)
8083

84+
new_node2original_node = {}
85+
for node in gm.graph.nodes:
86+
new_node2original_node[node] = node
87+
88+
def sort_key(node):
89+
return new_node2original_node[node].name
90+
8191
for range_idx in range(len(range_idx2submodule_body_nodes)):
8292
(
8393
submodule_input_nodes,
8494
submodule_output_nodes,
8595
identity_nodes,
8696
) = _get_submodule_inputs_and_outputs(
87-
original_gm=original_gm,
97+
gm=gm,
8898
start_node_idx=get_start_node_idx(range_idx),
8999
end_node_idx=get_end_node_idx(range_idx),
90100
chain_style=chain_style,
91101
)
92102

93103
def get_input_nodes(range_idx):
94-
return submodule_input_nodes
104+
return sorted(submodule_input_nodes, key=sort_key)
95105

96106
def get_output_nodes(range_idx):
97-
return submodule_output_nodes
107+
end = range_idx2range[range_idx][1]
108+
if end >= len(submodules_body_nodes):
109+
return submodule_output_nodes
110+
return sorted(submodule_output_nodes, key=sort_key)
98111

99112
submodule_name = (
100113
f"{submodule_name_prefix}_{range_idx}"
@@ -107,7 +120,8 @@ def get_output_nodes(range_idx):
107120

108121
# Add placeholder nodes for inputs
109122
for original_node in get_input_nodes(range_idx):
110-
new_node = new_graph.placeholder(original_node.name)
123+
name = new_node2original_node[original_node].name
124+
new_node = new_graph.placeholder(name)
111125
node_map[original_node] = new_node
112126

113127
# Copy body nodes
@@ -116,9 +130,9 @@ def get_output_nodes(range_idx):
116130
node_map[original_node] = new_node
117131

118132
# Add output nodes
119-
output_args = []
120-
for original_node in get_output_nodes(range_idx):
121-
output_args.append(node_map[original_node])
133+
output_args = [
134+
node_map[original_node] for original_node in get_output_nodes(range_idx)
135+
]
122136
new_graph.output(tuple(output_args))
123137

124138
# Create the new GraphModule
@@ -127,15 +141,15 @@ def get_output_nodes(range_idx):
127141
if submodule_hook is not None:
128142
new_sub_module = submodule_hook(new_sub_module, range_idx)
129143
# Replace with submodule node
130-
original_gm.add_submodule(submodule_name, new_sub_module)
131-
with original_gm.graph.inserting_after(get_body_nodes(range_idx)[-1]):
132-
submodule_node = original_gm.graph.call_module(
144+
gm.add_submodule(submodule_name, new_sub_module)
145+
with gm.graph.inserting_after(get_body_nodes(range_idx)[-1]):
146+
submodule_node = gm.graph.call_module(
133147
submodule_name, tuple(get_input_nodes(range_idx))
134148
)
135149
prev_node = submodule_node
136150
for idx, original_output in enumerate(get_output_nodes(range_idx)):
137-
with original_gm.graph.inserting_after(prev_node):
138-
new_output_node = original_gm.graph.call_function(
151+
with gm.graph.inserting_after(prev_node):
152+
new_output_node = gm.graph.call_function(
139153
operator.getitem, (submodule_node, idx)
140154
)
141155
node_map[original_output] = new_output_node
@@ -146,31 +160,34 @@ def get_output_nodes(range_idx):
146160
for original_output in get_output_nodes(range_idx):
147161
if original_output not in identity_node_set:
148162
original_output.replace_all_uses_with(node_map[original_output])
163+
new_node2original_node[
164+
node_map[original_output]
165+
] = new_node2original_node[original_output]
149166

150167
# Erase old nodes
151168
for node in reversed(get_body_nodes(range_idx)):
152-
original_gm.graph.erase_node(node)
153-
# print_submodule_call("(fx) after Erase old nodes", original_gm)
169+
gm.graph.erase_node(node)
170+
# print_submodule_call("(fx) after Erase old nodes", gm)
154171

155-
# print_submodule_call("(fx) before recompile", original_gm)
172+
# print_submodule_call("(fx) before recompile", gm)
156173

157-
original_gm.recompile()
174+
gm.recompile()
158175

159-
# print_submodule_call("(fx) after recompile", original_gm)
176+
# print_submodule_call("(fx) after recompile", gm)
160177

161-
return original_gm
178+
return gm
162179

163180

164181
def fold_range_to_submodule(
165-
original_gm: torch.fx.GraphModule,
182+
gm: torch.fx.GraphModule,
166183
start_node_idx: int,
167184
end_node_idx: int,
168185
submodule_hook=None,
169186
submodule_name="extracted_submodule",
170187
group_head_and_tail=True,
171188
):
172189
return convert_to_submodules_graph(
173-
original_gm,
190+
gm,
174191
split_positions=[start_node_idx, end_node_idx],
175192
submodule_hook=submodule_hook,
176193
submodule_name_prefix=submodule_name,
@@ -186,7 +203,7 @@ class NodeProducedOrConsumedCountCtx:
186203

187204

188205
def _get_submodule_inputs_and_outputs(
189-
original_gm: torch.fx.GraphModule,
206+
gm: torch.fx.GraphModule,
190207
start_node_idx: int,
191208
end_node_idx: int,
192209
chain_style=False,
@@ -196,7 +213,7 @@ def _get_submodule_inputs_and_outputs(
196213
defaultdict(int),
197214
defaultdict(int),
198215
)
199-
node_list = list(original_gm.graph.nodes)
216+
node_list = list(gm.graph.nodes)
200217

201218
def get_related_node(node):
202219
for arg in node.args:
@@ -240,7 +257,7 @@ def get_related_node(node):
240257
if count_ctx.node2before_input[node] > 0
241258
if count_ctx.node2body[node] == 0
242259
if count_ctx.node2after_output[node] > 0
243-
][:1]
260+
]
244261
input_nodes_set = set(input_nodes)
245262
input_nodes = [
246263
*input_nodes,
@@ -251,5 +268,4 @@ def get_related_node(node):
251268
*output_nodes,
252269
*[node for node in identity_nodes if node not in output_nodes_set],
253270
]
254-
255271
return input_nodes, output_nodes, identity_nodes

graph_net/torch/extractor.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ def extract(
139139
dynamic=True,
140140
mut_graph_codes=None,
141141
placeholder_auto_rename=False,
142-
custom_extractor_path: str = None,
143-
custom_extractor_config: str = None,
142+
extractor_config: dict = None,
144143
):
145144
"""
146145
Extract computation graphs from PyTorch nn.Module.
@@ -210,7 +209,11 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
210209
>>>
211210
"""
212211

212+
extractor_config = make_extractor_config(extractor_config)
213+
213214
def get_graph_extractor_maker():
215+
custom_extractor_path = extractor_config["custom_extractor_path"]
216+
custom_extractor_config = extractor_config["custom_extractor_config"]
214217
if custom_extractor_path is None:
215218
return GraphExtractor
216219
import importlib.util as imp
@@ -247,3 +250,18 @@ def decorator_or_wrapper(obj):
247250
)
248251

249252
return decorator_or_wrapper
253+
254+
255+
def make_extractor_config(extractor_config):
256+
kwargs = extractor_config if extractor_config is not None else {}
257+
return make_extractor_config_impl(**kwargs)
258+
259+
260+
def make_extractor_config_impl(
261+
custom_extractor_path: str = None, custom_extractor_config: dict = None
262+
):
263+
config = custom_extractor_config if custom_extractor_config is not None else {}
264+
return {
265+
"custom_extractor_path": custom_extractor_path,
266+
"custom_extractor_config": config,
267+
}

0 commit comments

Comments
 (0)