@@ -90,12 +90,12 @@ def __call__(self, model_path):
9090
9191 tensor_metas = self ._get_tensor_metas (model_path )
9292 tensor_meta_attrs_list = [asdict (tensor_meta ) for tensor_meta in tensor_metas ]
93- logging .warning (f "before create_inputs_by_metas" )
93+ logging .warning ("before create_inputs_by_metas" )
9494 inputs = self .get_dimension_generalizer ().create_inputs_by_metas (
9595 module = self .get_model (model_path ),
9696 tensor_meta_attrs_list = tensor_meta_attrs_list ,
9797 )
98- logging .warning (f "after create_inputs_by_metas" )
98+ logging .warning ("after create_inputs_by_metas" )
9999 dyn_dim_cstr = make_dyn_dim_cstr_from_tensor_metas (tensor_metas )
100100
101101 def data_input_predicator (input_var_name ):
@@ -157,23 +157,23 @@ def get_model(self, model_path):
157157
158158 @contextmanager
159159 def _try_dimension_generalization (self , dim_axes_pairs , model_path , inputs ):
160- logging .warning (f "enter _try_dimension_generalization" )
160+ logging .warning ("enter _try_dimension_generalization" )
161161 if self .config ["dimension_generalizer_filepath" ] is None :
162162 yield model_path , ()
163163 return
164164 model = self .get_model (model_path )
165165 dim_generalizer = self .get_dimension_generalizer ()
166166 dim_gen_pass = dim_generalizer (model , dim_axes_pairs )
167- logging .warning (f "before need_rewrite" )
167+ logging .warning ("before need_rewrite" )
168168 need_rewrite = dim_gen_pass .need_rewrite (inputs )
169- logging .warning (f "after need_rewrite" )
169+ logging .warning ("after need_rewrite" )
170170 if not need_rewrite :
171171 yield model_path , ()
172172 return
173173
174- logging .warning (f "before rewrite" )
174+ logging .warning ("before rewrite" )
175175 graph_module = dim_gen_pass .rewrite (inputs )
176- logging .warning (f "after rewrite" )
176+ logging .warning ("after rewrite" )
177177 with tempfile .TemporaryDirectory () as tmp_dir :
178178 shutil .copytree (Path (model_path ), Path (tmp_dir ), dirs_exist_ok = True )
179179 dim_gen_pass .save_graph_module (graph_module , tmp_dir )
@@ -293,20 +293,20 @@ def symbolize_data_input_dims(
293293 Returns new DynamicDimConstraints if success.
294294 Returns None if no symbolicable dim .
295295 """
296- unqiue_dims = []
296+ unique_dims = []
297297 dim2axes = {}
298298
299299 def dumpy_filter_fn (input_name , input_idx , axis , dim ):
300300 if is_data_input (input_name ):
301301 print ("data_input" , input_name , input_idx , axis , dim )
302- if dim not in unqiue_dims :
303- unqiue_dims .append (dim )
302+ if dim not in unique_dims :
303+ unique_dims .append (dim )
304304 dim2axes [dim ] = []
305305 dim2axes [dim ].append (axis )
306306 # No symbolization by returning False
307307 return False
308308
309- # Collect input dimensions into `unqiue_dims `
309+ # Collect input dimensions into `unique_dims `
310310 assert dyn_dim_cstr .symbolize (dumpy_filter_fn ) is None
311311 total_dim_gen_pass_names = ()
312312
@@ -323,7 +323,7 @@ def append_dim_gen_pass_names(dim_gen_pass_names):
323323 ]
324324 )
325325
326- for i , picked_dim in enumerate (unqiue_dims ):
326+ for i , picked_dim in enumerate (unique_dims ):
327327 logging .warning (f"{ i = } { picked_dim = } " )
328328 cur_dyn_dim_cstr = copy .deepcopy (dyn_dim_cstr )
329329
@@ -341,20 +341,20 @@ def filter_fn(input_name, input_idx, axis, dim):
341341 if not cur_dyn_dim_cstr .check_delta_symbol2example_value (sym2example_value ):
342342 continue
343343 dim_axes_pairs = tuple (
344- (dim , axes ) for dim in unqiue_dims [: i + 1 ] for axes in [dim2axes [dim ]]
344+ (dim , axes ) for dim in unique_dims [: i + 1 ] for axes in [dim2axes [dim ]]
345345 )
346346 ctx_mgr = dyn_dim_cstr_feasibility_ctx_mgr
347- logging .warning (f "before dyn_dim_cstr_feasibility_ctx_mgr" )
347+ logging .warning ("before dyn_dim_cstr_feasibility_ctx_mgr" )
348348 with ctx_mgr (dim_axes_pairs ) as dyn_dim_cstr_feasibility :
349- logging .warning (f "enter dyn_dim_cstr_feasibility_ctx_mgr" )
349+ logging .warning ("enter dyn_dim_cstr_feasibility_ctx_mgr" )
350350 tmp_dyn_dim_cstr = copy .deepcopy (cur_dyn_dim_cstr )
351351 tmp_dyn_dim_cstr .update_symbol2example_value (sym2example_value )
352- logging .warning (f "before dyn_dim_cstr_feasibility" )
352+ logging .warning ("before dyn_dim_cstr_feasibility" )
353353 is_dyn_dim_cstr_feasible = dyn_dim_cstr_feasibility (tmp_dyn_dim_cstr )
354- logging .warning (f "after dyn_dim_cstr_feasibility" )
354+ logging .warning ("after dyn_dim_cstr_feasibility" )
355355 if not is_dyn_dim_cstr_feasible :
356356 continue
357357 dyn_dim_cstr = cur_dyn_dim_cstr
358358 append_dim_gen_pass_names (dyn_dim_cstr_feasibility .dim_gen_pass_names )
359- logging .warning (f "leave dyn_dim_cstr_feasibility_ctx_mgr" )
359+ logging .warning ("leave dyn_dim_cstr_feasibility_ctx_mgr" )
360360 return dyn_dim_cstr , total_dim_gen_pass_names
0 commit comments