Skip to content

Commit bf54b98

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/GraphNet into subgraph_dataset
2 parents 7ae9b9b + 0b25342 commit bf54b98

16 files changed

+548
-18
lines changed

graph_net/constraint_util.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import tempfile
1313
import shutil
1414
from pathlib import Path
15-
import json
1615
from dataclasses import asdict
1716

1817

@@ -187,12 +186,14 @@ def _save_model_to_log_file(self, model_path):
187186
shutil.copy(Path(model_path) / "model.py", log_file)
188187

189188
def _save_dim_gen_pass_names(self, dim_gen_pass_names, model_path):
190-
from graph_net.graph_net_json_file_util import kDimensionGeneralizationPasses
189+
from graph_net.graph_net_json_file_util import (
190+
kDimensionGeneralizationPasses,
191+
update_json,
192+
)
191193

192-
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
193-
graph_net_json = json.loads(graph_net_json_file_path.read_text())
194-
graph_net_json[kDimensionGeneralizationPasses] = list(dim_gen_pass_names)
195-
graph_net_json_file_path.write_text(json.dumps(graph_net_json))
194+
update_json(
195+
model_path, kDimensionGeneralizationPasses, list(dim_gen_pass_names)
196+
)
196197

197198
def _save_dyn_dim_cstr(self, dyn_dim_cstr, model_path):
198199
cstr_code = dyn_dim_cstr.serialize_to_py_str()

