Skip to content

Commit 7ace5ce

Browse files
authored
[Feature Enhancement] add submodule folding func (#328)
* feat(extractor): add submodule folding func * refactor: decompose_utils * rename: decompose_util * update * style: black files
1 parent fad5ec1 commit 7ace5ce

File tree

3 files changed

+293
-6
lines changed

3 files changed

+293
-6
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import os
2+
import torch
3+
import unittest
4+
5+
from torch.fx import symbolic_trace
6+
from graph_net.torch.extractor import extract
7+
from graph_net.torch.decompose_util import fold_range_to_submodule
8+
9+
10+
# Simple module for demonstration
11+
class MyModule(torch.nn.Module):
12+
def __init__(self) -> None:
13+
super().__init__()
14+
15+
def forward(self, x):
16+
y = x + 1
17+
z = y * 2
18+
return z.clamp(min=0.0, max=1.0)
19+
20+
21+
class WrapperModule(torch.nn.Module):
22+
def __init__(self, submodule):
23+
super().__init__()
24+
self.submodule = submodule
25+
26+
def forward(self, *args):
27+
print("Args:")
28+
print(args)
29+
return self.submodule(*args)
30+
31+
32+
def submodule_hook(submodule: torch.fx.GraphModule):
33+
print(f"{'-'*8} [submodule] {'-'*8}\n")
34+
print(submodule.graph)
35+
"""
36+
graph():
37+
%add : [num_users=1] = placeholder[target=add]
38+
%mul : [num_users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
39+
%clamp : [num_users=1] = call_method[target=clamp](args = (%mul,), kwargs = {min: 0.0, max: 1.0})
40+
return (clamp,)
41+
42+
"""
43+
print(submodule.code)
44+
"""
45+
def forward(self, add):
46+
mul = add * 2; add = None
47+
clamp = mul.clamp(min = 0.0, max = 1.0); mul = None
48+
return (clamp,)
49+
"""
50+
51+
return WrapperModule(submodule)
52+
53+
54+
class TestExtractorSubmodule(unittest.TestCase):
55+
"""Test extraction of submodules from traced GraphModule."""
56+
57+
def test_sample(self):
58+
module = MyModule()
59+
60+
# Symbolic tracing frontend - captures the semantics of the module
61+
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
62+
63+
# High-level intermediate representation (IR) - Graph representation
64+
print(symbolic_traced.graph)
65+
"""
66+
graph():
67+
%x : [num_users=1] = placeholder[target=x]
68+
%add : [num_users=1] = call_function[target=operator.add](args = (%x, 1), kwargs = {})
69+
%mul : [num_users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
70+
%clamp : [num_users=1] = call_method[target=clamp](args = (%mul,), kwargs = {min: 0.0, max: 1.0})
71+
return clamp
72+
"""
73+
74+
# Code generation - valid Python code
75+
print(symbolic_traced.code)
76+
"""
77+
def forward(self, x):
78+
add = x + 1; x = None
79+
mul = add * 2; add = None
80+
clamp = mul.clamp(min = 0.0, max = 1.0); mul = None
81+
return clamp
82+
"""
83+
84+
inp = torch.tensor([1.0, 2.0, 3.0, 4.0])
85+
source_output = module(inp)
86+
traced_output = symbolic_traced(inp)
87+
88+
folded = fold_range_to_submodule(
89+
symbolic_traced,
90+
start_node_idx=2,
91+
end_node_idx=4,
92+
submodule_hook=submodule_hook,
93+
)
94+
folded_output = folded(inp)
95+
96+
print(f"{'-'*8} [folded] {'-'*8}\n")
97+
print(folded.graph)
98+
"""
99+
graph():
100+
%x : [num_users=1] = placeholder[target=x]
101+
%add : [num_users=1] = call_function[target=operator.add](args = (%x, 1), kwargs = {})
102+
%extraced_submodule : [num_users=1] = call_module[target=extraced_submodule](args = (%add,), kwargs = {})
103+
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%extraced_submodule, 0), kwargs = {})
104+
return getitem
105+
106+
"""
107+
print(folded.code)
108+
"""
109+
def forward(self, x):
110+
add = x + 1; x = None
111+
extraced_submodule = self.extraced_submodule(add); add = None
112+
getitem = extraced_submodule[0]; extraced_submodule = None
113+
return getitem
114+
"""
115+
116+
# Save to workspace, assumed workspace is ./tmp/
117+
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = "./tmp/"
118+
folded = extract("demo_test", False)(folded)
119+
wrapper_output = folded(inp)
120+
121+
self.assertTrue(torch.allclose(source_output, traced_output))
122+
self.assertTrue(torch.allclose(source_output, folded_output))
123+
self.assertTrue(torch.allclose(source_output, wrapper_output))
124+
125+
126+
if __name__ == "__main__":
127+
unittest.main()

graph_net/torch/decompose_util.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import torch
2+
import copy
3+
import operator
4+
from collections import defaultdict
5+
from dataclasses import dataclass
6+
7+
8+
def fold_range_to_submodule(
9+
original_gm: torch.fx.GraphModule,
10+
start_node_idx: int,
11+
end_node_idx: int,
12+
submodule_hook=None,
13+
submodule_name="extraced_submodule",
14+
):
15+
original_gm = copy.deepcopy(original_gm)
16+
submodule_body_nodes = list(original_gm.graph.nodes)[start_node_idx:end_node_idx]
17+
18+
def get_body_nodes():
19+
return submodule_body_nodes
20+
21+
assert len(get_body_nodes()) > 0
22+
23+
for idx, original_node in enumerate(get_body_nodes()):
24+
assert original_node.op not in {
25+
"placeholder",
26+
"output",
27+
}, f"{idx=}, {original_node.op=}"
28+
29+
submodule_input_nodes, submodule_output_nodes = _get_submodule_inputs_and_outputs(
30+
original_gm=original_gm,
31+
start_node_idx=start_node_idx,
32+
end_node_idx=end_node_idx,
33+
)
34+
35+
def get_input_nodes():
36+
return submodule_input_nodes
37+
38+
def get_output_nodes():
39+
return submodule_output_nodes
40+
41+
def get_name2sub_submodule():
42+
used_module_names = set()
43+
for node in get_body_nodes():
44+
if node.op == "call_module":
45+
used_module_names.add(node.target)
46+
return {
47+
name: module
48+
for name, module in original_gm.named_modules()
49+
if name in used_module_names
50+
}
51+
52+
new_graph = torch.fx.Graph()
53+
# Create a mapping for nodes from original graph to new graph
54+
node_map = {}
55+
56+
# Add placeholder nodes for inputs
57+
for original_node in get_input_nodes():
58+
new_node = new_graph.placeholder(original_node.name)
59+
node_map[original_node] = new_node
60+
61+
# Copy body nodes
62+
for original_node in get_body_nodes():
63+
print(original_node)
64+
new_node = new_graph.node_copy(original_node, lambda x: node_map[x])
65+
node_map[original_node] = new_node
66+
67+
# Add output nodes
68+
output_args = []
69+
for original_node in get_output_nodes():
70+
output_args.append(node_map[original_node])
71+
new_graph.output(tuple(output_args))
72+
73+
# Create the new GraphModule
74+
# This assumes no submodules are being extracted, or they are handled separately
75+
new_sub_module = torch.fx.GraphModule(get_name2sub_submodule(), new_graph)
76+
if submodule_hook is not None:
77+
new_sub_module = submodule_hook(new_sub_module)
78+
# Replace with submodule node
79+
original_gm.add_submodule(submodule_name, new_sub_module)
80+
with original_gm.graph.inserting_after(get_body_nodes()[-1]):
81+
submodule_node = original_gm.graph.call_module(
82+
submodule_name, tuple(get_input_nodes())
83+
)
84+
prev_node = submodule_node
85+
for idx, original_output in enumerate(get_output_nodes()):
86+
with original_gm.graph.inserting_after(prev_node):
87+
new_output_node = original_gm.graph.call_function(
88+
operator.getitem, (submodule_node, idx)
89+
)
90+
node_map[original_output] = new_output_node
91+
prev_node = new_output_node
92+
93+
# Replace all use of outputs
94+
for original_output in get_output_nodes():
95+
original_output.replace_all_uses_with(node_map[original_output])
96+
97+
# Erase old nodes
98+
for node in reversed(get_body_nodes()):
99+
original_gm.graph.erase_node(node)
100+
101+
original_gm.recompile()
102+
103+
return original_gm
104+
105+
106+
@dataclass
107+
class NodeProducedOrConsumedCountCtx:
108+
node2before_input: defaultdict(int)
109+
node2body: defaultdict(int)
110+
node2after_output: defaultdict(int)
111+
112+
113+
def _get_submodule_inputs_and_outputs(
114+
original_gm: torch.fx.GraphModule,
115+
start_node_idx: int,
116+
end_node_idx: int,
117+
):
118+
count_ctx = NodeProducedOrConsumedCountCtx(
119+
defaultdict(int),
120+
defaultdict(int),
121+
defaultdict(int),
122+
)
123+
node_list = list(original_gm.graph.nodes)
124+
125+
def get_related_node(node):
126+
yield from node.args
127+
yield node
128+
129+
for node in node_list[0:start_node_idx]:
130+
for related_node in get_related_node(node):
131+
count_ctx.node2before_input[related_node] += 1
132+
133+
for node in node_list[start_node_idx:end_node_idx]:
134+
for related_node in get_related_node(node):
135+
count_ctx.node2body[related_node] += 1
136+
137+
for node in node_list[end_node_idx:]:
138+
for related_node in get_related_node(node):
139+
count_ctx.node2after_output[related_node] += 1
140+
141+
input_nodes = [
142+
node
143+
for node in node_list
144+
if count_ctx.node2before_input[node] > 0
145+
if count_ctx.node2body[node] > 0
146+
]
147+
148+
output_nodes = [
149+
node
150+
for node in node_list
151+
if not (count_ctx.node2before_input[node] > 0)
152+
if count_ctx.node2body[node] > 0
153+
if count_ctx.node2after_output[node] > 0
154+
]
155+
156+
return input_nodes, output_nodes

graph_net/torch/test_compiler.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,11 @@ def test_single_model(args):
266266
expected_out = (expected_out,)
267267

268268
eager_types = [
269-
str(x.dtype).replace("torch.", "")
270-
if isinstance(x, torch.Tensor)
271-
else type(x).__name__
269+
(
270+
str(x.dtype).replace("torch.", "")
271+
if isinstance(x, torch.Tensor)
272+
else type(x).__name__
273+
)
272274
for x in expected_out
273275
]
274276
print(
@@ -308,9 +310,11 @@ def test_single_model(args):
308310
compiled_out = tuple(item.to("cpu").to("cuda") for item in compiled_out)
309311

310312
compiled_types = [
311-
str(x.dtype).replace("torch.", "")
312-
if isinstance(x, torch.Tensor)
313-
else type(x).__name__
313+
(
314+
str(x.dtype).replace("torch.", "")
315+
if isinstance(x, torch.Tensor)
316+
else type(x).__name__
317+
)
314318
for x in compiled_out
315319
]
316320
print(

0 commit comments

Comments
 (0)