File tree Expand file tree Collapse file tree 3 files changed +12
-9
lines changed
Expand file tree Collapse file tree 3 files changed +12
-9
lines changed Original file line number Diff line number Diff line change 55from dataclasses import asdict
66
77
8- def get_torch_module_and_inputs (model_path ):
8+ def get_torch_module_and_inputs (model_path , use_dummy_inputs = True ):
99 module = _get_torch_module (model_path )
1010 tensor_metas = _get_tensor_metas (model_path )
11- inputs = _create_inputs_by_metas (module , tensor_metas )
11+ inputs = _create_inputs_by_metas (module , tensor_metas , use_dummy_inputs )
1212 return module , inputs
1313
1414
@@ -27,11 +27,11 @@ def _get_tensor_metas(model_path):
2727 ]
2828
2929
30- def _create_inputs_by_metas (module , tensor_metas ):
30+ def _create_inputs_by_metas (module , tensor_metas , use_dummy_inputs ):
3131 tensor_meta_attrs_list = [asdict (tensor_meta ) for tensor_meta in tensor_metas ]
32- from graph_net .torch .utils import get_dummy_named_tensors
32+ from graph_net .torch .utils import get_named_tensors
3333
34- named_tensors = get_dummy_named_tensors (tensor_meta_attrs_list )
34+ named_tensors = get_named_tensors (tensor_meta_attrs_list , use_dummy_inputs )
3535 name2tensor = {k : v for k , v in named_tensors }
3636 return tuple (
3737 name2tensor [name ] for name in inspect .signature (module .forward ).parameters
Original file line number Diff line number Diff line change @@ -130,7 +130,7 @@ def __call__(self, rel_model_path):
130130 for k , v in self .config .items ()
131131 if k in {"split_positions" , "group_head_and_tail" , "chain_style" }
132132 }
133- module , inputs = get_torch_module_and_inputs (model_path )
133+ module , inputs = get_torch_module_and_inputs (model_path , use_dummy_inputs = False )
134134 gm = parse_immutable_model_path_into_sole_graph_module (model_path )
135135 try :
136136 # logger.warning("convert_to_submodules_graph-call-begin")
@@ -227,7 +227,7 @@ def __call__(self, rel_model_path):
227227 "group_head_and_tail" : self .config .get ("group_head_and_tail" , False ),
228228 "chain_style" : self .config .get ("chain_style" , False ),
229229 }
230- module , inputs = get_torch_module_and_inputs (model_path )
230+ module , inputs = get_torch_module_and_inputs (model_path , use_dummy_inputs = False )
231231 gm = parse_sole_graph_module (module , inputs )
232232 rewrited_gm : torch .fx .GraphModule = convert_to_submodules_graph (
233233 gm ,
Original file line number Diff line number Diff line change @@ -236,15 +236,18 @@ def convert_tensor_meta_attrs_list_to_named_tensors(tensor_meta_attrs_list):
236236 return ret
237237
238238
239- def get_dummy_named_tensors (tensor_meta_attrs_list ):
239+ def get_named_tensors (tensor_meta_attrs_list , use_dummy_inputs ):
240240 tensors_wrappers = convert_tensor_meta_attrs_list_to_tensors_wrappers (
241241 tensor_meta_attrs_list
242242 )
243243 ret = []
244244 for i , tensors_wrapper in enumerate (tensors_wrappers ):
245245 name = tensors_wrapper ["name" ]
246246 # shape = tensors_wrapper["info"]['shape']
247- tensor = get_dummy_tensor (tensors_wrapper )
247+ if use_dummy_inputs :
248+ tensor = get_dummy_tensor (tensors_wrapper )
249+ else :
250+ tensor = replay_tensor (tensors_wrapper )
248251 ret .append ((name , tensor ))
249252 return ret
250253
You can’t perform that action at this time.
0 commit comments