Skip to content

Commit adff744

Browse files
authored
Merge branch 'PaddlePaddle:develop' into develop
2 parents 9363023 + 7ee8b99 commit adff744

30 files changed

+2320
-137
lines changed

docs/torch_to_paddle_conversion_design.md

Lines changed: 685 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
samples/timm/resnetaa50d.d_in12k
1+
# samples/timm/resnetaa50d.d_in12k
22
samples/transformers-auto-model/opus-mt-en-gmw
3-
samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation
3+
# samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation

graph_net/constraint_util.py

Lines changed: 158 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
2+
from contextlib import AbstractContextManager
23
from graph_net.imp_util import load_module
34
from graph_net.tensor_meta import TensorMeta
45
from typing import Callable
56
import functools
67
import copy
78
import sys
89
import os
10+
from contextlib import contextmanager
11+
import tempfile
12+
import shutil
13+
from pathlib import Path
14+
import json
915

1016

1117
class UpdateInputTensorConstraints:
@@ -17,6 +23,7 @@ def __init__(self, config=None):
1723
self.model_runnable_predicator = self._make_model_runnable_predicator(
1824
self.config
1925
)
26+
self.num_successful_handled_models = 0
2027

2128
def _make_data_input_predicator(self, config):
2229
module = load_module(config["data_input_predicator_filepath"])
@@ -33,25 +40,37 @@ def _make_config(
3340
data_input_predicator_filepath,
3441
model_runnable_predicator_filepath,
3542
data_input_predicator_class_name="DataInputPredicator",
36-
data_input_predicator_config=None,
3743
model_runnable_predicator_class_name="ModelRunner",
44+
data_input_predicator_config=None,
3845
model_runnable_predicator_config=None,
46+
dimension_generalizer_filepath=None,
47+
dimension_generalizer_class_name="StaticToDynamic",
48+
dimension_generalizer_config=None,
3949
model_path_prefix="",
4050
resume=False,
51+
last_model_log_file=None,
52+
limits_successfully_handled_models=None,
4153
):
4254
if data_input_predicator_config is None:
4355
data_input_predicator_config = {}
4456
if model_runnable_predicator_config is None:
4557
model_runnable_predicator_config = {}
58+
if dimension_generalizer_config is None:
59+
dimension_generalizer_config = {}
4660
return {
61+
"resume": resume,
62+
"model_path_prefix": model_path_prefix,
4763
"data_input_predicator_filepath": data_input_predicator_filepath,
4864
"data_input_predicator_class_name": data_input_predicator_class_name,
4965
"data_input_predicator_config": data_input_predicator_config,
5066
"model_runnable_predicator_filepath": model_runnable_predicator_filepath,
5167
"model_runnable_predicator_class_name": model_runnable_predicator_class_name,
5268
"model_runnable_predicator_config": model_runnable_predicator_config,
53-
"model_path_prefix": model_path_prefix,
54-
"resume": resume,
69+
"dimension_generalizer_filepath": dimension_generalizer_filepath,
70+
"dimension_generalizer_class_name": dimension_generalizer_class_name,
71+
"dimension_generalizer_config": dimension_generalizer_config,
72+
"last_model_log_file": last_model_log_file,
73+
"limits_successfully_handled_models": limits_successfully_handled_models,
5574
}
5675

5776
def __call__(self, model_path):
@@ -74,17 +93,80 @@ def __call__(self, model_path):
7493
def data_input_predicator(input_var_name):
7594
return self.data_input_predicator(model_path, input_var_name)
7695

