Skip to content

Commit 2513f17

Browse files
committed
Update
2 parents a597988 + 35d840f commit 2513f17

File tree

3,864 files changed

+515460
-40211
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

3,864 files changed

+515460
-40211
lines changed

graph_net/paddle/check_redundant_incrementally.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,5 +110,5 @@ def main(args):
110110
help="Path to GraphNet samples",
111111
)
112112
args = parser.parse_args()
113-
print(args)
113+
print(f"[Check Redundancy Arguments] {args}")
114114
main(args=args)

graph_net/paddle/validate.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import graph_net
1313
import os
1414
import ast
15+
import astor
1516
import paddle
1617

1718

@@ -29,7 +30,7 @@ def _get_sha_hash(content):
2930
return m.hexdigest()
3031

3132

32-
def _extract_forward_source(model_path):
33+
def _extract_forward_source(model_path, class_name):
3334
source = None
3435
with open(f"{model_path}/model.py", "r") as f:
3536
source = f.read()
@@ -38,18 +39,18 @@ def _extract_forward_source(model_path):
3839
forward_code = None
3940

4041
for node in tree.body:
41-
if isinstance(node, ast.ClassDef) and node.name == "GraphModule":
42+
if isinstance(node, ast.ClassDef) and node.name == class_name:
4243
for fn in node.body:
4344
if isinstance(fn, ast.FunctionDef) and fn.name == "forward":
44-
return ast.unparse(fn)
45+
return astor.to_source(fn)
4546
return None
4647

4748

4849
def check_graph_hash(args):
4950
model_path = args.model_path
5051
file_path = f"{model_path}/graph_hash.txt"
5152
if args.dump_graph_hash_key:
52-
model_str = _extract_forward_source(model_path)
53+
model_str = _extract_forward_source(model_path, class_name="GraphModule")
5354
assert model_str is not None, f"model_str of {args.model_path} is None."
5455
new_hash_text = _get_sha_hash(model_str)
5556

@@ -140,4 +141,5 @@ def main(args):
140141
help="Path to GraphNet samples folder. e.g '../../samples'",
141142
)
142143
args = parser.parse_args()
144+
print(f"[Validate Arguments] {args}")
143145
main(args=args)

graph_net/torch/extractor.py

Lines changed: 118 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,121 @@
77

