Skip to content

Commit 32f9105

Browse files
committed
Enhance the performance of the input_tensor_constraints.py file generation process.
1 parent 29dbd62 commit 32f9105

11 files changed

+170
-53
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
samples/timm/davit_giant
12
# samples/timm/resnetaa50d.d_in12k
23
# samples/transformers-auto-model/opus-mt-en-gmw
3-
samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation
4+
# samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation

graph_net/constraint_util.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
23
from contextlib import AbstractContextManager
34
from graph_net.imp_util import load_module
@@ -88,14 +89,21 @@ def __call__(self, model_path):
8889
return
8990

9091
tensor_metas = self._get_tensor_metas(model_path)
92+
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
93+
logging.warning(f"before create_inputs_by_metas")
94+
inputs = self.get_dimension_generalizer().create_inputs_by_metas(
95+
module=self.get_model(model_path),
96+
tensor_meta_attrs_list=tensor_meta_attrs_list,
97+
)
98+
logging.warning(f"after create_inputs_by_metas")
9199
dyn_dim_cstr = make_dyn_dim_cstr_from_tensor_metas(tensor_metas)
92100

93101
def data_input_predicator(input_var_name):
94102
return self.data_input_predicator(model_path, input_var_name)
95103

96104
def get_tmp_model_path_ctx_mgr(dim_axes_pairs):
97105
return self._try_dimension_generalization(
98-
dim_axes_pairs, model_path, tensor_metas
106+
dim_axes_pairs, model_path, inputs
99107
)
100108

101109
def get_predicator_is_dyn_dim_cstr_feasible(tmp_model_path):
@@ -128,28 +136,44 @@ def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
128136
)
129137
sys.exit(0)
130138

131-
@contextmanager
132-
def _try_dimension_generalization(self, dim_axes_pairs, model_path, tensor_metas):
133-
if self.config["dimension_generalizer_filepath"] is None:
134-
yield model_path, ()
135-
return
136-
py_module = load_module(os.path.join(model_path, "model.py"))
137-
GraphModule = getattr(py_module, "GraphModule")
138-
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
139-
model = GraphModule()
139+
def get_dimension_generalizer(self):
140+
if hasattr(self, "_dim_generalizer"):
141+
return self._dim_generalizer
142+
assert self.config["dimension_generalizer_filepath"] is not None
140143
decorator_cls = getattr(
141144
load_module(self.config["dimension_generalizer_filepath"]),
142145
self.config["dimension_generalizer_class_name"],
143146
)
144-
dim_generalizer = decorator_cls(self.config["dimension_generalizer_config"])
147+
self._dim_generalizer = decorator_cls(
148+
self.config["dimension_generalizer_config"]
149+
)
150+
return self._dim_generalizer
151+
152+
def get_model(self, model_path):
153+
py_module = load_module(os.path.join(model_path, "model.py"))
154+
GraphModule = getattr(py_module, "GraphModule")
155+
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
156+
return GraphModule()
157+
158+
@contextmanager
159+
def _try_dimension_generalization(self, dim_axes_pairs, model_path, inputs):
160+
logging.warning(f"enter _try_dimension_generalization")
161+
if self.config["dimension_generalizer_filepath"] is None:
162+
yield model_path, ()
163+
return
164+
model = self.get_model(model_path)
165+
dim_generalizer = self.get_dimension_generalizer()
145166
dim_gen_pass = dim_generalizer(model, dim_axes_pairs)
146-
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
147-
inputs = dim_gen_pass.create_inputs_by_metas(tensor_meta_attrs_list)
148-
if not dim_gen_pass.need_rewrite(inputs):
167+
logging.warning(f"before need_rewrite")
168+
need_rewrite = dim_gen_pass.need_rewrite(inputs)
169+
logging.warning(f"after need_rewrite")
170+
if not need_rewrite:
149171
yield model_path, ()
150172
return
151173

174+
logging.warning(f"before rewrite")
152175
graph_module = dim_gen_pass.rewrite(inputs)
176+
logging.warning(f"after rewrite")
153177
with tempfile.TemporaryDirectory() as tmp_dir:
154178
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
155179
dim_gen_pass.save_graph_module(graph_module, tmp_dir)
@@ -300,6 +324,7 @@ def append_dim_gen_pass_names(dim_gen_pass_names):
300324
)
301325

302326
for i, picked_dim in enumerate(unqiue_dims):
327+
logging.warning(f"{i=} {picked_dim=}")
303328
cur_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
304329

