Skip to content

Commit 53b0e4f

Browse files
committed
Merge branch 'develop' into binary_decomposer_paddle
2 parents 7f7366a + a762155 commit 53b0e4f

15 files changed

+1197
-57
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
samples/timm/davit_giant
12
# samples/timm/resnetaa50d.d_in12k
23
# samples/transformers-auto-model/opus-mt-en-gmw
3-
samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation
4+
# samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation

graph_net/config/todo_torch_samples_list.txt

Lines changed: 577 additions & 0 deletions
Large diffs are not rendered by default.

graph_net/constraint_util.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
23
from contextlib import AbstractContextManager
34
from graph_net.imp_util import load_module
@@ -88,14 +89,21 @@ def __call__(self, model_path):
8889
return
8990

9091
tensor_metas = self._get_tensor_metas(model_path)
92+
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
93+
logging.warning(f"before create_inputs_by_metas")
94+
inputs = self.get_dimension_generalizer().create_inputs_by_metas(
95+
module=self.get_model(model_path),
96+
tensor_meta_attrs_list=tensor_meta_attrs_list,
97+
)
98+
logging.warning(f"after create_inputs_by_metas")
9199
dyn_dim_cstr = make_dyn_dim_cstr_from_tensor_metas(tensor_metas)
92100

93101
def data_input_predicator(input_var_name):
94102
return self.data_input_predicator(model_path, input_var_name)
95103

96104
def get_tmp_model_path_ctx_mgr(dim_axes_pairs):
97105
return self._try_dimension_generalization(
98-
dim_axes_pairs, model_path, tensor_metas
106+
dim_axes_pairs, model_path, inputs
99107
)
100108

101109
def get_predicator_is_dyn_dim_cstr_feasible(tmp_model_path):
@@ -128,28 +136,44 @@ def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
128136
)
129137
sys.exit(0)
130138

131-
@contextmanager
132-
def _try_dimension_generalization(self, dim_axes_pairs, model_path, tensor_metas):
133-
if self.config["dimension_generalizer_filepath"] is None:
134-
yield model_path, ()
135-
return
136-
py_module = load_module(os.path.join(model_path, "model.py"))
137-
GraphModule = getattr(py_module, "GraphModule")
138-
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
139-
model = GraphModule()
139+
def get_dimension_generalizer(self):
140+
if hasattr(self, "_dim_generalizer"):
141+
return self._dim_generalizer
142+
assert self.config["dimension_generalizer_filepath"] is not None
140143
decorator_cls = getattr(
141144
load_module(self.config["dimension_generalizer_filepath"]),
142145
self.config["dimension_generalizer_class_name"],
143146
)
144-
dim_generalizer = decorator_cls(self.config["dimension_generalizer_config"])
147+
self._dim_generalizer = decorator_cls(
148+
self.config["dimension_generalizer_config"]
149+
)
150+
return self._dim_generalizer
151+
152+
def get_model(self, model_path):
153+
py_module = load_module(os.path.join(model_path, "model.py"))
154+
GraphModule = getattr(py_module, "GraphModule")
155+
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
156+
return GraphModule()
157+
158+
@contextmanager
159+
def _try_dimension_generalization(self, dim_axes_pairs, model_path, inputs):
160+
logging.warning(f"enter _try_dimension_generalization")
161+
if self.config["dimension_generalizer_filepath"] is None:
162+
yield model_path, ()
163+
return
164+
model = self.get_model(model_path)
165+
dim_generalizer = self.get_dimension_generalizer()
145166
dim_gen_pass = dim_generalizer(model, dim_axes_pairs)
146-
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
147-
inputs = dim_gen_pass.create_inputs_by_metas(tensor_meta_attrs_list)
148-
if not dim_gen_pass.need_rewrite(inputs):
167+
logging.warning(f"before need_rewrite")
168+
need_rewrite = dim_gen_pass.need_rewrite(inputs)
169+
logging.warning(f"after need_rewrite")
170+
if not need_rewrite:
149171
yield model_path, ()
150172
return
151173

