11from . import utils
22import argparse
3- import base64
4- import importlib
53import importlib .util
6- import json
7- import sys
8- from typing import Type
9-
104import torch
11-
12-
13- BUILTIN_DECORATORS = {
14- "AgentUnittestGenerator" : "graph_net.torch.sample_passes.agent_unittest_generator" ,
15- }
5+ from typing import Type
6+ import json
7+ import base64
168
179
1810def load_class_from_file (file_path : str , class_name : str ) -> Type [torch .nn .Module ]:
@@ -24,17 +16,6 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul
2416 return model_class
2517
2618
27- def _load_builtin_decorator (class_name : str ):
28- module_path = BUILTIN_DECORATORS .get (class_name )
29- if not module_path :
30- return None
31- try :
32- module = importlib .import_module (module_path )
33- except ModuleNotFoundError :
34- return None
35- return getattr (module , class_name , None )
36-
37-
3819def _convert_to_dict (config_str ):
3920 if config_str is None :
4021 return {}
@@ -44,47 +25,26 @@ def _convert_to_dict(config_str):
4425 return config
4526
4627
47- def _get_decorator (arg ):
48- """Accept argparse.Namespace or already-parsed dict configs."""
49- if arg is None :
28+ def _get_decorator (decorator_config ):
29+ if "decorator_path" not in decorator_config :
5030 return lambda model : model
51-
52- decorator_config = (
53- _convert_to_dict (arg .decorator_config )
54- if hasattr (arg , "decorator_config" )
55- else arg
56- )
57- if not decorator_config :
58- return lambda model : model
59-
6031 class_name = decorator_config .get ("decorator_class_name" , "RunModelDecorator" )
61- decorator_kwargs = decorator_config .get ("decorator_config" , {})
62-
63- if "decorator_path" in decorator_config :
64- decorator_class = load_class_from_file (
65- decorator_config ["decorator_path" ], class_name = class_name
66- )
67- return decorator_class (decorator_kwargs )
68-
69- builtin_decorator = _load_builtin_decorator (class_name )
70- if builtin_decorator :
71- return builtin_decorator (decorator_kwargs )
72-
73- if hasattr (sys .modules [__name__ ], class_name ):
74- decorator_class = getattr (sys .modules [__name__ ], class_name )
75- return decorator_class (decorator_kwargs )
76-
77- return lambda model : model
32+ decorator_class = load_class_from_file (
33+ decorator_config ["decorator_path" ],
34+ class_name = class_name ,
35+ )
36+ return decorator_class (decorator_config .get ("decorator_config" , {}))
7837
7938
8039def get_flag_use_dummy_inputs (decorator_config ):
81- return "use_dummy_inputs" in decorator_config if decorator_config else False
40+ return "use_dummy_inputs" in decorator_config
8241
8342
8443def replay_tensor (info , use_dummy_inputs ):
8544 if use_dummy_inputs :
8645 return utils .get_dummy_tensor (info )
87- return utils .replay_tensor (info )
46+ else :
47+ return utils .replay_tensor (info )
8848
8949
9050def main (args ):
@@ -96,13 +56,8 @@ def main(args):
9656 model = model_class ()
9757 print (f"{ model_path = } " )
9858 decorator_config = _convert_to_dict (args .decorator_config )
99- if decorator_config :
100- decorator_config .setdefault ("decorator_config" , {})
101- decorator_config ["decorator_config" ].setdefault ("model_path" , model_path )
102- decorator_config ["decorator_config" ].setdefault ("output_dir" , None )
103- decorator_config ["decorator_config" ].setdefault ("use_dummy_inputs" , False )
104-
105- model = _get_decorator (decorator_config )(model )
59+ if "decorator_path" in decorator_config :
60+ model = _get_decorator (decorator_config )(model )
10661
10762 inputs_params = utils .load_converted_from_text (f"{ model_path } " )
10863 params = inputs_params ["weight_info" ]
0 commit comments