11from graph_net .dynamic_dim_constraints import DynamicDimConstraints
2+ from contextlib import AbstractContextManager
23from graph_net .imp_util import load_module
34from graph_net .tensor_meta import TensorMeta
45from typing import Callable
56import functools
67import copy
78import sys
89import os
10+ from contextlib import contextmanager
11+ import tempfile
12+ import shutil
13+ from pathlib import Path
14+ import json
915
1016
1117class UpdateInputTensorConstraints :
@@ -17,6 +23,7 @@ def __init__(self, config=None):
1723 self .model_runnable_predicator = self ._make_model_runnable_predicator (
1824 self .config
1925 )
26+ self .num_successful_handled_models = 0
2027
2128 def _make_data_input_predicator (self , config ):
2229 module = load_module (config ["data_input_predicator_filepath" ])
@@ -33,25 +40,37 @@ def _make_config(
3340 data_input_predicator_filepath ,
3441 model_runnable_predicator_filepath ,
3542 data_input_predicator_class_name = "DataInputPredicator" ,
36- data_input_predicator_config = None ,
3743 model_runnable_predicator_class_name = "ModelRunner" ,
44+ data_input_predicator_config = None ,
3845 model_runnable_predicator_config = None ,
46+ dimension_generalizer_filepath = None ,
47+ dimension_generalizer_class_name = "StaticToDynamic" ,
48+ dimension_generalizer_config = None ,
3949 model_path_prefix = "" ,
4050 resume = False ,
51+ last_model_log_file = None ,
52+ limits_successfully_handled_models = None ,
4153 ):
4254 if data_input_predicator_config is None :
4355 data_input_predicator_config = {}
4456 if model_runnable_predicator_config is None :
4557 model_runnable_predicator_config = {}
58+ if dimension_generalizer_config is None :
59+ dimension_generalizer_config = {}
4660 return {
61+ "resume" : resume ,
62+ "model_path_prefix" : model_path_prefix ,
4763 "data_input_predicator_filepath" : data_input_predicator_filepath ,
4864 "data_input_predicator_class_name" : data_input_predicator_class_name ,
4965 "data_input_predicator_config" : data_input_predicator_config ,
5066 "model_runnable_predicator_filepath" : model_runnable_predicator_filepath ,
5167 "model_runnable_predicator_class_name" : model_runnable_predicator_class_name ,
5268 "model_runnable_predicator_config" : model_runnable_predicator_config ,
53- "model_path_prefix" : model_path_prefix ,
54- "resume" : resume ,
69+ "dimension_generalizer_filepath" : dimension_generalizer_filepath ,
70+ "dimension_generalizer_class_name" : dimension_generalizer_class_name ,
71+ "dimension_generalizer_config" : dimension_generalizer_config ,
72+ "last_model_log_file" : last_model_log_file ,
73+ "limits_successfully_handled_models" : limits_successfully_handled_models ,
5574 }
5675
5776 def __call__ (self , model_path ):
@@ -74,17 +93,80 @@ def __call__(self, model_path):
7493 def data_input_predicator (input_var_name ):
7594 return self .data_input_predicator (model_path , input_var_name )
7695
77- def is_dyn_dim_cstr_feasible ( dyn_dim_cstr ):
78- return self ._is_dyn_dim_cstr_feasible (
79- model_path , tensor_metas , dyn_dim_cstr
96+ def get_tmp_model_path_ctx_mgr ( dim_axes_pairs ):
97+ return self ._try_dimension_generalization (
98+ dim_axes_pairs , model_path , tensor_metas
8099 )
81100
82- dyn_dim_cstr = symbolize_data_input_dims (
101+ def get_predicator_is_dyn_dim_cstr_feasible (tmp_model_path ):
102+ def is_dyn_dim_cstr_feasible (dyn_dim_cstr ):
103+ return self ._is_dyn_dim_cstr_feasible (
104+ tmp_model_path , tensor_metas , dyn_dim_cstr
105+ )
106+
107+ return is_dyn_dim_cstr_feasible
108+
109+ dyn_dim_cstr_feasibility_ctx_mgr = DynDimCstrFeasibilityContextManager (
110+ get_tmp_model_path_ctx_mgr = get_tmp_model_path_ctx_mgr ,
111+ get_predicator_is_dyn_dim_cstr_feasible = get_predicator_is_dyn_dim_cstr_feasible ,
112+ )
113+ dyn_dim_cstr , dim_gen_pass_names = symbolize_data_input_dims (
83114 dyn_dim_cstr ,
84115 is_data_input = data_input_predicator ,
85- is_dyn_dim_cstr_feasible = is_dyn_dim_cstr_feasible ,
116+ dyn_dim_cstr_feasibility_ctx_mgr = dyn_dim_cstr_feasibility_ctx_mgr ,
86117 )
87118 self ._save_dyn_dim_cstr (dyn_dim_cstr , model_path )
119+ self ._save_dim_gen_pass_names (dim_gen_pass_names , model_path )
120+ if len (dyn_dim_cstr .symbols ) > 0 :
121+ self .num_successful_handled_models += 1
122+ limits = self .config ["limits_successfully_handled_models" ]
123+ if limits is not None :
124+ if self .num_successful_handled_models > limits :
125+ print (
126+ "`num_successful_handled_models` exceeds config `limits_successfully_handled_models`" ,
127+ file = sys .stderr ,
128+ )
129+ sys .exit (0 )
130+
131+ @contextmanager
132+ def _try_dimension_generalization (self , dim_axes_pairs , model_path , tensor_metas ):
133+ if self .config ["dimension_generalizer_filepath" ] is None :
134+ yield model_path , ()
135+ return
136+ py_module = load_module (os .path .join (model_path , "model.py" ))
137+ GraphModule = getattr (py_module , "GraphModule" )
138+ GraphModule .__graph_net_file_path__ = py_module .__graph_net_file_path__
139+ model = GraphModule ()
140+ decorator_cls = getattr (
141+ load_module (self .config ["dimension_generalizer_filepath" ]),
142+ self .config ["dimension_generalizer_class_name" ],
143+ )
144+ dim_generalizer = decorator_cls (self .config ["dimension_generalizer_config" ])
145+ dim_gen_pass = dim_generalizer (model , dim_axes_pairs )
146+ if not dim_gen_pass .need_rewrite ():
147+ yield model_path , ()
148+ return
149+ from dataclasses import asdict
150+
151+ tensor_meta_attrs_list = [asdict (tensor_meta ) for tensor_meta in tensor_metas ]
152+ graph_module = dim_gen_pass .rewrite_with_tensor_meta_attrs_list (
153+ tensor_meta_attrs_list = tensor_meta_attrs_list ,
154+ )
155+ with tempfile .TemporaryDirectory () as tmp_dir :
156+ shutil .copytree (Path (model_path ), Path (tmp_dir ), dirs_exist_ok = True )
157+ dim_gen_pass .save_graph_module (graph_module , tmp_dir )
158+ if self .config ["last_model_log_file" ] is not None :
159+ log_file = Path (self .config ["last_model_log_file" ])
160+ shutil .copy (Path (tmp_dir ) / "model.py" , log_file )
161+ yield tmp_dir , dim_gen_pass .get_pass_names ()
162+
163+ def _save_dim_gen_pass_names (self , dim_gen_pass_names , model_path ):
164+ from graph_net .graph_net_json_file_util import kDimensionGeneralizationPasses
165+
166+ graph_net_json_file_path = Path (f"{ model_path } /graph_net.json" )
167+ graph_net_json = json .loads (graph_net_json_file_path .read_text ())
168+ graph_net_json [kDimensionGeneralizationPasses ] = list (dim_gen_pass_names )
169+ graph_net_json_file_path .write_text (json .dumps (graph_net_json ))
88170
89171 def _save_dyn_dim_cstr (self , dyn_dim_cstr , model_path ):
90172 cstr_code = dyn_dim_cstr .serialize_to_py_str ()
@@ -106,7 +188,6 @@ def _is_dyn_dim_cstr_feasible(
106188 weight_meta_code = "\n " .join (
107189 tensor_meta .serialize_to_py_str () for tensor_meta in tensor_metas
108190 )
109- import tempfile
110191
111192 with tempfile .TemporaryDirectory () as tmpdir :
112193 for filename in ["graph_net.json" , "model.py" ]:
@@ -145,30 +226,82 @@ def make_dyn_dim_cstr_from_tensor_metas(tensor_metas: list[TensorMeta]):
145226 )
146227
147228
229+ class DynDimCstrFeasibilityPredicator :
230+ def __init__ (
231+ self ,
232+ is_dyn_dim_cstr_feasible : Callable [[DynamicDimConstraints ], bool ],
233+ dim_gen_pass_names : tuple [str ],
234+ ):
235+ self .is_dyn_dim_cstr_feasible = is_dyn_dim_cstr_feasible
236+ self .dim_gen_pass_names = dim_gen_pass_names
237+
238+ def __call__ (self , dyn_dim_cstr : DynamicDimConstraints ) -> bool :
239+ return self .is_dyn_dim_cstr_feasible (dyn_dim_cstr )
240+
241+
242+ class DynDimCstrFeasibilityContextManager :
243+ def __init__ (
244+ self ,
245+ get_tmp_model_path_ctx_mgr ,
246+ get_predicator_is_dyn_dim_cstr_feasible ,
247+ ):
248+ self .get_tmp_model_path_ctx_mgr = get_tmp_model_path_ctx_mgr
249+ self .get_predicator_is_dyn_dim_cstr_feasible = (
250+ get_predicator_is_dyn_dim_cstr_feasible
251+ )
252+
253+ @contextmanager
254+ def __call__ (
255+ self , dim_axes_pairs
256+ ) -> AbstractContextManager [DynDimCstrFeasibilityPredicator ]:
257+ ctx_mgr = self .get_tmp_model_path_ctx_mgr
258+ with ctx_mgr (dim_axes_pairs ) as (tmp_model_apth , dg_pass_names ):
259+ predicator = self .get_predicator_is_dyn_dim_cstr_feasible (tmp_model_apth )
260+ yield DynDimCstrFeasibilityPredicator (predicator , dg_pass_names )
261+
262+
148263def symbolize_data_input_dims (
149264 dyn_dim_cstr : DynamicDimConstraints ,
150265 is_data_input : Callable [[str ], bool ],
151- is_dyn_dim_cstr_feasible : Callable [[ DynamicDimConstraints ], bool ] ,
152- ) -> DynamicDimConstraints | None :
266+ dyn_dim_cstr_feasibility_ctx_mgr : DynDimCstrFeasibilityContextManager ,
267+ ) -> ( DynamicDimConstraints | None , tuple [ str ]) :
153268 """
154269 is_data_input: Callable[["input_var_name:str"], bool]
155270 Symbolizes data input dimensions as much as possible.
156271 Returns new DynamicDimConstraints if success.
157272 Returns None if no symbolicable dim .
158273 """
159274 unqiue_dims = []
275+ dim2axes = {}
160276
161277 def dumpy_filter_fn (input_name , input_idx , axis , dim ):
162278 if is_data_input (input_name ):
163279 print ("data_input" , input_name , input_idx , axis , dim )
164280 if dim not in unqiue_dims :
165281 unqiue_dims .append (dim )
166- # No symbolization because of returning True
282+ dim2axes [dim ] = []
283+ dim2axes [dim ].append (axis )
284+ # No symbolization by returning False
167285 return False
168286
169287 # Collect input dimensions into `unqiue_dims`
170288 assert dyn_dim_cstr .symbolize (dumpy_filter_fn ) is None
171- for picked_dim in unqiue_dims :
289+ total_dim_gen_pass_names = ()
290+
291+ def append_dim_gen_pass_names (dim_gen_pass_names ):
292+ nonlocal total_dim_gen_pass_names
293+ total_dim_gen_pass_names = tuple (
294+ [
295+ * total_dim_gen_pass_names ,
296+ * (
297+ pass_name
298+ for pass_name in dim_gen_pass_names
299+ if pass_name not in total_dim_gen_pass_names
300+ ),
301+ ]
302+ )
303+
304+ for i , picked_dim in enumerate (unqiue_dims ):
172305 cur_dyn_dim_cstr = copy .deepcopy (dyn_dim_cstr )
173306
174307 def filter_fn (input_name , input_idx , axis , dim ):
@@ -184,9 +317,15 @@ def filter_fn(input_name, input_idx, axis, dim):
184317 sym2example_value = {symbol : picked_dim + 1 }
185318 if not cur_dyn_dim_cstr .check_delta_symbol2example_value (sym2example_value ):
186319 continue
187- tmp_dyn_dim_cstr = copy .deepcopy (cur_dyn_dim_cstr )
188- tmp_dyn_dim_cstr .update_symbol2example_value (sym2example_value )
189- if not is_dyn_dim_cstr_feasible (tmp_dyn_dim_cstr ):
190- continue
191- dyn_dim_cstr = cur_dyn_dim_cstr
192- return dyn_dim_cstr
320+ dim_axes_pairs = tuple (
321+ (dim , axes ) for dim in unqiue_dims [: i + 1 ] for axes in [dim2axes [dim ]]
322+ )
323+ ctx_mgr = dyn_dim_cstr_feasibility_ctx_mgr
324+ with ctx_mgr (dim_axes_pairs ) as dyn_dim_cstr_feasibility :
325+ tmp_dyn_dim_cstr = copy .deepcopy (cur_dyn_dim_cstr )
326+ tmp_dyn_dim_cstr .update_symbol2example_value (sym2example_value )
327+ if not dyn_dim_cstr_feasibility (tmp_dyn_dim_cstr ):
328+ continue
329+ dyn_dim_cstr = cur_dyn_dim_cstr
330+ append_dim_gen_pass_names (dyn_dim_cstr_feasibility .dim_gen_pass_names )
331+ return dyn_dim_cstr , total_dim_gen_pass_names
0 commit comments