graph_net/dimension_generalizer.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import logging
2+
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
3+
from graph_net.imp_util import load_module
4+
from graph_net.tensor_meta import TensorMeta
5+
import functools
6+
import sys
7+
import os
8+
from contextlib import contextmanager
9+
import tempfile
10+
import shutil
11+
from pathlib import Path
12+
from dataclasses import asdict
13+
import graph_net.graph_net_json_file_util as gn_json
14+
15+
16+
class ApplyDimGenPasses:
17+
def __init__(self, config=None):
18+
if config is None:
19+
config = {}
20+
self.config = self._make_config(**config)
21+
self.num_handled_models = 0
22+
23+
def _make_config(
24+
self,
25+
output_dir: str,
26+
dimension_generalizer_filepath=None,
27+
dimension_generalizer_class_name="StaticToDynamic",
28+
dimension_generalizer_config=None,
29+
model_path_prefix="",
30+
resume=False,
31+
last_model_log_file=None,
32+
limits_handled_models=None,
33+
):
34+
if dimension_generalizer_config is None:
35+
dimension_generalizer_config = {}
36+
return {
37+
"resume": resume,
38+
"output_dir": output_dir,
39+
"model_path_prefix": model_path_prefix,
40+
"dimension_generalizer_filepath": dimension_generalizer_filepath,
41+
"dimension_generalizer_class_name": dimension_generalizer_class_name,
42+
"dimension_generalizer_config": dimension_generalizer_config,
43+
"last_model_log_file": last_model_log_file,
44+
"limits_handled_models": limits_handled_models,
45+
}
46+
47+
def __call__(self, rel_model_path):
48+
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
49+
output_dir = Path(self.config["output_dir"])
50+
output_dir.mkdir(parents=True, exist_ok=True)
51+
generalized_model_path = output_dir / rel_model_path
52+
if self.config["resume"] and (generalized_model_path / "model.py").exists():
53+
return
54+
tensor_metas = self._get_tensor_metas(model_path)
55+
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
56+
dim_gen_pass_names = self._get_dim_gen_pass_names(model_path)
57+
dim_generalizer = self._get_dimension_generalizer(dim_gen_pass_names)
58+
inputs = dim_generalizer.create_inputs_by_metas(
59+
module=self._get_model(model_path),
60+
tensor_meta_attrs_list=tensor_meta_attrs_list,
61+
)
62+
dyn_dim_cstrs = DynamicDimConstraints.unserialize_from_py_file(
63+
os.path.join(model_path, "input_tensor_constraints.py")
64+
)
65+
dim_axes_pairs = self._get_dim_axes_pairs(dyn_dim_cstrs)
66+
if len(dim_axes_pairs) == 0:
67+
return
68+
69+
def get_generalized():
70+
return self._get_generalized_model_py_file_path(
71+
dim_generalizer=dim_generalizer,
72+
dim_axes_pairs=dim_axes_pairs,
73+
model_path=model_path,
74+
inputs=inputs,
75+
)
76+
77+
with get_generalized() as generalized_model_py_path:
78+
self._save_generalized_model_path(rel_model_path, generalized_model_py_path)
79+
80+
self._check_num_handled_models()
81+
82+
def _save_generalized_model_path(self, rel_model_path, generalized_model_py_path):
83+
from_model_path = Path(self.config["model_path_prefix"]) / rel_model_path
84+
to_model_path = Path(self.config["output_dir"]) / rel_model_path
85+
print(f"{str(to_model_path)=}")
86+
to_model_path.mkdir(parents=True, exist_ok=True)
87+
shutil.copytree(Path(from_model_path), Path(to_model_path), dirs_exist_ok=True)
88+
generalized_model_py_code = Path(generalized_model_py_path).read_text()
89+
(to_model_path / "model.py").write_text(generalized_model_py_code)
90+
91+
def _get_dim_axes_pairs(self, dyn_dim_cstrs):
92+
sym_input_shapes = dyn_dim_cstrs.get_sorted_symbolic_input_shapes()
93+
return [
94+
(dim, axes)
95+
for symbol in dyn_dim_cstrs.symbols
96+
for dim in [dyn_dim_cstrs.symbol2example_value[symbol]]
97+
for axes in [
98+
[
99+
axis
100+
for shape in sym_input_shapes
101+
for axis, sym_or_dim in enumerate(shape)
102+
if sym_or_dim == symbol
103+
]
104+
]
105+
]
106+
107+
def _get_dim_gen_pass_names(self, model_path):
108+
json_value = gn_json.read_json(model_path)
109+
return json_value.get(gn_json.kDimensionGeneralizationPasses, [])
110+
111+
def _check_num_handled_models(self):
112+
self.num_handled_models += 1
113+
limits = self.config["limits_handled_models"]
114+
if limits is None:
115+
return
116+
if self.num_handled_models < limits:
117+
return
118+
print("`num_handled_models` exceeds config `limits_handled_models`")
119+
sys.exit(0)
120+
121+
def _get_dimension_generalizer(self, dim_gen_pass_names):
122+
assert self.config["dimension_generalizer_filepath"] is not None
123+
decorator_cls = getattr(
124+
load_module(self.config["dimension_generalizer_filepath"]),
125+
self.config["dimension_generalizer_class_name"],
126+
)
127+
config = {"pass_names": dim_gen_pass_names}
128+
dim_generalizer = decorator_cls(config)
129+
return dim_generalizer
130+
131+
def _get_model(self, model_path):
132+
py_module = load_module(os.path.join(model_path, "model.py"))
133+
GraphModule = getattr(py_module, "GraphModule")
134+
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
135+
return GraphModule()
136+
137+
@contextmanager
138+
def _get_generalized_model_py_file_path(
139+
self, dim_generalizer, dim_axes_pairs, model_path, inputs
140+
):
141+
model = self._get_model(model_path)
142+
dim_gen_pass = dim_generalizer(model, dim_axes_pairs)
143+
logging.warning("before need_rewrite")
144+
need_rewrite = dim_gen_pass.need_rewrite(inputs)
145+
logging.warning("after need_rewrite")
146+
if not need_rewrite:
147+
yield os.path.join(model_path, "model.py")
148+
return
149+
logging.warning("before rewrite")
150+
graph_module = dim_gen_pass.rewrite(inputs)
151+
logging.warning("after rewrite")
152+
with tempfile.TemporaryDirectory() as tmp_dir:
153+
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
154+
dim_gen_pass.save_graph_module(graph_module, tmp_dir)
155+
yield os.path.join(tmp_dir, "model.py")
156+
157+
def _get_tensor_metas(self, model_path):
158+
make = TensorMeta.unserialize_from_py_file
159+
return [
160+
*make(os.path.join(model_path, "input_meta.py")),
161+
*make(os.path.join(model_path, "weight_meta.py")),
162+
]
163+
164+
165+
def update_tensor_metas_by_dyn_dim_cstr(
166+
tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints
167+
):
168+
input_shapes = dyn_dim_cstr.get_reified_input_shapes()
169+
assert len(tensor_metas) == len(input_shapes)
170+
for i, tensor_meta in enumerate(tensor_metas):
171+
tensor_meta.shape = input_shapes[i]
172+
if tensor_meta.data is not None:
173+
assert isinstance(tensor_meta.data, (list, tuple))
174+
size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1)
175+
doubled_data = [*tensor_meta.data, *tensor_meta.data]
176+
tensor_meta.data = doubled_data[:size]