305330
def filter_fn(input_name, input_idx, axis, dim):
@@ -319,11 +344,17 @@ def filter_fn(input_name, input_idx, axis, dim):
319344
(dim, axes) for dim in unqiue_dims[: i + 1] for axes in [dim2axes[dim]]
320345
)
321346
ctx_mgr = dyn_dim_cstr_feasibility_ctx_mgr
347+
logging.warning(f"before dyn_dim_cstr_feasibility_ctx_mgr")
322348
with ctx_mgr(dim_axes_pairs) as dyn_dim_cstr_feasibility:
349+
logging.warning(f"enter dyn_dim_cstr_feasibility_ctx_mgr")
323350
tmp_dyn_dim_cstr = copy.deepcopy(cur_dyn_dim_cstr)
324351
tmp_dyn_dim_cstr.update_symbol2example_value(sym2example_value)
325-
if not dyn_dim_cstr_feasibility(tmp_dyn_dim_cstr):
352+
logging.warning(f"before dyn_dim_cstr_feasibility")
353+
is_dyn_dim_cstr_feasible = dyn_dim_cstr_feasibility(tmp_dyn_dim_cstr)
354+
logging.warning(f"after dyn_dim_cstr_feasibility")
355+
if not is_dyn_dim_cstr_feasible:
326356
continue
327357
dyn_dim_cstr = cur_dyn_dim_cstr
328358
append_dim_gen_pass_names(dyn_dim_cstr_feasibility.dim_gen_pass_names)
359+
logging.warning(f"leave dyn_dim_cstr_feasibility_ctx_mgr")
329360
return dyn_dim_cstr, total_dim_gen_pass_names

