1+ import logging
12from graph_net .dynamic_dim_constraints import DynamicDimConstraints
23from contextlib import AbstractContextManager
34from graph_net .imp_util import load_module
@@ -88,14 +89,21 @@ def __call__(self, model_path):
8889 return
8990
9091 tensor_metas = self ._get_tensor_metas (model_path )
92+ tensor_meta_attrs_list = [asdict (tensor_meta ) for tensor_meta in tensor_metas ]
93+ logging .warning (f"before create_inputs_by_metas" )
94+ inputs = self .get_dimension_generalizer ().create_inputs_by_metas (
95+ module = self .get_model (model_path ),
96+ tensor_meta_attrs_list = tensor_meta_attrs_list ,
97+ )
98+ logging .warning (f"after create_inputs_by_metas" )
9199 dyn_dim_cstr = make_dyn_dim_cstr_from_tensor_metas (tensor_metas )
92100
93101 def data_input_predicator (input_var_name ):
94102 return self .data_input_predicator (model_path , input_var_name )
95103
96104 def get_tmp_model_path_ctx_mgr (dim_axes_pairs ):
97105 return self ._try_dimension_generalization (
98- dim_axes_pairs , model_path , tensor_metas
106+ dim_axes_pairs , model_path , inputs
99107 )
100108
101109 def get_predicator_is_dyn_dim_cstr_feasible (tmp_model_path ):
@@ -128,28 +136,44 @@ def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
128136 )
129137 sys .exit (0 )
130138
131- @contextmanager
132- def _try_dimension_generalization (self , dim_axes_pairs , model_path , tensor_metas ):
133- if self .config ["dimension_generalizer_filepath" ] is None :
134- yield model_path , ()
135- return
136- py_module = load_module (os .path .join (model_path , "model.py" ))
137- GraphModule = getattr (py_module , "GraphModule" )
138- GraphModule .__graph_net_file_path__ = py_module .__graph_net_file_path__
139- model = GraphModule ()
139+ def get_dimension_generalizer (self ):
140+ if hasattr (self , "_dim_generalizer" ):
141+ return self ._dim_generalizer
142+ assert self .config ["dimension_generalizer_filepath" ] is not None
140143 decorator_cls = getattr (
141144 load_module (self .config ["dimension_generalizer_filepath" ]),
142145 self .config ["dimension_generalizer_class_name" ],
143146 )
144- dim_generalizer = decorator_cls (self .config ["dimension_generalizer_config" ])
147+ self ._dim_generalizer = decorator_cls (
148+ self .config ["dimension_generalizer_config" ]
149+ )
150+ return self ._dim_generalizer
151+
152+ def get_model (self , model_path ):
153+ py_module = load_module (os .path .join (model_path , "model.py" ))
154+ GraphModule = getattr (py_module , "GraphModule" )
155+ GraphModule .__graph_net_file_path__ = py_module .__graph_net_file_path__
156+ return GraphModule ()
157+
158+ @contextmanager
159+ def _try_dimension_generalization (self , dim_axes_pairs , model_path , inputs ):
160+ logging .warning (f"enter _try_dimension_generalization" )
161+ if self .config ["dimension_generalizer_filepath" ] is None :
162+ yield model_path , ()
163+ return
164+ model = self .get_model (model_path )
165+ dim_generalizer = self .get_dimension_generalizer ()
145166 dim_gen_pass = dim_generalizer (model , dim_axes_pairs )
146- tensor_meta_attrs_list = [asdict (tensor_meta ) for tensor_meta in tensor_metas ]
147- inputs = dim_gen_pass .create_inputs_by_metas (tensor_meta_attrs_list )
148- if not dim_gen_pass .need_rewrite (inputs ):
167+ logging .warning (f"before need_rewrite" )
168+ need_rewrite = dim_gen_pass .need_rewrite (inputs )
169+ logging .warning (f"after need_rewrite" )
170+ if not need_rewrite :
149171 yield model_path , ()
150172 return
151173
174+ logging .warning (f"before rewrite" )
152175 graph_module = dim_gen_pass .rewrite (inputs )
176+ logging .warning (f"after rewrite" )
153177 with tempfile .TemporaryDirectory () as tmp_dir :
154178 shutil .copytree (Path (model_path ), Path (tmp_dir ), dirs_exist_ok = True )
155179 dim_gen_pass .save_graph_module (graph_module , tmp_dir )
@@ -300,6 +324,7 @@ def append_dim_gen_pass_names(dim_gen_pass_names):
300324 )
301325
302326 for i , picked_dim in enumerate (unqiue_dims ):
327+ logging .warning (f"{ i = } { picked_dim = } " )
303328 cur_dyn_dim_cstr = copy .deepcopy (dyn_dim_cstr )
304329
305330 def filter_fn (input_name , input_idx , axis , dim ):
@@ -319,11 +344,17 @@ def filter_fn(input_name, input_idx, axis, dim):
319344 (dim , axes ) for dim in unqiue_dims [: i + 1 ] for axes in [dim2axes [dim ]]
320345 )
321346 ctx_mgr = dyn_dim_cstr_feasibility_ctx_mgr
347+ logging .warning (f"before dyn_dim_cstr_feasibility_ctx_mgr" )
322348 with ctx_mgr (dim_axes_pairs ) as dyn_dim_cstr_feasibility :
349+ logging .warning (f"enter dyn_dim_cstr_feasibility_ctx_mgr" )
323350 tmp_dyn_dim_cstr = copy .deepcopy (cur_dyn_dim_cstr )
324351 tmp_dyn_dim_cstr .update_symbol2example_value (sym2example_value )
325- if not dyn_dim_cstr_feasibility (tmp_dyn_dim_cstr ):
352+ logging .warning (f"before dyn_dim_cstr_feasibility" )
353+ is_dyn_dim_cstr_feasible = dyn_dim_cstr_feasibility (tmp_dyn_dim_cstr )
354+ logging .warning (f"after dyn_dim_cstr_feasibility" )
355+ if not is_dyn_dim_cstr_feasible :
326356 continue
327357 dyn_dim_cstr = cur_dyn_dim_cstr
328358 append_dim_gen_pass_names (dyn_dim_cstr_feasibility .dim_gen_pass_names )
359+ logging .warning (f"leave dyn_dim_cstr_feasibility_ctx_mgr" )
329360 return dyn_dim_cstr , total_dim_gen_pass_names
0 commit comments