@@ -90,12 +90,12 @@ def __call__(self, model_path):
9090
9191 tensor_metas = self ._get_tensor_metas (model_path )
9292 tensor_meta_attrs_list = [asdict (tensor_meta ) for tensor_meta in tensor_metas ]
93- logging .warning (f "before create_inputs_by_metas" )
93+ logging .warning ("before create_inputs_by_metas" )
9494 inputs = self .get_dimension_generalizer ().create_inputs_by_metas (
9595 module = self .get_model (model_path ),
9696 tensor_meta_attrs_list = tensor_meta_attrs_list ,
9797 )
98- logging .warning (f "after create_inputs_by_metas" )
98+ logging .warning ("after create_inputs_by_metas" )
9999 dyn_dim_cstr = make_dyn_dim_cstr_from_tensor_metas (tensor_metas )
100100
101101 def data_input_predicator (input_var_name ):
@@ -157,23 +157,23 @@ def get_model(self, model_path):
157157
158158 @contextmanager
159159 def _try_dimension_generalization (self , dim_axes_pairs , model_path , inputs ):
160- logging .warning (f "enter _try_dimension_generalization" )
160+ logging .warning ("enter _try_dimension_generalization" )
161161 if self .config ["dimension_generalizer_filepath" ] is None :
162162 yield model_path , ()
163163 return
164164 model = self .get_model (model_path )
165165 dim_generalizer = self .get_dimension_generalizer ()
166166 dim_gen_pass = dim_generalizer (model , dim_axes_pairs )
167- logging .warning (f "before need_rewrite" )
167+ logging .warning ("before need_rewrite" )
168168 need_rewrite = dim_gen_pass .need_rewrite (inputs )
169- logging .warning (f "after need_rewrite" )
169+ logging .warning ("after need_rewrite" )
170170 if not need_rewrite :
171171 yield model_path , ()
172172 return
173173
174- logging .warning (f "before rewrite" )
174+ logging .warning ("before rewrite" )
175175 graph_module = dim_gen_pass .rewrite (inputs )
176- logging .warning (f "after rewrite" )
176+ logging .warning ("after rewrite" )
177177 with tempfile .TemporaryDirectory () as tmp_dir :
178178 shutil .copytree (Path (model_path ), Path (tmp_dir ), dirs_exist_ok = True )
179179 dim_gen_pass .save_graph_module (graph_module , tmp_dir )
@@ -344,17 +344,17 @@ def filter_fn(input_name, input_idx, axis, dim):
344344 (dim , axes ) for dim in unqiue_dims [: i + 1 ] for axes in [dim2axes [dim ]]
345345 )
346346 ctx_mgr = dyn_dim_cstr_feasibility_ctx_mgr
347- logging .warning (f "before dyn_dim_cstr_feasibility_ctx_mgr" )
347+ logging .warning ("before dyn_dim_cstr_feasibility_ctx_mgr" )
348348 with ctx_mgr (dim_axes_pairs ) as dyn_dim_cstr_feasibility :
349- logging .warning (f "enter dyn_dim_cstr_feasibility_ctx_mgr" )
349+ logging .warning ("enter dyn_dim_cstr_feasibility_ctx_mgr" )
350350 tmp_dyn_dim_cstr = copy .deepcopy (cur_dyn_dim_cstr )
351351 tmp_dyn_dim_cstr .update_symbol2example_value (sym2example_value )
352- logging .warning (f "before dyn_dim_cstr_feasibility" )
352+ logging .warning ("before dyn_dim_cstr_feasibility" )
353353 is_dyn_dim_cstr_feasible = dyn_dim_cstr_feasibility (tmp_dyn_dim_cstr )
354- logging .warning (f "after dyn_dim_cstr_feasibility" )
354+ logging .warning ("after dyn_dim_cstr_feasibility" )
355355 if not is_dyn_dim_cstr_feasible :
356356 continue
357357 dyn_dim_cstr = cur_dyn_dim_cstr
358358 append_dim_gen_pass_names (dyn_dim_cstr_feasibility .dim_gen_pass_names )
359- logging .warning (f "leave dyn_dim_cstr_feasibility_ctx_mgr" )
359+ logging .warning ("leave dyn_dim_cstr_feasibility_ctx_mgr" )
360360 return dyn_dim_cstr , total_dim_gen_pass_names
0 commit comments