Skip to content

Commit ed394e4

Browse files
committed
chore: revert run_model.py
1 parent 7eb8a16 commit ed394e4

File tree

2 files changed

+27
-118
lines changed

2 files changed

+27
-118
lines changed

graph_net/paddle/run_model.py

Lines changed: 12 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
import argparse
2-
import base64
3-
import importlib
4-
import importlib.util
5-
import json
61
import os
7-
import sys
2+
import json
3+
import base64
4+
import argparse
85

96
os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"
107

@@ -13,29 +10,13 @@
1310
from graph_net.paddle import utils
1411

1512

16-
BUILTIN_DECORATORS = {
17-
"AgentUnittestGenerator": "graph_net.paddle.sample_passes.agent_unittest_generator",
18-
}
19-
20-
2113
def load_class_from_file(file_path: str, class_name: str):
2214
print(f"Load {class_name} from {file_path}")
2315
module = imp_util.load_module(file_path, "unnamed")
2416
model_class = getattr(module, class_name, None)
2517
return model_class
2618

2719

28-
def _load_builtin_decorator(class_name: str):
29-
module_path = BUILTIN_DECORATORS.get(class_name)
30-
if not module_path:
31-
return None
32-
try:
33-
module = importlib.import_module(module_path)
34-
except ModuleNotFoundError:
35-
return None
36-
return getattr(module, class_name, None)
37-
38-
3920
def get_input_dict(model_path):
4021
inputs_params = utils.load_converted_from_text(f"{model_path}")
4122
params = inputs_params["weight_info"]
@@ -58,37 +39,16 @@ def _convert_to_dict(config_str):
5839
return config
5940

6041

61-
def _get_decorator(arg):
62-
"""Accept argparse.Namespace or already-parsed dict configs."""
63-
if arg is None:
42+
def _get_decorator(args):
43+
if args.decorator_config is None:
6444
return lambda model: model
65-
66-
decorator_config = (
67-
_convert_to_dict(arg.decorator_config)
68-
if hasattr(arg, "decorator_config")
69-
else arg
70-
)
71-
if not decorator_config:
45+
decorator_config = _convert_to_dict(args.decorator_config)
46+
if "decorator_path" not in decorator_config:
7247
return lambda model: model
73-
74-
class_name = decorator_config.get("decorator_class_name", "RunModelDecorator")
75-
decorator_kwargs = decorator_config.get("decorator_config", {})
76-
77-
if "decorator_path" in decorator_config:
78-
decorator_class = load_class_from_file(
79-
decorator_config["decorator_path"], class_name=class_name
80-
)
81-
return decorator_class(decorator_kwargs)
82-
83-
builtin_decorator = _load_builtin_decorator(class_name)
84-
if builtin_decorator:
85-
return builtin_decorator(decorator_kwargs)
86-
87-
if hasattr(sys.modules[__name__], class_name):
88-
decorator_class = getattr(sys.modules[__name__], class_name)
89-
return decorator_class(decorator_kwargs)
90-
91-
return lambda model: model
48+
decorator_class = load_class_from_file(
49+
decorator_config["decorator_path"], class_name="RunModelDecorator"
50+
)
51+
return decorator_class(decorator_config.get("decorator_config", {}))
9252

9353

9454
def main(args):
@@ -99,15 +59,9 @@ def main(args):
9959
assert model_class is not None
10060
model = model_class()
10161
print(f"{model_path=}")
102-
decorator_config = _convert_to_dict(args.decorator_config)
103-
if decorator_config:
104-
decorator_config.setdefault("decorator_config", {})
105-
decorator_config["decorator_config"].setdefault("model_path", model_path)
106-
decorator_config["decorator_config"].setdefault("output_dir", None)
107-
decorator_config["decorator_config"].setdefault("use_numpy", True)
10862

109-
model = _get_decorator(decorator_config)(model)
11063
input_dict = get_input_dict(args.model_path)
64+
model = _get_decorator(args)(model)
11165
model(**input_dict)
11266