graph_net/dynamic_dim_constraints.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,21 @@ class DynamicDimConstraints:
2323
input_shapes: list[(tuple[sympy.Expr | int], str)]
2424
kInputShapes = "dynamic_dim_constraint_input_shapes"
2525

26+
def serialize_symbolic_input_shapes_to_str(self):
27+
input_shapes = self.get_sorted_symbolic_input_shapes()
28+
input_shapes_str = str(input_shapes).replace(" ", "")
29+
return input_shapes_str
30+
31+
def get_sorted_symbolic_input_shapes(self):
32+
return sorted(
33+
[
34+
tuple(shape)
35+
for shape, name in self.input_shapes
36+
if any(isinstance(dim, sympy.Expr) for dim in shape)
37+
],
38+
key=str,
39+
)
40+
2641
@classmethod
2742
def make_by_named_inputs(cls, named_shapes):
2843
return cls(
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,17 @@
1+
from pathlib import Path
2+
import json
3+
14
kDimensionGeneralizationPasses = "dimension_generalization_passes"
5+
kSymbolicDimensionReifier = "symbolic_dimension_reifier"
6+
7+
8+
def read_json(model_path):
9+
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
10+
return json.loads(graph_net_json_file_path.read_text())
11+
12+
13+
def update_json(model_path, field, value):
14+
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
15+
graph_net_json = json.loads(graph_net_json_file_path.read_text())
16+
graph_net_json[field] = value
17+
graph_net_json_file_path.write_text(json.dumps(graph_net_json, indent=4))
Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from pathlib import Path
22
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
3-
import sympy
3+
import graph_net.graph_net_json_file_util as gn_json
44

55

66
class GetInTensorSymbolicShapes:
77
def __init__(self, config):
88
self.config = self.make_config(**config)
99

10-
def make_config(self, model_path_prefix):
10+
def make_config(self, model_path_prefix, ignore_reified=True):
1111
return {
1212
"model_path_prefix": model_path_prefix,
13+
"ignore_reified": ignore_reified,
1314
}
1415

1516
def __call__(self, model_path):
@@ -18,17 +19,21 @@ def __call__(self, model_path):
1819
if not input_tensor_cstr_filepath.exists():
1920
print(f"get-in-tensor-symbolic-shapes None {model_path}")
2021
return
22+
if self.config["ignore_reified"] and self._found_reified_dims(
23+
str(original_model_path)
24+
):
25+
print(f"get-in-tensor-symbolic-shapes <reified> {model_path}")
26+
return
2127
dyn_dim_cstrs = DynamicDimConstraints.unserialize_from_py_file(
2228
str(input_tensor_cstr_filepath)
2329
)
24-
dyn_dim_cstrs.symbol2example_value = {}
25-
dyn_dim_cstrs.input_shapes = sorted(
26-
[
27-
tuple(shape)
28-
for shape, name in dyn_dim_cstrs.input_shapes
29-
if any(isinstance(dim, sympy.Expr) for dim in shape)
30-
],
31-
key=str,
32-
)
33-
input_shapes_str = str(dyn_dim_cstrs.input_shapes).replace(" ", "")
30+
input_shapes_str = str(dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str())
3431
print(f"get-in-tensor-symbolic-shapes {input_shapes_str} {model_path}")
32+
33+
def _found_reified_dims(self, model_path):
34+
json = gn_json.read_json(model_path)
35+
36+
if gn_json.kSymbolicDimensionReifier not in json:
37+
return False
38+
39+
return json[gn_json.kSymbolicDimensionReifier] is not None
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from pathlib import Path
2+
from graph_net.imp_util import load_module
3+
import graph_net.graph_net_json_file_util as gn_json
4+
5+
6+
class UpdateSymDimReifier:
7+
def __init__(self, config):
8+
self.config = self.make_config(**config)
9+
10+
def make_config(
11+
self,
12+
model_path_prefix,
13+
reifier_factory_path,
14+
reifier_factory_class_name,
15+
reifier_factory_config=None,
16+
resume=True,
17+
):
18+
if reifier_factory_config is None:
19+
reifier_factory_config = {}
20+
return {
21+
"reifier_factory_path": reifier_factory_path,
22+
"reifier_factory_class_name": reifier_factory_class_name,
23+
"reifier_factory_config": reifier_factory_config,
24+
"model_path_prefix": model_path_prefix,
25+
"resume": resume,
26+
}
27+
28+
def __call__(self, model_path):
29+
model_path_obj = Path(self.config["model_path_prefix"]) / model_path
30+
model_path = str(model_path_obj)
31+
input_tensor_cstr_filepath = model_path_obj / "input_tensor_constraints.py"
32+
if not input_tensor_cstr_filepath.exists():
33+
return
34+
if self.config["resume"] and self._found_reified_dims(model_path):
35+
return
36+
reifier_factory_class = self._get_reifier_factory_class()
37+
reifier_factory_instance = reifier_factory_class(
38+
config=self.config["reifier_factory_config"], model_path=model_path
39+
)
40+
matched_reifier_name = reifier_factory_instance.get_matched_reifier_name()
41+
if matched_reifier_name is None:
42+
return
43+
assert isinstance(matched_reifier_name, str), f"{type(matched_reifier_name)=}"
44+
gn_json.update_json(
45+
model_path, gn_json.kSymbolicDimensionReifier, matched_reifier_name
46+
)
47+
48+
def _get_reifier_factory_class(self):
49+
py_module = load_module(self.config["reifier_factory_path"])
50+
return getattr(py_module, self.config["reifier_factory_class_name"])
51+
52+
def _found_reified_dims(self, model_path):
53+
json = gn_json.read_json(model_path)
54+
55+
if gn_json.kSymbolicDimensionReifier not in json:
56+
return False
57+
58+
return json[gn_json.kSymbolicDimensionReifier] is not None
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
# model_runnable_predicator=ShapePropagatablePredicator
8+
model_runnable_predicator=ModelRunnablePredicator
9+
config_json_str=$(cat <<EOF
10+
{
11+
"handler_path": "$GRAPH_NET_ROOT/dimension_generalizer.py",
12+
"handler_class_name": "ApplyDimGenPasses",
13+
"handler_config": {
14+
"resume": true,
15+
"output_dir": "/tmp/dimension_generalized_samples",
16+
"model_path_prefix": "$GRAPH_NET_ROOT/../",
17+
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
18+
"dimension_generalizer_class_name": "StaticToDynamic",
19+
"limits_handled_models": 9999999,
20+
"last_model_log_file": "/tmp/a.py"
21+
}
22+
}
23+
EOF
24+
)
25+
CONFIG=$(echo $config_json_str | base64 -w 0)
26+
27+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt --handler-config=$CONFIG

graph_net/tools/get_in_tensor_symbolic_shapes.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ config_json_str=$(cat <<EOF
1111
"handler_path": "$GRAPH_NET_ROOT/tools/_get_in_tensor_symbolic_shapes.py",
1212
"handler_class_name": "GetInTensorSymbolicShapes",
1313
"handler_config": {
14+
"ignore_reified": true,
1415
"model_path_prefix": "$GRAPH_NET_ROOT/../"
1516
}
1617
}

0 commit comments

Comments
 (0)