174+
logging.warning(f"before rewrite")
152175
graph_module = dim_gen_pass.rewrite(inputs)
176+
logging.warning(f"after rewrite")
153177
with tempfile.TemporaryDirectory() as tmp_dir:
154178
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
155179
dim_gen_pass.save_graph_module(graph_module, tmp_dir)
@@ -300,6 +324,7 @@ def append_dim_gen_pass_names(dim_gen_pass_names):
300324
)
301325

302326
for i, picked_dim in enumerate(unqiue_dims):
327+
logging.warning(f"{i=} {picked_dim=}")
303328
cur_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
304329

305330
def filter_fn(input_name, input_idx, axis, dim):
@@ -319,11 +344,17 @@ def filter_fn(input_name, input_idx, axis, dim):
319344
(dim, axes) for dim in unqiue_dims[: i + 1] for axes in [dim2axes[dim]]
320345
)
321346
ctx_mgr = dyn_dim_cstr_feasibility_ctx_mgr
347+
logging.warning(f"before dyn_dim_cstr_feasibility_ctx_mgr")
322348
with ctx_mgr(dim_axes_pairs) as dyn_dim_cstr_feasibility:
349+
logging.warning(f"enter dyn_dim_cstr_feasibility_ctx_mgr")
323350
tmp_dyn_dim_cstr = copy.deepcopy(cur_dyn_dim_cstr)
324351
tmp_dyn_dim_cstr.update_symbol2example_value(sym2example_value)
325-
if not dyn_dim_cstr_feasibility(tmp_dyn_dim_cstr):
352+
logging.warning(f"before dyn_dim_cstr_feasibility")
353+
is_dyn_dim_cstr_feasible = dyn_dim_cstr_feasibility(tmp_dyn_dim_cstr)
354+
logging.warning(f"after dyn_dim_cstr_feasibility")
355+
if not is_dyn_dim_cstr_feasible:
326356
continue
327357
dyn_dim_cstr = cur_dyn_dim_cstr
328358
append_dim_gen_pass_names(dyn_dim_cstr_feasibility.dim_gen_pass_names)
359+
logging.warning(f"leave dyn_dim_cstr_feasibility_ctx_mgr")
329360
return dyn_dim_cstr, total_dim_gen_pass_names