88
torch._dynamo.config.capture_scalar_outputs = True
99
torch._dynamo.config.capture_dynamic_output_shape_ops = True
10+
torch._dynamo.config.capture_sparse_compute = True
11+
torch._dynamo.config.raise_on_ctx_manager_usage = False
12+
torch._dynamo.config.allow_rnn = True
13+
14+
15+
class GraphExtractor:
16+
def __init__(
17+
self, name, dynamic, mut_graph_codes=None, placeholder_auto_rename=False
18+
):
19+
self.subgraph_counter = 0
20+
self.name = name
21+
self.dynamic = dynamic
22+
self.mut_graph_codes = mut_graph_codes
23+
self.placeholder_auto_rename = placeholder_auto_rename
24+
self.workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
25+
if not self.workspace_path:
26+
raise EnvironmentError(
27+
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
28+
)
29+
30+
def move_files(self, source_dir, target_dir):
31+
os.makedirs(target_dir, exist_ok=True)
32+
for item in os.listdir(source_dir):
33+
source_path = os.path.join(source_dir, item)
34+
if os.path.isfile(source_path):
35+
target_path = os.path.join(target_dir, item)
36+
shutil.move(source_path, target_path)
37+
38+
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
39+
# 1. Get model path
40+
model_path = os.path.join(self.workspace_path, self.name)
41+
os.makedirs(model_path, exist_ok=True)
42+
43+
if self.subgraph_counter == 0:
44+
subgraph_path = model_path
45+
else:
46+
if self.subgraph_counter == 1:
47+
subgraph_0_path = os.path.join(model_path, f"subgraph_0")
48+
self.move_files(model_path, subgraph_0_path)
49+
50+
subgraph_path = os.path.join(
51+
model_path, f"subgraph_{self.subgraph_counter}"
52+
)
53+
os.makedirs(subgraph_path, exist_ok=True)
54+
55+
self.subgraph_counter += 1
56+
57+
# 2. Get full params
58+
params = {}
59+
input_idx = 0
60+
unique_id = 0
61+
62+
def try_rename_placeholder(node):
63+
assert node.op == "placeholder"
64+
if not self.placeholder_auto_rename:
65+
return
66+
nonlocal unique_id
67+
node.target = f"v{unique_id}"
68+
unique_id += 1
69+
node.name = f"v{unique_id}"
70+
unique_id += 1
71+
72+
for node in gm.graph.nodes:
73+
if node.op == "placeholder":
74+
try_rename_placeholder(node)
75+
input = sample_inputs[input_idx]
76+
if isinstance(input, torch.SymInt):
77+
input = torch.tensor(4)
78+
params[node.target] = input
79+
input_idx += 1
80+
81+
if node.op == "call_function" and hasattr(node.target, "__name__"):
82+
if node.target.__name__ in [
83+
"_enter_autocast",
84+
"_exit_autocast",
85+
]:
86+
node.replace_all_uses_with(node.args[0])
87+
gm.graph.erase_node(node)
88+
89+
assert input_idx == len(sample_inputs)
90+
if self.mut_graph_codes is not None:
91+
assert isinstance(self.mut_graph_codes, list)
92+
self.mut_graph_codes.append(gm.code)
93+
# 3. Generate and save model code
94+
base_code = gm.code
95+
# gm.graph.print_tabular()
96+
write_code = utils.apply_templates(base_code)
97+
with open(os.path.join(subgraph_path, "model.py"), "w") as fp:
98+
fp.write(write_code)
99+
100+
# 4. Save metadata
101+
metadata = {
102+
"framework": "torch",
103+
"num_devices_required": 1,
104+
"num_nodes_required": 1,
105+
"dynamic": bool(self.dynamic),
106+
"model_name": self.name,
107+
}
108+
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
109+
json.dump(metadata, f, indent=4)
110+
111+
# 5. Save tensor metadata
112+
# Adapt to different input structures (e.g., single tensor vs. dict/tuple of tensors)
113+
converted = utils.convert_state_and_inputs(params, [])
114+
utils.save_converted_to_text(converted, file_path=subgraph_path)
115+
utils.save_constraints_text(
116+
converted,
117+
file_path=os.path.join(subgraph_path, "input_tensor_constraints.py"),
118+
)
119+
120+
print(
121+
f"Graph and tensors for '{self.name}' extracted successfully to: {model_path}"
122+
)
123+
124+
return gm.forward
10125

11126

12127
def extract(name, dynamic=True, mut_graph_codes=None, placeholder_auto_rename=False):
@@ -80,109 +195,11 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
80195