graph_net/dynamic_dim_constraints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def symbolize(
4444
]
4545
Returns created symbol.
4646
"""
47+
import logging
48+
4749
InputDim = namedtuple("InputDim", ["input_idx", "axis", "dim"])
4850
input_dims = [
4951
InputDim(input_idx, axis, dim)

graph_net/model_path_handler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
import json
1111
import base64
1212
from contextlib import contextmanager
13+
import logging
14+
15+
logging.basicConfig(
16+
level=logging.WARNING, format="%(asctime)s [%(levelname)s] %(message)s"
17+
)
1318

1419

1520
def _convert_to_dict(config_str):

graph_net/test/shape_prop_batch_init_input_tensor_constraints_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ config_json_str=$(cat <<EOF
1111
"handler_path": "$GRAPH_NET_ROOT/constraint_util.py",
1212
"handler_class_name": "UpdateInputTensorConstraints",
1313
"handler_config": {
14-
"resume": true,
14+
"resume": false,
1515
"model_path_prefix": "$GRAPH_NET_ROOT/../",
1616
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
1717
"data_input_predicator_class_name": "NaiveDataInputPredicator",
@@ -26,4 +26,4 @@ EOF
2626
)
2727
CONFIG=$(echo $config_json_str | base64 -w 0)
2828

29-
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt --handler-config=$CONFIG
29+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/small_torch_samples_list.txt --handler-config=$CONFIG
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
# model_runnable_predicator=ShapePropagatablePredicator
8+
model_runnable_predicator=ModelRunnablePredicator
9+
config_json_str=$(cat <<EOF
10+
{
11+
"handler_path": "$GRAPH_NET_ROOT/constraint_util.py",
12+
"handler_class_name": "UpdateInputTensorConstraints",
13+
"handler_config": {
14+
"resume": true,
15+
"model_path_prefix": "$GRAPH_NET_ROOT/../",
16+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
17+
"data_input_predicator_class_name": "NaiveDataInputPredicator",
18+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
19+
"model_runnable_predicator_class_name": "$model_runnable_predicator",
20+
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
21+
"dimension_generalizer_class_name": "StaticToDynamic",
22+
"last_model_log_file": "/tmp/a.py"
23+
}
24+
}
25+
EOF
26+
)
27+
CONFIG=$(echo $config_json_str | base64 -w 0)
28+
29+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt --handler-config=$CONFIG

graph_net/torch/constraint_util.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
22
import os
3+
import graph_net
34

45

56
class NaiveDataInputPredicator:
@@ -12,23 +13,27 @@ def __call__(self, model_path, input_var_name: str) -> bool:
1213

1314
class ModelRunnablePredicator:
1415
def __init__(self, config):
15-
self.config = config
16+
if config is None:
17+
config = {}
18+
19+
graph_net_root = os.path.dirname(graph_net.__file__)
20+
decorator_config = {"use_dummy_inputs": True}
21+
self.predicator = RunModelPredicator(decorator_config)
1622

1723
def __call__(self, model_path):
18-
cmd = f"{sys.executable} -m graph_net.torch.run_model --model-path {model_path}"
19-
return os.system(cmd) == 0
24+
return self.predicator(model_path)
2025

2126

2227
class ShapePropagatablePredicator:
2328
def __init__(self, config=None):
2429
if config is None:
2530
config = {}
26-
import graph_net
2731

2832
graph_net_root = os.path.dirname(graph_net.__file__)
2933
decorator_config = {
3034
"decorator_path": f"{graph_net_root}/torch/shape_prop.py",
3135
"decorator_class_name": "ShapePropagate",
36+
"use_dummy_inputs": True,
3237
}
3338
self.predicator = RunModelPredicator(decorator_config)
3439

graph_net/torch/fx_graph_cache_util.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import torch
23
import copy
34
import os
@@ -13,8 +14,12 @@ def parse_immutable_model_path_into_sole_graph_module(model_path):
1314
if model_path not in g_model_path2graph_module:
1415
module = _get_torch_module(model_path)
1516
tensor_metas = _get_tensor_metas(model_path)
17+
logging.warning("before _create_inputs_by_metas")
1618
inputs = _create_inputs_by_metas(module, tensor_metas)
19+
logging.warning("after _create_inputs_by_metas")
20+
logging.warning("before parse_sole_graph_module")
1721
g_model_path2graph_module[model_path] = parse_sole_graph_module(module, inputs)
22+
logging.warning("after parse_sole_graph_module")
1823
return copy.deepcopy(g_model_path2graph_module[model_path])
1924

2025

@@ -34,11 +39,9 @@ def _get_tensor_metas(model_path):
3439

3540
def _create_inputs_by_metas(module, tensor_metas):
3641
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
37-
from graph_net.torch.utils import convert_tensor_meta_attrs_list_to_named_tensors
42+
from graph_net.torch.utils import get_dummy_named_tensors
3843

39-
named_tensors = convert_tensor_meta_attrs_list_to_named_tensors(
40-
tensor_meta_attrs_list
41-
)
44+
named_tensors = get_dummy_named_tensors(tensor_meta_attrs_list)
4245
name2tensor = {k: v for k, v in named_tensors}
4346
return tuple(
4447
name2tensor[name] for name in inspect.signature(module.forward).parameters

graph_net/torch/run_model.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@ def _convert_to_dict(config_str):
3030
return config
3131

3232

33-
def _get_decorator(args):
34-
if args.decorator_config is None:
35-
return lambda model: model
36-
decorator_config = _convert_to_dict(args.decorator_config)
33+
def _get_decorator(decorator_config):
3734
if "decorator_path" not in decorator_config:
3835
return lambda model: model
3936
class_name = decorator_config.get("decorator_class_name", "RunModelDecorator")
@@ -44,6 +41,17 @@ def _get_decorator(args):
4441
return decorator_class(decorator_config.get("decorator_config", {}))
4542

4643

44+
def get_flag_use_dummy_inputs(decorator_config):
45+
return "use_dummy_inputs" in decorator_config
46+
47+
48+
def replay_tensor(info, use_dummy_inputs):
49+
if use_dummy_inputs:
50+
return utils.get_dummy_tensor(info)
51+
else:
52+
return utils.replay_tensor(info)
53+
54+
4755
def main(args):
4856
model_path = args.model_path
4957
model_class = load_class_from_file(
@@ -53,11 +61,15 @@ def main(args):
5361
model = model_class()
5462
print(f"{model_path=}")
5563

56-
model = _get_decorator(args)(model)
64+
decorator_config = _convert_to_dict(args.decorator_config)
65+
if "decorator_path" in args.decorator_config:
66+
model = _get_decorator(decorator_config)(model)
5767

5868
inputs_params = utils.load_converted_from_text(f"{model_path}")
5969
params = inputs_params["weight_info"]
60-
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
70+
use_dummy_inputs = get_flag_use_dummy_inputs(decorator_config)
71+
print(f"{use_dummy_inputs=}")
72+
state_dict = {k: replay_tensor(v, use_dummy_inputs) for k, v in params.items()}
6173

6274
model(**state_dict)
6375

graph_net/torch/static_to_dynamic.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import logging
12
import torch
23
import torch.fx as fx
3-
from graph_net.torch.utils import convert_tensor_meta_attrs_list_to_named_tensors
4+
from graph_net.torch.utils import get_dummy_named_tensors
45
from torch.fx.passes.shape_prop import ShapeProp
56
from graph_net.torch.utils import apply_templates
67
from pathlib import Path
@@ -26,6 +27,13 @@ def __init__(self, config=None):
2627
def __call__(self, module, dim_axes_pairs):
2728
return StaticToDynamicModulePass(self.config, module, dim_axes_pairs)
2829

30+
def create_inputs_by_metas(self, module, tensor_meta_attrs_list):
31+
named_tensors = get_dummy_named_tensors(tensor_meta_attrs_list)
32+
name2tensor = {k: v for k, v in named_tensors}
33+
return tuple(
34+
name2tensor[name] for name in inspect.signature(module.forward).parameters
35+
)
36+
2937

3038
class StaticToDynamicModulePass(torch.nn.Module):
3139
def __init__(self, config, module, dim_axes_pairs):
@@ -56,19 +64,11 @@ def make_config(self, pass_names=()):
5664
def get_pass_names(self):
5765
return self.config["pass_names"]
5866

59-
def create_inputs_by_metas(self, tensor_meta_attrs_list):
60-
named_tensors = convert_tensor_meta_attrs_list_to_named_tensors(
61-
tensor_meta_attrs_list
62-
)
63-
name2tensor = {k: v for k, v in named_tensors}
64-
return tuple(
65-
name2tensor[name]
66-
for name in inspect.signature(self.module.forward).parameters
67-
)
68-
6967
def need_rewrite(self, inputs):
7068
try:
69+
logging.warning("before _create_fx_graph_module")
7170
traced_module = self._create_fx_graph_module(inputs)
71+
logging.warning("after _create_fx_graph_module")
7272
ShapeProp(traced_module).propagate(*inputs)
7373
except:
7474
return False

0 commit comments

Comments
 (0)