Skip to content

Commit 8d2874c

Browse files
committed
refactor parse_sole_graph_module
1 parent bc5a23d commit 8d2874c

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

graph_net/test/batch_init_input_tensor_constraints_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ config_json_str=$(cat <<EOF
1111
"handler_path": "$GRAPH_NET_ROOT/constraint_util.py",
1212
"handler_class_name": "UpdateInputTensorConstraints",
1313
"handler_config": {
14-
"resume": false,
14+
"resume": true,
1515
"model_path_prefix": "$GRAPH_NET_ROOT/../",
1616
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
1717
"data_input_predicator_class_name": "NaiveDataInputPredicator",
@@ -26,4 +26,4 @@ EOF
2626
)
2727
CONFIG=$(echo $config_json_str | base64 -w 0)
2828

29-
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/small_torch_samples_list.txt --handler-config=$CONFIG
29+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt --handler-config=$CONFIG

graph_net/test/shape_prop_batch_init_input_tensor_constraints_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ config_json_str=$(cat <<EOF
1111
"handler_path": "$GRAPH_NET_ROOT/constraint_util.py",
1212
"handler_class_name": "UpdateInputTensorConstraints",
1313
"handler_config": {
14-
"resume": false,
14+
"resume": true,
1515
"model_path_prefix": "$GRAPH_NET_ROOT/../",
1616
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
1717
"data_input_predicator_class_name": "NaiveDataInputPredicator",
@@ -26,4 +26,4 @@ EOF
2626
)
2727
CONFIG=$(echo $config_json_str | base64 -w 0)
2828

29-
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/small_torch_samples_list.txt --handler-config=$CONFIG
29+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt --handler-config=$CONFIG

graph_net/torch/fx_graph_parse_util.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import inspect
23

34

45
def parse_sole_graph_module(module, inputs):
@@ -14,12 +15,46 @@ def my_backend(gm, sample_inputs):
1415

1516
torch.compile(module, backend=my_backend)(*inputs)
1617
assert traced_module is not None
17-
assert all(id(a) == id(b) for a, b in zip(inputs, traced_sample_inputs))
1818
for node in traced_module.graph.nodes:
1919
if node.op != "placeholder":
2020
continue
2121
assert node.target[:2] == "L_" or node.target[:2] == "l_", f"{node.target=}"
2222
node.target = node.target[2:]
23+
if node.target[0] == "l":
24+
node.target = "L" + node.target[1:]
2325
assert node.name[:2] == "L_" or node.name[:2] == "l_", f"{node.name=}"
2426
node.name = node.name[2:]
27+
if node.name[0] == "l":
28+
node.name = "L" + node.name[1:]
29+
30+
def get_input_names_from_signature():
31+
return inspect.signature(module.forward).parameters
32+
33+
def get_input_names_from_placeholder():
34+
return [
35+
node.name for node in traced_module.graph.nodes if node.op == "placeholder"
36+
]
37+
38+
def get_diff_input_names():
39+
placeholder_names = set(get_input_names_from_placeholder())
40+
return [
41+
(i, name)
42+
for i, name in enumerate(get_input_names_from_signature())
43+
if name not in placeholder_names
44+
]
45+
46+
if len(inputs) == len(traced_sample_inputs) + 1:
47+
diff_input_names = get_diff_input_names()
48+
assert len(diff_input_names) == 1, f"{diff_input_names=}"
49+
pos, name = diff_input_names[0]
50+
for i, node in enumerate(traced_module.graph.nodes):
51+
if i < pos:
52+
assert node.op == "placeholder"
53+
elif i == pos:
54+
with traced_module.graph.inserting_before(node):
55+
traced_module.graph.placeholder(name)
56+
else:
57+
break
58+
traced_module.recompile()
59+
assert len(get_diff_input_names()) == 0, f"{get_diff_input_names()=}"
2560
return traced_module

0 commit comments

Comments
 (0)