Skip to content

Commit 28c4074

Browse files
committed
fix regression bugs in tools/apply_dim_gen_passes.sh, tools/get_in_tensor_symbolic_shapes.sh and tools/init_input_tensor_constraints.sh
1 parent 54244cd commit 28c4074

File tree

5 files changed

+47
-27
lines changed

5 files changed

+47
-27
lines changed

graph_net/tools/apply_dim_gen_passes.sh

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
44
os.path.dirname(graph_net.__file__))")
55

6-
# input model path
7-
# model_runnable_predicator=ShapePropagatablePredicator
8-
model_runnable_predicator=ModelRunnablePredicator
9-
config_json_str=$(cat <<EOF
6+
python3 -m graph_net.model_path_handler \
7+
--model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt \
8+
--handler-config=$(base64 -w 0 <<EOF
109
{
1110
"handler_path": "$GRAPH_NET_ROOT/dimension_generalizer.py",
1211
"handler_class_name": "ApplyDimGenPasses",
@@ -22,6 +21,3 @@ config_json_str=$(cat <<EOF
2221
}
2322
EOF
2423
)
25-
CONFIG=$(echo $config_json_str | base64 -w 0)
26-
27-
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt --handler-config=$CONFIG

graph_net/tools/get_in_tensor_symbolic_shapes.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ os.path.dirname(graph_net.__file__))")
66
# input model path
77
# model_runnable_predicator=ShapePropagatablePredicator
88
model_runnable_predicator=ModelRunnablePredicator
9-
config_json_str=$(cat <<EOF
9+
10+
python3 -m graph_net.model_path_handler \
11+
--model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt \
12+
--handler-config=$(base64 -w 0 <<EOF
1013
{
1114
"handler_path": "$GRAPH_NET_ROOT/tools/_get_in_tensor_symbolic_shapes.py",
1215
"handler_class_name": "GetInTensorSymbolicShapes",
@@ -17,6 +20,3 @@ config_json_str=$(cat <<EOF
1720
}
1821
EOF
1922
)
20-
CONFIG=$(echo $config_json_str | base64 -w 0)
21-
22-
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt --handler-config=$CONFIG

graph_net/tools/batch_init_input_tensor_constraints.sh renamed to graph_net/tools/init_input_tensor_constraints.sh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ os.path.dirname(graph_net.__file__))")
66
# input model path
77
# model_runnable_predicator=ShapePropagatablePredicator
88
model_runnable_predicator=ModelRunnablePredicator
9-
config_json_str=$(cat <<EOF
9+
10+
python3 -m graph_net.model_path_handler \
11+
--use-subprocess \
12+
--model-path-list $GRAPH_NET_ROOT/config/small100_torch_samples_list.txt \
13+
--handler-config=$(base64 -w 0 <<EOF
1014
{
1115
"handler_path": "$GRAPH_NET_ROOT/constraint_util.py",
1216
"handler_class_name": "UpdateInputTensorConstraints",
@@ -40,6 +44,3 @@ config_json_str=$(cat <<EOF
4044
}
4145
EOF
4246
)
43-
CONFIG=$(echo $config_json_str | base64 -w 0)
44-
45-
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/empty_cstr_torch_samples_list.txt --handler-config=$CONFIG --use-subprocess

graph_net/torch/fx_graph_parse_util.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,14 @@ def get_input_names_from_placeholder():
157157
names_from_placeholder=get_input_names_from_placeholder(),
158158
)
159159

160-
for node in traced_module.graph.nodes:
161-
if node.op != "placeholder":
162-
continue
163-
node.target = _rename_placeholder(node.target, pattern2replacement)
164-
node.name = node.target
160+
def handle_placeholder_name(pattern2replacement):
161+
for node in traced_module.graph.nodes:
162+
if node.op != "placeholder":
163+
continue
164+
node.target = _rename_placeholder(node.target, pattern2replacement)
165+
node.name = node.target
166+
167+
handle_placeholder_name(pattern2replacement)
165168

166169
def get_diff_input_names():
167170
placeholder_names = set(get_input_names_from_placeholder())
@@ -226,14 +229,35 @@ def handle_underscore_suffix_difference():
226229
for node in traced_module.graph.nodes:
227230
if not (node.op == "placeholder"):
228231
continue
229-
if node.target not in names:
232+
if node.name not in names:
230233
continue
231234
node.target = node.target[:-1]
232235
node.name = node.name[:-1]
233236
traced_module.recompile()
234237

235238
handle_underscore_suffix_difference()
236239

240+
def handle_prefix_difference():
241+
zip_filter_names = get_zip_filter_names()
242+
if not (len(zip_filter_names) > 0):
243+
return
244+
names = set(
245+
name_in_placeholder
246+
for _0, name_in_signature, name_in_placeholder in zip_filter_names
247+
if (f"l_{name_in_signature[2:]}" == name_in_placeholder)
248+
if (name_in_signature == f"L_{name_in_placeholder[2:]}")
249+
)
250+
for node in traced_module.graph.nodes:
251+
if not (node.op == "placeholder"):
252+
continue
253+
if node.name not in names:
254+
continue
255+
node.target = f"L_{node.target[2:]}"
256+
node.name = f"L_{node.name[2:]}"
257+
traced_module.recompile()
258+
259+
handle_prefix_difference()
260+
237261
if len(get_zip_filter_names()) > 0 and set(get_input_names_from_signature()) == set(
238262
get_input_names_from_placeholder()
239263
):
@@ -243,15 +267,14 @@ def handle_underscore_suffix_difference():
243267

244268
zip_filter_names = get_zip_filter_names()
245269

246-
def zip_filter_names_error_str():
270+
def get_error_model_path():
247271
for triple in zip_filter_names:
248272
print(triple)
249-
error_model_path = module.__graph_net_file_path__
250-
return f"{error_model_path=}"
273+
return module.__graph_net_file_path__
251274

252275
# from pathlib import Path
253276
# Path("/tmp/a.py").write_text(traced_module.code)
254-
assert len(zip_filter_names) == 0, f"{zip_filter_names_error_str()=}"
277+
assert len(zip_filter_names) == 0, f"{get_error_model_path()=}"
255278
return traced_module
256279

257280

graph_net/torch/static_to_dynamic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import traceback
22
import logging
33
import torch
4-
from graph_net.torch.utils import get_dummy_named_tensors
4+
from graph_net.torch.utils import get_named_tensors
55
from torch.fx.passes.shape_prop import ShapeProp
66
from graph_net.torch.utils import apply_templates
77
from pathlib import Path
@@ -24,7 +24,7 @@ def __call__(self, module, dim_axes_pairs):
2424
return StaticToDynamicModulePass(self.config, module, dim_axes_pairs)
2525

2626
def create_inputs_by_metas(self, module, tensor_meta_attrs_list):
27-
named_tensors = get_dummy_named_tensors(tensor_meta_attrs_list)
27+
named_tensors = get_named_tensors(tensor_meta_attrs_list, use_dummy_inputs=True)
2828
name2tensor = {k: v for k, v in named_tensors}
2929
return tuple(
3030
name2tensor[name] for name in inspect.signature(module.forward).parameters

0 commit comments

Comments
 (0)