graph_net/dynamic_dim_constraints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def symbolize(
4444
]
4545
Returns created symbol.
4646
"""
47+
import logging
48+
4749
InputDim = namedtuple("InputDim", ["input_idx", "axis", "dim"])
4850
input_dims = [
4951
InputDim(input_idx, axis, dim)

graph_net/model_path_handler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
import json
1111
import base64
1212
from contextlib import contextmanager
13+
import logging
14+
15+
logging.basicConfig(
16+
level=logging.WARNING, format="%(asctime)s [%(levelname)s] %(message)s"
17+
)
1318

1419

1520
def _convert_to_dict(config_str):

graph_net/test/batch_init_input_tensor_constraints_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ config_json_str=$(cat <<EOF
1111
"handler_path": "$GRAPH_NET_ROOT/constraint_util.py",
1212
"handler_class_name": "UpdateInputTensorConstraints",
1313
"handler_config": {
14-
"resume": false,
14+
"resume": true,
1515
"model_path_prefix": "$GRAPH_NET_ROOT/../",
1616
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
1717
"data_input_predicator_class_name": "NaiveDataInputPredicator",
@@ -26,4 +26,4 @@ EOF
2626
)
2727
CONFIG=$(echo $config_json_str | base64 -w 0)
2828

29-
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/small_torch_samples_list.txt --handler-config=$CONFIG
29+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt --handler-config=$CONFIG

graph_net/test/subgraph_decompose_and_evaluation_step_test.sh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ python3 -m graph_net.subgraph_decompose_and_evaluation_step \
7979
--tolerance="$TOLERANCE" \
8080
--max-subgraph-size="$INITIAL_MAX_SIZE"
8181

82-
echo ""
83-
echo ">>> Pass execution finished."
84-
echo ">>> Run this script again to execute the NEXT pass if needed."
82+
if [ $? -ne 0 ]; then
83+
echo ""
84+
echo "[ERROR] Task failed! Please check logs and fix bugs before proceeding."
85+
else
86+
echo ""
87+
echo ">>> Pass execution finished."
88+
echo ">>> Run this script again to execute the NEXT pass if needed."
89+
fi
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
# model_runnable_predicator=ShapePropagatablePredicator
8+
model_runnable_predicator=ModelRunnablePredicator
9+
config_json_str=$(cat <<EOF
10+
{
11+
"handler_path": "$GRAPH_NET_ROOT/constraint_util.py",
12+
"handler_class_name": "UpdateInputTensorConstraints",
13+
"handler_config": {
14+
"resume": true,
15+
"model_path_prefix": "$GRAPH_NET_ROOT/../",
16+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
17+
"data_input_predicator_class_name": "NaiveDataInputPredicator",
18+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
19+
"model_runnable_predicator_class_name": "$model_runnable_predicator",
20+
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
21+
"dimension_generalizer_class_name": "StaticToDynamic",
22+
"last_model_log_file": "/tmp/a.py"
23+
}
24+
}
25+
EOF
26+
)
27+
CONFIG=$(echo $config_json_str | base64 -w 0)
28+
29+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt --handler-config=$CONFIG

graph_net/torch/constraint_util.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
22
import os
3+
import graph_net
34

45

56
class NaiveDataInputPredicator:
@@ -12,23 +13,27 @@ def __call__(self, model_path, input_var_name: str) -> bool:
1213

1314
class ModelRunnablePredicator:
1415
def __init__(self, config):
15-
self.config = config
16+
if config is None:
17+
config = {}
18+
19+
graph_net_root = os.path.dirname(graph_net.__file__)
20+
decorator_config = {"use_dummy_inputs": True}
21+
self.predicator = RunModelPredicator(decorator_config)
1622

1723
def __call__(self, model_path):
18-
cmd = f"{sys.executable} -m graph_net.torch.run_model --model-path {model_path}"
19-
return os.system(cmd) == 0
24+
return self.predicator(model_path)
2025

2126

2227
class ShapePropagatablePredicator:
2328
def __init__(self, config=None):
2429
if config is None:
2530
config = {}
26-
import graph_net
2731

2832
graph_net_root = os.path.dirname(graph_net.__file__)
2933
decorator_config = {
3034
"decorator_path": f"{graph_net_root}/torch/shape_prop.py",
3135
"decorator_class_name": "ShapePropagate",
36+
"use_dummy_inputs": True,
3237
}
3338
self.predicator = RunModelPredicator(decorator_config)
3439

graph_net/torch/fx_graph_cache_util.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import torch
23
import copy
34
import os
@@ -13,8 +14,12 @@ def parse_immutable_model_path_into_sole_graph_module(model_path):
1314
if model_path not in g_model_path2graph_module:
1415
module = _get_torch_module(model_path)
1516
tensor_metas = _get_tensor_metas(model_path)
17+
logging.warning("before _create_inputs_by_metas")
1618
inputs = _create_inputs_by_metas(module, tensor_metas)
19+
logging.warning("after _create_inputs_by_metas")
20+
logging.warning("before parse_sole_graph_module")
1721
g_model_path2graph_module[model_path] = parse_sole_graph_module(module, inputs)
22+
logging.warning("after parse_sole_graph_module")
1823
return copy.deepcopy(g_model_path2graph_module[model_path])
1924

2025

@@ -34,11 +39,9 @@ def _get_tensor_metas(model_path):
3439

3540
def _create_inputs_by_metas(module, tensor_metas):
3641
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
37-
from graph_net.torch.utils import convert_tensor_meta_attrs_list_to_named_tensors
42+
from graph_net.torch.utils import get_dummy_named_tensors
3843

39-
named_tensors = convert_tensor_meta_attrs_list_to_named_tensors(
40-
tensor_meta_attrs_list
41-
)
44+
named_tensors = get_dummy_named_tensors(tensor_meta_attrs_list)
4245
name2tensor = {k: v for k, v in named_tensors}
4346
return tuple(
4447
name2tensor[name] for name in inspect.signature(module.forward).parameters

0 commit comments

Comments
 (0)