77-
def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
78-
return self._is_dyn_dim_cstr_feasible(
79-
model_path, tensor_metas, dyn_dim_cstr
96+
def get_tmp_model_path_ctx_mgr(dim_axes_pairs):
97+
return self._try_dimension_generalization(
98+
dim_axes_pairs, model_path, tensor_metas
8099
)
81100

82-
dyn_dim_cstr = symbolize_data_input_dims(
101+
def get_predicator_is_dyn_dim_cstr_feasible(tmp_model_path):
102+
def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
103+
return self._is_dyn_dim_cstr_feasible(
104+
tmp_model_path, tensor_metas, dyn_dim_cstr
105+
)
106+
107+
return is_dyn_dim_cstr_feasible
108+
109+
dyn_dim_cstr_feasibility_ctx_mgr = DynDimCstrFeasibilityContextManager(
110+
get_tmp_model_path_ctx_mgr=get_tmp_model_path_ctx_mgr,
111+
get_predicator_is_dyn_dim_cstr_feasible=get_predicator_is_dyn_dim_cstr_feasible,
112+
)
113+
dyn_dim_cstr, dim_gen_pass_names = symbolize_data_input_dims(
83114
dyn_dim_cstr,
84115
is_data_input=data_input_predicator,
85-
is_dyn_dim_cstr_feasible=is_dyn_dim_cstr_feasible,
116+
dyn_dim_cstr_feasibility_ctx_mgr=dyn_dim_cstr_feasibility_ctx_mgr,
86117
)
87118
self._save_dyn_dim_cstr(dyn_dim_cstr, model_path)
119+
self._save_dim_gen_pass_names(dim_gen_pass_names, model_path)
120+
if len(dyn_dim_cstr.symbols) > 0:
121+
self.num_successful_handled_models += 1
122+
limits = self.config["limits_successfully_handled_models"]
123+
if limits is not None:
124+
if self.num_successful_handled_models > limits:
125+
print(
126+
"`num_successful_handled_models` exceeds config `limits_successfully_handled_models`",
127+
file=sys.stderr,
128+
)
129+
sys.exit(0)
130+
131+
@contextmanager
132+
def _try_dimension_generalization(self, dim_axes_pairs, model_path, tensor_metas):
133+
if self.config["dimension_generalizer_filepath"] is None:
134+
yield model_path, ()
135+
return
136+
py_module = load_module(os.path.join(model_path, "model.py"))
137+
GraphModule = getattr(py_module, "GraphModule")
138+
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
139+
model = GraphModule()
140+
decorator_cls = getattr(
141+
load_module(self.config["dimension_generalizer_filepath"]),
142+
self.config["dimension_generalizer_class_name"],
143+
)
144+
dim_generalizer = decorator_cls(self.config["dimension_generalizer_config"])
145+
dim_gen_pass = dim_generalizer(model, dim_axes_pairs)
146+
if not dim_gen_pass.need_rewrite():
147+
yield model_path, ()
148+
return
149+
from dataclasses import asdict
150+
151+
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
152+
graph_module = dim_gen_pass.rewrite_with_tensor_meta_attrs_list(
153+
tensor_meta_attrs_list=tensor_meta_attrs_list,
154+
)
155+
with tempfile.TemporaryDirectory() as tmp_dir:
156+
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
157+
dim_gen_pass.save_graph_module(graph_module, tmp_dir)
158+
if self.config["last_model_log_file"] is not None:
159+
log_file = Path(self.config["last_model_log_file"])
160+
shutil.copy(Path(tmp_dir) / "model.py", log_file)
161+
yield tmp_dir, dim_gen_pass.get_pass_names()
162+
163+
def _save_dim_gen_pass_names(self, dim_gen_pass_names, model_path):
164+
from graph_net.graph_net_json_file_util import kDimensionGeneralizationPasses
165+
166+
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
167+
graph_net_json = json.loads(graph_net_json_file_path.read_text())
168+
graph_net_json[kDimensionGeneralizationPasses] = list(dim_gen_pass_names)
169+
graph_net_json_file_path.write_text(json.dumps(graph_net_json))
88170

89171
def _save_dyn_dim_cstr(self, dyn_dim_cstr, model_path):
90172
cstr_code = dyn_dim_cstr.serialize_to_py_str()
@@ -106,7 +188,6 @@ def _is_dyn_dim_cstr_feasible(
106188
weight_meta_code = "\n".join(
107189
tensor_meta.serialize_to_py_str() for tensor_meta in tensor_metas
108190
)
109-
import tempfile
110191

111192
with tempfile.TemporaryDirectory() as tmpdir:
112193
for filename in ["graph_net.json", "model.py"]:
@@ -145,30 +226,82 @@ def make_dyn_dim_cstr_from_tensor_metas(tensor_metas: list[TensorMeta]):
145226
)
146227

147228

229+
class DynDimCstrFeasibilityPredicator:
230+
def __init__(
231+
self,
232+
is_dyn_dim_cstr_feasible: Callable[[DynamicDimConstraints], bool],
233+
dim_gen_pass_names: tuple[str],
234+
):
235+
self.is_dyn_dim_cstr_feasible = is_dyn_dim_cstr_feasible
236+
self.dim_gen_pass_names = dim_gen_pass_names
237+
238+
def __call__(self, dyn_dim_cstr: DynamicDimConstraints) -> bool:
239+
return self.is_dyn_dim_cstr_feasible(dyn_dim_cstr)
240+
241+
242+
class DynDimCstrFeasibilityContextManager:
243+
def __init__(
244+
self,
245+
get_tmp_model_path_ctx_mgr,
246+
get_predicator_is_dyn_dim_cstr_feasible,
247+
):
248+
self.get_tmp_model_path_ctx_mgr = get_tmp_model_path_ctx_mgr
249+
self.get_predicator_is_dyn_dim_cstr_feasible = (
250+
get_predicator_is_dyn_dim_cstr_feasible
251+
)
252+
253+
@contextmanager
254+
def __call__(
255+
self, dim_axes_pairs
256+
) -> AbstractContextManager[DynDimCstrFeasibilityPredicator]:
257+
ctx_mgr = self.get_tmp_model_path_ctx_mgr
258+
with ctx_mgr(dim_axes_pairs) as (tmp_model_apth, dg_pass_names):
259+
predicator = self.get_predicator_is_dyn_dim_cstr_feasible(tmp_model_apth)
260+
yield DynDimCstrFeasibilityPredicator(predicator, dg_pass_names)
261+
262+
148263
def symbolize_data_input_dims(
149264
dyn_dim_cstr: DynamicDimConstraints,
150265
is_data_input: Callable[[str], bool],
151-
is_dyn_dim_cstr_feasible: Callable[[DynamicDimConstraints], bool],
152-
) -> DynamicDimConstraints | None:
266+
dyn_dim_cstr_feasibility_ctx_mgr: DynDimCstrFeasibilityContextManager,
267+
) -> (DynamicDimConstraints | None, tuple[str]):
153268
"""
154269
is_data_input: Callable[["input_var_name:str"], bool]
155270
Symbolizes data input dimensions as much as possible.
156271
Returns new DynamicDimConstraints if success.
157272
Returns None if no symbolicable dim .
158273
"""
159274
unqiue_dims = []
275+
dim2axes = {}
160276

161277
def dumpy_filter_fn(input_name, input_idx, axis, dim):
162278
if is_data_input(input_name):
163279
print("data_input", input_name, input_idx, axis, dim)
164280
if dim not in unqiue_dims:
165281
unqiue_dims.append(dim)
166-
# No symbolization because of returning True
282+
dim2axes[dim] = []
283+
dim2axes[dim].append(axis)
284+
# No symbolization by returning False
167285
return False
168286

169287
# Collect input dimensions into `unqiue_dims`
170288
assert dyn_dim_cstr.symbolize(dumpy_filter_fn) is None
171-
for picked_dim in unqiue_dims:
289+
total_dim_gen_pass_names = ()
290+
291+
def append_dim_gen_pass_names(dim_gen_pass_names):
292+
nonlocal total_dim_gen_pass_names
293+
total_dim_gen_pass_names = tuple(
294+
[
295+
*total_dim_gen_pass_names,
296+
*(
297+
pass_name
298+
for pass_name in dim_gen_pass_names
299+
if pass_name not in total_dim_gen_pass_names
300+
),
301+
]
302+
)
303+
304+
for i, picked_dim in enumerate(unqiue_dims):
172305
cur_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
173306

174307
def filter_fn(input_name, input_idx, axis, dim):
@@ -184,9 +317,15 @@ def filter_fn(input_name, input_idx, axis, dim):
184317
sym2example_value = {symbol: picked_dim + 1}
185318
if not cur_dyn_dim_cstr.check_delta_symbol2example_value(sym2example_value):
186319
continue
187-
tmp_dyn_dim_cstr = copy.deepcopy(cur_dyn_dim_cstr)
188-
tmp_dyn_dim_cstr.update_symbol2example_value(sym2example_value)
189-
if not is_dyn_dim_cstr_feasible(tmp_dyn_dim_cstr):
190-
continue
191-
dyn_dim_cstr = cur_dyn_dim_cstr
192-
return dyn_dim_cstr
320+
dim_axes_pairs = tuple(
321+
(dim, axes) for dim in unqiue_dims[: i + 1] for axes in [dim2axes[dim]]
322+
)
323+
ctx_mgr = dyn_dim_cstr_feasibility_ctx_mgr
324+
with ctx_mgr(dim_axes_pairs) as dyn_dim_cstr_feasibility:
325+
tmp_dyn_dim_cstr = copy.deepcopy(cur_dyn_dim_cstr)
326+
tmp_dyn_dim_cstr.update_symbol2example_value(sym2example_value)
327+
if not dyn_dim_cstr_feasibility(tmp_dyn_dim_cstr):
328+
continue
329+
dyn_dim_cstr = cur_dyn_dim_cstr
330+
append_dim_gen_pass_names(dyn_dim_cstr_feasibility.dim_gen_pass_names)
331+
return dyn_dim_cstr, total_dim_gen_pass_names
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
kDimensionGeneralizationPasses = "dimension_generalization_passes"

graph_net/imp_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ def load_module(path, name="unamed"):
55
spec = imp.spec_from_file_location(name, path)
66
module = imp.module_from_spec(spec)
77
spec.loader.exec_module(module)
8+
module.__graph_net_file_path__ = path
89
return module

graph_net/model_path_handler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import traceback
12
import argparse
23
import importlib.util
34
from graph_net.imp_util import load_module
@@ -44,7 +45,11 @@ def main(args):
4445
except KeyboardInterrupt:
4546
sys.exit(-1)
4647
except Exception as e:
47-
pass
48+
print("--- Concise Error Message ---")
49+
print(e)
50+
51+
print("\n--- Full Traceback ---")
52+
traceback.print_exc()
4853

4954

5055
def _get_model_paths(args):

0 commit comments

Comments
 (0)