Skip to content

Commit 479e342

Browse files
committed
fix bug
1 parent 51d0d7e commit 479e342

File tree

7 files changed

+1234
-1518
lines changed

7 files changed

+1234
-1518
lines changed

graph_net/torch/extractor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def try_rename_placeholder(node):
5959
mut_graph_codes.append(gm.code)
6060
# 3. Generate and save model code
6161
base_code = gm.code
62+
base_code = utils.remove_amp_pass(base_code)
63+
6264
# gm.graph.print_tabular()
6365
write_code = utils.apply_templates(base_code)
6466
with open(os.path.join(model_path, "model.py"), "w") as fp:
@@ -90,6 +92,7 @@ def try_rename_placeholder(node):
9092
return gm.forward
9193

9294
# return torch.compile(backend=extractor, dynamic=dynamic)
95+
print(model)
9396
compiled_model = torch.compile(model, backend=extractor, dynamic=dynamic)
9497

9598
return compiled_model

graph_net/torch/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,9 @@ def replay_tensor(info):
268268
if dtype is torch.bool:
269269
return (torch.randn(size=shape) > 0.5).to(dtype).to(device)
270270
return torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean
271+
272+
273+
def remove_amp_pass(code):
274+
lines = code.split("\n")
275+
filtered_lines = [line for line in lines if "amp" not in line]
276+
return "\n".join(filtered_lines)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
4bd91a3dcc08517dc4ac39602b20245aba8452907e56b5111bccfe53a5c1c0bf
1+
ad1c8cd8d22e7d12ed4a58199fd907bf24abb8b89e644651597e95a8d1102d86
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"framework": "torch",
33
"num_devices_required": 1,
4-
"num_nodes_required": 1
4+
"num_nodes_required": 1,
5+
"dynamic": false
56
}

0 commit comments

Comments
 (0)