11367

graph_net/torch/run_model.py

Lines changed: 15 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
from . import utils
22
import argparse
3-
import base64
4-
import importlib
53
import importlib.util
6-
import json
7-
import sys
8-
from typing import Type
9-
104
import torch
11-
12-
13-
BUILTIN_DECORATORS = {
14-
"AgentUnittestGenerator": "graph_net.torch.sample_passes.agent_unittest_generator",
15-
}
5+
from typing import Type
6+
import json
7+
import base64
168

179

1810
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
@@ -24,17 +16,6 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul
2416
return model_class
2517

2618

27-
def _load_builtin_decorator(class_name: str):
28-
module_path = BUILTIN_DECORATORS.get(class_name)
29-
if not module_path:
30-
return None
31-
try:
32-
module = importlib.import_module(module_path)
33-
except ModuleNotFoundError:
34-
return None
35-
return getattr(module, class_name, None)
36-
37-
3819
def _convert_to_dict(config_str):
3920
if config_str is None:
4021
return {}
@@ -44,47 +25,26 @@ def _convert_to_dict(config_str):
4425
return config
4526

4627

47-
def _get_decorator(arg):
48-
"""Accept argparse.Namespace or already-parsed dict configs."""
49-
if arg is None:
28+
def _get_decorator(decorator_config):
29+
if "decorator_path" not in decorator_config:
5030
return lambda model: model
51-
52-
decorator_config = (
53-
_convert_to_dict(arg.decorator_config)
54-
if hasattr(arg, "decorator_config")
55-
else arg
56-
)
57-
if not decorator_config:
58-
return lambda model: model
59-
6031
class_name = decorator_config.get("decorator_class_name", "RunModelDecorator")
61-
decorator_kwargs = decorator_config.get("decorator_config", {})
62-
63-
if "decorator_path" in decorator_config:
64-
decorator_class = load_class_from_file(
65-
decorator_config["decorator_path"], class_name=class_name
66-
)
67-
return decorator_class(decorator_kwargs)
68-
69-
builtin_decorator = _load_builtin_decorator(class_name)
70-
if builtin_decorator:
71-
return builtin_decorator(decorator_kwargs)
72-
73-
if hasattr(sys.modules[__name__], class_name):
74-
decorator_class = getattr(sys.modules[__name__], class_name)
75-
return decorator_class(decorator_kwargs)
76-
77-
return lambda model: model
32+
decorator_class = load_class_from_file(
33+
decorator_config["decorator_path"],
34+
class_name=class_name,
35+
)
36+
return decorator_class(decorator_config.get("decorator_config", {}))
7837

7938

8039
def get_flag_use_dummy_inputs(decorator_config):
81-
return "use_dummy_inputs" in decorator_config if decorator_config else False
40+
return "use_dummy_inputs" in decorator_config
8241

8342

8443
def replay_tensor(info, use_dummy_inputs):
8544
if use_dummy_inputs:
8645
return utils.get_dummy_tensor(info)
87-
return utils.replay_tensor(info)
46+
else:
47+
return utils.replay_tensor(info)
8848

8949

9050
def main(args):
@@ -96,13 +56,8 @@ def main(args):
9656
model = model_class()
9757
print(f"{model_path=}")
9858
decorator_config = _convert_to_dict(args.decorator_config)
99-
if decorator_config:
100-
decorator_config.setdefault("decorator_config", {})
101-
decorator_config["decorator_config"].setdefault("model_path", model_path)
102-
decorator_config["decorator_config"].setdefault("output_dir", None)
103-
decorator_config["decorator_config"].setdefault("use_dummy_inputs", False)
104-
105-
model = _get_decorator(decorator_config)(model)
59+
if "decorator_path" in decorator_config:
60+
model = _get_decorator(decorator_config)(model)
10661

10762
inputs_params = utils.load_converted_from_text(f"{model_path}")
10863
params = inputs_params["weight_info"]

0 commit comments

Comments
 (0)