Skip to content

Commit e32fcb6

Browse files
authored
add multi graph (#268)
1 parent 04f3c15 commit e32fcb6

File tree

2 files changed

+129
-70
lines changed

2 files changed

+129
-70
lines changed

graph_net/torch/extractor.py

Lines changed: 96 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import torch
33
import json
4+
import shutil
45
from typing import Union, Callable
56
from . import utils
67

@@ -80,76 +81,105 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
8081
def wrapper(model: torch.nn.Module):
8182
assert isinstance(model, torch.nn.Module), f"{type(model)=}"
8283

83-
def extractor(gm: torch.fx.GraphModule, sample_inputs):
84-
# 1. Get workspace path
85-
workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
86-
if not workspace_path:
87-
raise EnvironmentError(
88-
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
84+
class GraphExtractor:
85+
def __init__(self):
86+
self.subgraph_counter = 0
87+
88+
def move_files(self, source_dir, target_dir):
89+
os.makedirs(target_dir, exist_ok=True)
90+
for item in os.listdir(source_dir):
91+
source_path = os.path.join(source_dir, item)
92+
if os.path.isfile(source_path):
93+
target_path = os.path.join(target_dir, item)
94+
shutil.move(source_path, target_path)
95+
96+
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
97+
# 1. Get workspace path
98+
workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
99+
if not workspace_path:
100+
raise EnvironmentError(
101+
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
102+
)
103+
model_path = os.path.join(workspace_path, name)
104+
os.makedirs(model_path, exist_ok=True)
105+
106+
if self.subgraph_counter == 0:
107+
subgraph_path = model_path
108+
else:
109+
if self.subgraph_counter == 1:
110+
subgraph_0_path = os.path.join(model_path, f"subgraph_0")
111+
self.move_files(model_path, subgraph_0_path)
112+
113+
subgraph_path = os.path.join(
114+
model_path, f"subgraph_{self.subgraph_counter}"
115+
)
116+
os.makedirs(subgraph_path, exist_ok=True)
117+
118+
self.subgraph_counter += 1
119+
120+
# 2. Get full params
121+
params = {}
122+
input_idx = 0
123+
unique_id = 0
124+
125+
def try_rename_placeholder(node):
126+
assert node.op == "placeholder"
127+
if not placeholder_auto_rename:
128+
return
129+
nonlocal unique_id
130+
node.target = f"v{unique_id}"
131+
unique_id += 1
132+
node.name = f"v{unique_id}"
133+
unique_id += 1
134+
135+
for node in gm.graph.nodes:
136+
if node.op == "placeholder":
137+
try_rename_placeholder(node)
138+
input = sample_inputs[input_idx]
139+
if isinstance(input, torch.SymInt):
140+
input = torch.tensor(4)
141+
params[node.target] = input
142+
input_idx += 1
143+
assert input_idx == len(sample_inputs)
144+
if mut_graph_codes is not None:
145+
assert isinstance(mut_graph_codes, list)
146+
mut_graph_codes.append(gm.code)
147+
# 3. Generate and save model code
148+
base_code = gm.code
149+
# gm.graph.print_tabular()
150+
write_code = utils.apply_templates(base_code)
151+
with open(os.path.join(subgraph_path, "model.py"), "w") as fp:
152+
fp.write(write_code)
153+
154+
# 4. Save metadata
155+
metadata = {
156+
"framework": "torch",
157+
"num_devices_required": 1,
158+
"num_nodes_required": 1,
159+
"dynamic": bool(dynamic),
160+
"model_name": name,
161+
}
162+
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
163+
json.dump(metadata, f, indent=4)
164+
165+
# 5. Save tensor metadata
166+
# Adapt to different input structures (e.g., single tensor vs. dict/tuple of tensors)
167+
converted = utils.convert_state_and_inputs(params, [])
168+
utils.save_converted_to_text(converted, file_path=subgraph_path)
169+
utils.save_constraints_text(
170+
converted,
171+
file_path=os.path.join(
172+
subgraph_path, "input_tensor_constraints.py"
173+
),
89174
)
90-
model_path = os.path.join(workspace_path, name)
91-
os.makedirs(model_path, exist_ok=True)
92-
93-
# 2. Get full params
94-
params = {}
95-
input_idx = 0
96-
unique_id = 0
97-
98-
def try_rename_placeholder(node):
99-
assert node.op == "placeholder"
100-
if not placeholder_auto_rename:
101-
return
102-
nonlocal unique_id
103-
node.target = f"v{unique_id}"
104-
unique_id += 1
105-
node.name = f"v{unique_id}"
106-
unique_id += 1
107-
108-
for node in gm.graph.nodes:
109-
if node.op == "placeholder":
110-
try_rename_placeholder(node)
111-
input = sample_inputs[input_idx]
112-
if isinstance(input, torch.SymInt):
113-
input = torch.tensor(4)
114-
params[node.target] = input
115-
input_idx += 1
116-
assert input_idx == len(sample_inputs)
117-
if mut_graph_codes is not None:
118-
assert isinstance(mut_graph_codes, list)
119-
mut_graph_codes.append(gm.code)
120-
# 3. Generate and save model code
121-
base_code = gm.code
122-
# gm.graph.print_tabular()
123-
write_code = utils.apply_templates(base_code)
124-
with open(os.path.join(model_path, "model.py"), "w") as fp:
125-
fp.write(write_code)
126-
127-
# 4. Save metadata
128-
metadata = {
129-
"framework": "torch",
130-
"num_devices_required": 1,
131-
"num_nodes_required": 1,
132-
"dynamic": bool(dynamic),
133-
"model_name": name,
134-
}
135-
with open(os.path.join(model_path, "graph_net.json"), "w") as f:
136-
json.dump(metadata, f, indent=4)
137-
138-
# 5. Save tensor metadata
139-
# Adapt to different input structures (e.g., single tensor vs. dict/tuple of tensors)
140-
converted = utils.convert_state_and_inputs(params, [])
141-
utils.save_converted_to_text(converted, file_path=model_path)
142-
utils.save_constraints_text(
143-
converted,
144-
file_path=os.path.join(model_path, "input_tensor_constraints.py"),
145-
)
146175

147-
print(
148-
f"Graph and tensors for '{name}' extracted successfully to: {model_path}"
149-
)
176+
print(
177+
f"Graph and tensors for '{name}' extracted successfully to: {model_path}"
178+
)
150179

151-
return gm.forward
180+
return gm.forward
152181

182+
extractor = GraphExtractor()
153183
# return torch.compile(backend=extractor, dynamic=dynamic)
154184
compiled_model = torch.compile(model, backend=extractor, dynamic=dynamic)
155185

graph_net/torch/validate.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ def temp_workspace():
1515
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = old
1616

1717

18-
def main(args):
19-
model_path = args.model_path
18+
def validate(args, model_path):
2019
with temp_workspace() as tmp_dir_name:
2120
print("Check extractability ...")
2221
cmd = f"{sys.executable} -m graph_net.torch.single_device_runner --model-path {model_path}"
@@ -36,16 +35,46 @@ def main(args):
3635
if args.graph_net_samples_path is None
3736
else args.graph_net_samples_path
3837
)
39-
cmd = f"{sys.executable} -m graph_net.torch.check_redundant_incrementally --model-path {args.model_path} --graph-net-samples-path {graph_net_samples_path}"
38+
cmd = f"{sys.executable} -m graph_net.torch.check_redundant_incrementally --model-path {model_path} --graph-net-samples-path {graph_net_samples_path}"
4039
cmd_ret = os.system(cmd)
41-
rm_cmd = f"{sys.executable} -m graph_net.torch.remove_redundant_incrementally --model-path {args.model_path} --graph-net-samples-path {graph_net_samples_path}"
40+
rm_cmd = f"{sys.executable} -m graph_net.torch.remove_redundant_incrementally --model-path {model_path} --graph-net-samples-path {graph_net_samples_path}"
4241
assert (
4342
cmd_ret == 0
4443
), f"\nPlease use the following command to remove redundant model directories:\n\n{rm_cmd}\n"
4544

4645
print(f"Validation success, {model_path=}")
4746

4847

48+
def get_recursively_model_path(root_dir):
49+
for sub_dir in get_immediate_subdirectory_paths(root_dir):
50+
if is_single_model_dir(sub_dir):
51+
yield sub_dir
52+
else:
53+
yield from get_recursively_model_path(sub_dir)
54+
55+
56+
def get_immediate_subdirectory_paths(parent_dir):
57+
return [
58+
sub_dir
59+
for name in os.listdir(parent_dir)
60+
for sub_dir in [os.path.join(parent_dir, name)]
61+
if os.path.isdir(sub_dir)
62+
]
63+
64+
65+
def is_single_model_dir(model_dir):
66+
return os.path.isfile(f"{model_dir}/graph_net.json")
67+
68+
69+
def main(args):
70+
model_path = args.model_path
71+
if is_single_model_dir(args.model_path):
72+
validate(args, model_path)
73+
else:
74+
for model_path in get_recursively_model_path(args.model_path):
75+
validate(args, model_path)
76+
77+
4978
if __name__ == "__main__":
5079
parser = argparse.ArgumentParser(
5180
description="Validate a computation graph sample. return 0 if success"

0 commit comments

Comments
 (0)