Skip to content

Commit d5c9232

Browse files
committed
add torch/sample_passes/dimension_generalizer.py
1 parent 28c4074 commit d5c9232

File tree

6 files changed

+164
-8
lines changed

6 files changed

+164
-8
lines changed

graph_net/dimension_generalizer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ def __init__(self, config=None):
2626
def _make_config(
2727
self,
2828
output_dir: str,
29-
dimension_generalizer_filepath=None,
30-
dimension_generalizer_class_name="StaticToDynamic",
31-
dimension_generalizer_config=None,
32-
model_path_prefix="",
33-
resume=False,
34-
last_model_log_file=None,
35-
limits_handled_models=None,
29+
dimension_generalizer_filepath: str = None,
30+
dimension_generalizer_class_name: str = "StaticToDynamic",
31+
dimension_generalizer_config: dict = None,
32+
model_path_prefix: str = "",
33+
resume: bool = False,
34+
last_model_log_file: str = None,
35+
limits_handled_models: int = None,
3636
):
3737
if dimension_generalizer_config is None:
3838
dimension_generalizer_config = {}
@@ -118,7 +118,7 @@ def _get_symbols_and_reified_dims(self, from_model_path, dyn_dim_cstrs):
118118

119119
reifier_class = get_reifier(reifier_name)
120120
reifier_instance = reifier_class(str(from_model_path))
121-
assert reifier_instance.match
121+
assert reifier_instance.match()
122122
symbols2reified_dims = reifier_instance.reify()
123123
assert len(symbols2reified_dims) == 1
124124
symbols, reified_dims = next(iter(symbols2reified_dims.items()))

graph_net/sample_pass/resumable_sample_pass_mixin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
class ResumableSamplePassMixin(SamplePassMixin):
99
def __init__(self, *args, **kwargs):
1010
self.num_handled_models = 0
11+
super().__init__()
1112

