66import copy
77import sys
88import os
9+ from contextlib import contextmanager
10+ import tempfile
11+ import shutil
12+ from pathlib import Path
913
1014
1115class UpdateInputTensorConstraints :
@@ -33,25 +37,33 @@ def _make_config(
3337 data_input_predicator_filepath ,
3438 model_runnable_predicator_filepath ,
3539 data_input_predicator_class_name = "DataInputPredicator" ,
36- data_input_predicator_config = None ,
3740 model_runnable_predicator_class_name = "ModelRunner" ,
41+ data_input_predicator_config = None ,
3842 model_runnable_predicator_config = None ,
43+ dimension_generalizer_filepath = None ,
44+ dimension_generalizer_class_name = "StaticToDynamic" ,
45+ dimension_generalizer_config = None ,
3946 model_path_prefix = "" ,
4047 resume = False ,
4148 ):
4249 if data_input_predicator_config is None :
4350 data_input_predicator_config = {}
4451 if model_runnable_predicator_config is None :
4552 model_runnable_predicator_config = {}
53+ if dimension_generalizer_config is None :
54+ dimension_generalizer_config = {}
4655 return {
56+ "resume" : resume ,
57+ "model_path_prefix" : model_path_prefix ,
4758 "data_input_predicator_filepath" : data_input_predicator_filepath ,
4859 "data_input_predicator_class_name" : data_input_predicator_class_name ,
4960 "data_input_predicator_config" : data_input_predicator_config ,
5061 "model_runnable_predicator_filepath" : model_runnable_predicator_filepath ,
5162 "model_runnable_predicator_class_name" : model_runnable_predicator_class_name ,
5263 "model_runnable_predicator_config" : model_runnable_predicator_config ,
53- "model_path_prefix" : model_path_prefix ,
54- "resume" : resume ,
64+ "dimension_generalizer_filepath" : dimension_generalizer_filepath ,
65+ "dimension_generalizer_class_name" : dimension_generalizer_class_name ,
66+ "dimension_generalizer_config" : dimension_generalizer_config ,
5567 }
5668
5769 def __call__ (self , model_path ):
@@ -74,17 +86,51 @@ def __call__(self, model_path):
7486 def data_input_predicator (input_var_name ):
7587 return self .data_input_predicator (model_path , input_var_name )
7688
77- def is_dyn_dim_cstr_feasible (dyn_dim_cstr ):
78- return self ._is_dyn_dim_cstr_feasible (
79- model_path , tensor_metas , dyn_dim_cstr
80- )
89+ with self ._try_dimension_generalization (
90+ model_path , tensor_metas
91+ ) as tmp_model_path :
92+
93+ def is_dyn_dim_cstr_feasible (dyn_dim_cstr ):
94+ return self ._is_dyn_dim_cstr_feasible (
95+ tmp_model_path , tensor_metas , dyn_dim_cstr
96+ )
8197
82- dyn_dim_cstr = symbolize_data_input_dims (
83- dyn_dim_cstr ,
84- is_data_input = data_input_predicator ,
85- is_dyn_dim_cstr_feasible = is_dyn_dim_cstr_feasible ,
98+ dyn_dim_cstr = symbolize_data_input_dims (
99+ dyn_dim_cstr ,
100+ is_data_input = data_input_predicator ,
101+ is_dyn_dim_cstr_feasible = is_dyn_dim_cstr_feasible ,
102+ )
103+ self ._save_dyn_dim_cstr (dyn_dim_cstr , model_path )
104+
105+ @contextmanager
106+ def _try_dimension_generalization (self , model_path , tensor_metas ):
107+ if self .config ["dimension_generalizer_filepath" ] is None :
108+ yield model_path
109+ return
110+ py_module = load_module (os .path .join (model_path , "model.py" ))
111+ GraphModule = getattr (py_module , "GraphModule" )
112+ GraphModule .__graph_net_file_path__ = py_module .__graph_net_file_path__
113+ model = GraphModule ()
114+ decorator_cls = getattr (
115+ load_module (self .config ["dimension_generalizer_filepath" ]),
116+ self .config ["dimension_generalizer_class_name" ],
117+ )
118+ pass_obj = decorator_cls (self .config ["dimension_generalizer_config" ])(model )
119+ if not pass_obj .need_rewrite ():
120+ yield model_path
121+ return
122+ from dataclasses import asdict
123+
124+ tensor_meta_attrs_list = [asdict (tensor_meta ) for tensor_meta in tensor_metas ]
125+ graph_module = pass_obj .rewrite_with_tensor_meta_attrs_list (
126+ tensor_meta_attrs_list
86127 )
87- self ._save_dyn_dim_cstr (dyn_dim_cstr , model_path )
128+ with tempfile .TemporaryDirectory () as tmp_dir :
129+ shutil .copytree (Path (model_path ), Path (tmp_dir ), dirs_exist_ok = True )
130+ pass_obj .save_graph_module (graph_module , tmp_dir )
131+ shutil .copy (Path (tmp_dir ) / "model.py" , Path ("/tmp/a.py" ))
132+ yield tmp_dir
133+ # shutil.copytree(Path(tmp_dir), Path(model_path), dirs_exist_ok=True)
88134
89135 def _save_dyn_dim_cstr (self , dyn_dim_cstr , model_path ):
90136 cstr_code = dyn_dim_cstr .serialize_to_py_str ()
@@ -106,7 +152,6 @@ def _is_dyn_dim_cstr_feasible(
106152 weight_meta_code = "\n " .join (
107153 tensor_meta .serialize_to_py_str () for tensor_meta in tensor_metas
108154 )
109- import tempfile
110155
111156 with tempfile .TemporaryDirectory () as tmpdir :
112157 for filename in ["graph_net.json" , "model.py" ]:
0 commit comments