81196
def wrapper(model: torch.nn.Module):
82197
assert isinstance(model, torch.nn.Module), f"{type(model)=}"
83-
84-
class GraphExtractor:
85-
def __init__(self):
86-
self.subgraph_counter = 0
87-
88-
def move_files(self, source_dir, target_dir):
89-
os.makedirs(target_dir, exist_ok=True)
90-
for item in os.listdir(source_dir):
91-
source_path = os.path.join(source_dir, item)
92-
if os.path.isfile(source_path):
93-
target_path = os.path.join(target_dir, item)
94-
shutil.move(source_path, target_path)
95-
96-
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
97-
# 1. Get workspace path
98-
workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
99-
if not workspace_path:
100-
raise EnvironmentError(
101-
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
102-
)
103-
model_path = os.path.join(workspace_path, name)
104-
os.makedirs(model_path, exist_ok=True)
105-
106-
if self.subgraph_counter == 0:
107-
subgraph_path = model_path
108-
else:
109-
if self.subgraph_counter == 1:
110-
subgraph_0_path = os.path.join(model_path, f"subgraph_0")
111-
self.move_files(model_path, subgraph_0_path)
112-
113-
subgraph_path = os.path.join(
114-
model_path, f"subgraph_{self.subgraph_counter}"
115-
)
116-
os.makedirs(subgraph_path, exist_ok=True)
117-
118-
self.subgraph_counter += 1
119-
120-
# 2. Get full params
121-
params = {}
122-
input_idx = 0
123-
unique_id = 0
124-
125-
def try_rename_placeholder(node):
126-
assert node.op == "placeholder"
127-
if not placeholder_auto_rename:
128-
return
129-
nonlocal unique_id
130-
node.target = f"v{unique_id}"
131-
unique_id += 1
132-
node.name = f"v{unique_id}"
133-
unique_id += 1
134-
135-
for node in gm.graph.nodes:
136-
if node.op == "placeholder":
137-
try_rename_placeholder(node)
138-
input = sample_inputs[input_idx]
139-
if isinstance(input, torch.SymInt):
140-
input = torch.tensor(4)
141-
params[node.target] = input
142-
input_idx += 1
143-
assert input_idx == len(sample_inputs)
144-
if mut_graph_codes is not None:
145-
assert isinstance(mut_graph_codes, list)
146-
mut_graph_codes.append(gm.code)
147-
# 3. Generate and save model code
148-
base_code = gm.code
149-
# gm.graph.print_tabular()
150-
write_code = utils.apply_templates(base_code)
151-
with open(os.path.join(subgraph_path, "model.py"), "w") as fp:
152-
fp.write(write_code)
153-
154-
# 4. Save metadata
155-
metadata = {
156-
"framework": "torch",
157-
"num_devices_required": 1,
158-
"num_nodes_required": 1,
159-
"dynamic": bool(dynamic),
160-
"model_name": name,
161-
}
162-
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
163-
json.dump(metadata, f, indent=4)
164-
165-
# 5. Save tensor metadata
166-
# Adapt to different input structures (e.g., single tensor vs. dict/tuple of tensors)
167-
converted = utils.convert_state_and_inputs(params, [])
168-
utils.save_converted_to_text(converted, file_path=subgraph_path)
169-
utils.save_constraints_text(
170-
converted,
171-
file_path=os.path.join(
172-
subgraph_path, "input_tensor_constraints.py"
173-
),
174-
)
175-
176-
print(
177-
f"Graph and tensors for '{name}' extracted successfully to: {model_path}"
178-
)
179-
180-
return gm.forward
181-
182-
extractor = GraphExtractor()
198+
extractor = GraphExtractor(
199+
name, dynamic, mut_graph_codes, placeholder_auto_rename
200+
)
183201
# return torch.compile(backend=extractor, dynamic=dynamic)
184202
compiled_model = torch.compile(model, backend=extractor, dynamic=dynamic)
185-
186203
return compiled_model
187204

188205
def decorator(module_class):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
f2b5a332b1b19703e7ccfb450de96c9c12244144c7b9d305d20587f772fb6672
1+
517608d4d2699e09c6171648da38a4f924556cf25abd97875599acfdda5807e4
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
02fa10efca360c8ba7818c367cdeb9979e2af8c72cf489913396a1f241bbad07
1+
2a46a550da3ca0bd5aa6157a26aff525a3bc69ff8f67fe35b4424303c12e2820
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
482fd9e9f201b45c2ce0b22b3037878aa3d139cc203fb35c781fd470140ec962
1+
e13d4b5e10e7aadcf05e891979bb73813fb3c4c1407b2688fb6ac8f849cdcee0
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e013c0a1d9173f7db5ed91398ad65fa43154e3bc8ce2e15c2d5a6637ddec61d8
1+
2511edee7164b3327d5efcce7879c5a19a19aec8a86e74e233ae83db0807ed46
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
55b1fcce22aee360f71154396a1f528446cae70ebd991927c0abf6c06016d201
1+
c1e76a465ae2ac6d1cb568acb5f17db4bca92d6d0239061cd319f2d591ba82b9
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2ea1f7f9bb52a294ff9fb5fd9876b9e9ed8b4af2fdb6cce93985eedfe50c7a94
1+
b23ce390b79f214cdbd74ea52c32d6dc141d93b179a7bf75f94bb12e8bd91561
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
40fde6163a995d989050cf8b78b44132b4b62ce218f604dad67aff1f4f5a56f0
1+
94a3256e834ecd7e836da57b44da751d75ef9e095b04ac00abc37a5e18a01390

0 commit comments

Comments
 (0)