1213
def declare_config(
1314
self,

graph_net/sample_pass/sample_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def __init__(self, config=None):
1010

1111
self._check_config_declaration_valid()
1212
self.config = self._make_config_by_config_declare(config)
13+
super().__init__()
1314

1415
@abc.abstractmethod
1516
def declare_config(self):

graph_net/sample_pass/sample_pass_mixin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33

44
class SamplePassMixin(abc.ABC):
5+
def __init__(self):
6+
super().__init__()
7+
58
@abc.abstractmethod
69
def declare_config(self):
710
raise NotImplementedError()
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(os.path.dirname(graph_net.__file__)))")
5+
6+
python3 -m graph_net.model_path_handler \
7+
--model-path-list $GRAPH_NET_ROOT/graph_net/config/small100_torch_samples_list.txt \
8+
--handler-config=$(base64 -w 0 <<EOF
9+
{
10+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/dimension_generalizer.py",
11+
"handler_class_name": "DimensionGeneralizer",
12+
"handler_config": {
13+
"resume": false,
14+
"output_dir": "/tmp/workspace_dimension_generalizer",
15+
"model_path_prefix": "$GRAPH_NET_ROOT",
16+
"limits_handled_models": 10,
17+
"last_model_log_file": "/tmp/a.py"
18+
}
19+
}
20+
EOF
21+
)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import logging
2+
from graph_net.sample_pass.sample_pass import SamplePass
3+
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
4+
from graph_net.sample_pass.only_model_file_rewrite_sample_pass_mixin import (
5+
OnlyModelFileRewriteSamplePassMixin,
6+
)
7+
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
8+
from graph_net.imp_util import load_module
9+
from graph_net.tensor_meta import TensorMeta
10+
from graph_net.torch.static_to_dynamic import StaticToDynamic
11+
import os
12+
from contextlib import contextmanager
13+
import tempfile
14+
import shutil
15+
from pathlib import Path
16+
from dataclasses import asdict
17+
import graph_net.graph_net_json_file_util as gn_json
18+
19+
20+
class DimensionGeneralizer(
21+
SamplePass, ResumableSamplePassMixin, OnlyModelFileRewriteSamplePassMixin
22+
):
23+
def __init__(self, config):
24+
super().__init__(config)
25+
26+
def declare_config(
27+
self,
28+
model_path_prefix: str,
29+
output_dir: str,
30+
resume: bool = False,
31+
limits_handled_models: int = None,
32+
last_model_log_file: str = None,
33+
):
34+
pass
35+
36+
def __call__(self, rel_model_path: str):
37+
self.resumable_handle_sample(rel_model_path)
38+
39+
def sample_handled(self, rel_model_path: str) -> bool:
40+
return self.naive_sample_handled(rel_model_path, search_file_name="model.py")
41+
42+
def resume(self, rel_model_path: str):
43+
return self.copy_sample_and_handle_model_py_file(rel_model_path)
44+
45+
def handle_model_py_file(self, rel_model_path: str) -> str:
46+
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
47+
output_dir = Path(self.config["output_dir"])
48+
generalized_model_path = output_dir / rel_model_path
49+
generalized_model_path.mkdir(parents=True, exist_ok=True)
50+
tensor_metas = self._get_tensor_metas(model_path)
51+
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
52+
dim_gen_pass_names = self._get_dim_gen_pass_names(model_path)
53+
dim_generalizer = self._get_dimension_generalizer(dim_gen_pass_names)
54+
inputs = dim_generalizer.create_inputs_by_metas(
55+
module=self._get_model(model_path),
56+
tensor_meta_attrs_list=tensor_meta_attrs_list,
57+
)
58+
dyn_dim_cstrs = DynamicDimConstraints.unserialize_from_py_file(
59+
os.path.join(model_path, "input_tensor_constraints.py")
60+
)
61+
dim_axes_pairs = self._get_dim_axes_pairs(dyn_dim_cstrs)
62+
assert len(dim_axes_pairs) > 0, f"No symbolic dims found. {model_path=}"
63+
64+
def get_generalized():
65+
return self._get_generalized_model_py_file_path(
66+
dim_generalizer=dim_generalizer,
67+
dim_axes_pairs=dim_axes_pairs,
68+
model_path=model_path,
69+
inputs=inputs,
70+
)
71+
72+
with get_generalized() as tmp_model_py_path:
73+
return Path(tmp_model_py_path).read_text()
74+
75+
def _get_dim_axes_pairs(self, dyn_dim_cstrs):
76+
sym_input_shapes = dyn_dim_cstrs.get_sorted_symbolic_input_shapes()
77+
return [
78+
(dim, axes)
79+
for symbol in dyn_dim_cstrs.symbols
80+
for dim in [dyn_dim_cstrs.symbol2example_value[symbol]]
81+
for axes in [
82+
[
83+
axis
84+
for shape in sym_input_shapes
85+
for axis, sym_or_dim in enumerate(shape)
86+
if sym_or_dim == symbol
87+
]
88+
]
89+
]
90+
91+
def _get_dim_gen_pass_names(self, model_path):
92+
json_value = gn_json.read_json(model_path)
93+
return json_value.get(gn_json.kDimensionGeneralizationPasses, [])
94+
95+
def _get_dimension_generalizer(self, dim_gen_pass_names):
96+
dim_generalizer = StaticToDynamic({"pass_names": dim_gen_pass_names})
97+
return dim_generalizer
98+
99+
def _get_model(self, model_path):
100+
py_module = load_module(os.path.join(model_path, "model.py"))
101+
GraphModule = getattr(py_module, "GraphModule")
102+
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
103+
return GraphModule()
104+
105+
@contextmanager
106+
def _get_generalized_model_py_file_path(
107+
self, dim_generalizer, dim_axes_pairs, model_path, inputs
108+
):
109+
model = self._get_model(model_path)
110+
dim_gen_pass = dim_generalizer(model, dim_axes_pairs)
111+
logging.warning("before need_rewrite")
112+
need_rewrite = dim_gen_pass.need_rewrite(inputs)
113+
logging.warning("after need_rewrite")
114+
if not need_rewrite:
115+
yield os.path.join(model_path, "model.py")
116+
return
117+
logging.warning("before rewrite")
118+
graph_module = dim_gen_pass.rewrite(inputs)
119+
logging.warning("after rewrite")
120+
with tempfile.TemporaryDirectory() as tmp_dir:
121+
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
122+
dim_gen_pass.save_graph_module(graph_module, tmp_dir)
123+
yield os.path.join(tmp_dir, "model.py")
124+
125+
def _get_tensor_metas(self, model_path):
126+
make = TensorMeta.unserialize_from_py_file
127+
return [
128+
*make(os.path.join(model_path, "input_meta.py")),
129+
*make(os.path.join(model_path, "weight_meta.py")),
130+
]

0 commit comments

Comments
 (0)