11from graph_net .dynamic_dim_constraints import DynamicDimConstraints
2+ from contextlib import AbstractContextManager
23from graph_net .imp_util import load_module
34from graph_net .tensor_meta import TensorMeta
45from typing import Callable
@@ -21,6 +22,7 @@ def __init__(self, config=None):
2122 self .model_runnable_predicator = self ._make_model_runnable_predicator (
2223 self .config
2324 )
25+ self .num_successful_handled_models = 0
2426
2527 def _make_data_input_predicator (self , config ):
2628 module = load_module (config ["data_input_predicator_filepath" ])
@@ -45,6 +47,8 @@ def _make_config(
4547 dimension_generalizer_config = None ,
4648 model_path_prefix = "" ,
4749 resume = False ,
50+ last_model_log_file = None ,
51+ limits_successfully_handled_models = None ,
4852 ):
4953 if data_input_predicator_config is None :
5054 data_input_predicator_config = {}
@@ -64,6 +68,8 @@ def _make_config(
6468 "dimension_generalizer_filepath" : dimension_generalizer_filepath ,
6569 "dimension_generalizer_class_name" : dimension_generalizer_class_name ,
6670 "dimension_generalizer_config" : dimension_generalizer_config ,
71+ "last_model_log_file" : last_model_log_file ,
72+ "limits_successfully_handled_models" : limits_successfully_handled_models ,
6773 }
6874
6975 def __call__ (self , model_path ):
@@ -86,24 +92,42 @@ def __call__(self, model_path):
8692 def data_input_predicator (input_var_name ):
8793 return self .data_input_predicator (model_path , input_var_name )
8894
89- with self ._try_dimension_generalization (
90- model_path , tensor_metas
91- ) as tmp_model_path :
95+ def get_tmp_model_path_ctx_mgr (dim_axes_pairs ):
96+ return self ._try_dimension_generalization (
97+ dim_axes_pairs , model_path , tensor_metas
98+ )
9299
100+ def get_predicator_is_dyn_dim_cstr_feasible (tmp_model_path ):
93101 def is_dyn_dim_cstr_feasible (dyn_dim_cstr ):
94102 return self ._is_dyn_dim_cstr_feasible (
95103 tmp_model_path , tensor_metas , dyn_dim_cstr
96104 )
97105
98- dyn_dim_cstr = symbolize_data_input_dims (
99- dyn_dim_cstr ,
100- is_data_input = data_input_predicator ,
101- is_dyn_dim_cstr_feasible = is_dyn_dim_cstr_feasible ,
102- )
103- self ._save_dyn_dim_cstr (dyn_dim_cstr , model_path )
106+ return is_dyn_dim_cstr_feasible
107+
108+ dyn_dim_cstr_feasibility_ctx_mgr = DynDimCstrFeasibilityContextManager (
109+ get_tmp_model_path_ctx_mgr = get_tmp_model_path_ctx_mgr ,
110+ get_predicator_is_dyn_dim_cstr_feasible = get_predicator_is_dyn_dim_cstr_feasible ,
111+ )
112+ dyn_dim_cstr = symbolize_data_input_dims (
113+ dyn_dim_cstr ,
114+ is_data_input = data_input_predicator ,
115+ dyn_dim_cstr_feasibility_ctx_mgr = dyn_dim_cstr_feasibility_ctx_mgr ,
116+ )
117+ self ._save_dyn_dim_cstr (dyn_dim_cstr , model_path )
118+ if len (dyn_dim_cstr .symbols ) > 0 :
119+ self .num_successful_handled_models += 1
120+ limits = self .config ["limits_successfully_handled_models" ]
121+ if limits is not None :
122+ if self .num_successful_handled_models > limits :
123+ print (
124+ "`num_successful_handled_models` exceeds config `limits_successfully_handled_models`" ,
125+ file = sys .stderr ,
126+ )
127+ sys .exit (0 )
104128
105129 @contextmanager
106- def _try_dimension_generalization (self , model_path , tensor_metas ):
130+ def _try_dimension_generalization (self , dim_axes_pairs , model_path , tensor_metas ):
107131 if self .config ["dimension_generalizer_filepath" ] is None :
108132 yield model_path
109133 return
@@ -115,20 +139,23 @@ def _try_dimension_generalization(self, model_path, tensor_metas):
115139 load_module (self .config ["dimension_generalizer_filepath" ]),
116140 self .config ["dimension_generalizer_class_name" ],
117141 )
118- pass_obj = decorator_cls (self .config ["dimension_generalizer_config" ])(model )
119- if not pass_obj .need_rewrite ():
142+ dim_generalizer = decorator_cls (self .config ["dimension_generalizer_config" ])
143+ dim_gen_pass = dim_generalizer (model , dim_axes_pairs )
144+ if not dim_gen_pass .need_rewrite ():
120145 yield model_path
121146 return
122147 from dataclasses import asdict
123148
124149 tensor_meta_attrs_list = [asdict (tensor_meta ) for tensor_meta in tensor_metas ]
125- graph_module = pass_obj .rewrite_with_tensor_meta_attrs_list (
126- tensor_meta_attrs_list
150+ graph_module = dim_gen_pass .rewrite_with_tensor_meta_attrs_list (
151+ tensor_meta_attrs_list = tensor_meta_attrs_list ,
127152 )
128153 with tempfile .TemporaryDirectory () as tmp_dir :
129154 shutil .copytree (Path (model_path ), Path (tmp_dir ), dirs_exist_ok = True )
130- pass_obj .save_graph_module (graph_module , tmp_dir )
131- shutil .copy (Path (tmp_dir ) / "model.py" , Path ("/tmp/a.py" ))
155+ dim_gen_pass .save_graph_module (graph_module , tmp_dir )
156+ if self .config ["last_model_log_file" ] is not None :
157+ log_file = Path (self .config ["last_model_log_file" ])
158+ shutil .copy (Path (tmp_dir ) / "model.py" , log_file )
132159 yield tmp_dir
133160 # shutil.copytree(Path(tmp_dir), Path(model_path), dirs_exist_ok=True)
134161
@@ -190,10 +217,40 @@ def make_dyn_dim_cstr_from_tensor_metas(tensor_metas: list[TensorMeta]):
190217 )
191218
192219
220+ class DynDimCstrFeasibilityPredicator :
221+ def __init__ (
222+ self , is_dyn_dim_cstr_feasible : Callable [[DynamicDimConstraints ], bool ]
223+ ):
224+ self .is_dyn_dim_cstr_feasible = is_dyn_dim_cstr_feasible
225+
226+ def __call__ (self , dyn_dim_cstr : DynamicDimConstraints ) -> bool :
227+ return self .is_dyn_dim_cstr_feasible (dyn_dim_cstr )
228+
229+
230+ class DynDimCstrFeasibilityContextManager :
231+ def __init__ (
232+ self ,
233+ get_tmp_model_path_ctx_mgr ,
234+ get_predicator_is_dyn_dim_cstr_feasible ,
235+ ):
236+ self .get_tmp_model_path_ctx_mgr = get_tmp_model_path_ctx_mgr
237+ self .get_predicator_is_dyn_dim_cstr_feasible = (
238+ get_predicator_is_dyn_dim_cstr_feasible
239+ )
240+
241+ @contextmanager
242+ def __call__ (
243+ self , dim_axes_pairs
244+ ) -> AbstractContextManager [DynDimCstrFeasibilityPredicator ]:
245+ with self .get_tmp_model_path_ctx_mgr (dim_axes_pairs ) as tmp_model_apth :
246+ predicator = self .get_predicator_is_dyn_dim_cstr_feasible (tmp_model_apth )
247+ yield DynDimCstrFeasibilityPredicator (predicator )
248+
249+
193250def symbolize_data_input_dims (
194251 dyn_dim_cstr : DynamicDimConstraints ,
195252 is_data_input : Callable [[str ], bool ],
196- is_dyn_dim_cstr_feasible : Callable [[ DynamicDimConstraints ], bool ] ,
253+ dyn_dim_cstr_feasibility_ctx_mgr : DynDimCstrFeasibilityContextManager ,
197254) -> DynamicDimConstraints | None :
198255 """
199256 is_data_input: Callable[["input_var_name:str"], bool]
@@ -202,18 +259,21 @@ def symbolize_data_input_dims(
202259 Returns None if no symbolicable dim .
203260 """
204261 unqiue_dims = []
262+ dim2axes = {}
205263
206264 def dumpy_filter_fn (input_name , input_idx , axis , dim ):
207265 if is_data_input (input_name ):
208266 print ("data_input" , input_name , input_idx , axis , dim )
209267 if dim not in unqiue_dims :
210268 unqiue_dims .append (dim )
211- # No symbolization because of returning True
269+ dim2axes [dim ] = []
270+ dim2axes [dim ].append (axis )
271+ # No symbolization by returning False
212272 return False
213273
214274 # Collect input dimensions into `unqiue_dims`
215275 assert dyn_dim_cstr .symbolize (dumpy_filter_fn ) is None
216- for picked_dim in unqiue_dims :
276+ for i , picked_dim in enumerate ( unqiue_dims ) :
217277 cur_dyn_dim_cstr = copy .deepcopy (dyn_dim_cstr )
218278
219279 def filter_fn (input_name , input_idx , axis , dim ):
@@ -229,9 +289,15 @@ def filter_fn(input_name, input_idx, axis, dim):
229289 sym2example_value = {symbol : picked_dim + 1 }
230290 if not cur_dyn_dim_cstr .check_delta_symbol2example_value (sym2example_value ):
231291 continue
232- tmp_dyn_dim_cstr = copy .deepcopy (cur_dyn_dim_cstr )
233- tmp_dyn_dim_cstr .update_symbol2example_value (sym2example_value )
234- if not is_dyn_dim_cstr_feasible (tmp_dyn_dim_cstr ):
235- continue
236- dyn_dim_cstr = cur_dyn_dim_cstr
292+ dim_axes_pairs = tuple (
293+ (dim , axes ) for dim in unqiue_dims [: i + 1 ] for axes in [dim2axes [dim ]]
294+ )
295+ with dyn_dim_cstr_feasibility_ctx_mgr (
296+ dim_axes_pairs
297+ ) as is_dyn_dim_cstr_feasible :
298+ tmp_dyn_dim_cstr = copy .deepcopy (cur_dyn_dim_cstr )
299+ tmp_dyn_dim_cstr .update_symbol2example_value (sym2example_value )
300+ if not is_dyn_dim_cstr_feasible (tmp_dyn_dim_cstr ):
301+ continue
302+ dyn_dim_cstr = cur_dyn_dim_cstr
237303 return dyn_dim_cstr
0 commit comments