Skip to content

Commit 05c4cd3

Browse files
committed
feat: refactor code to improve style, use a seperate method to inject AgentUnittestGenerator, adopt jinja2 as render engine
1 parent 675631e commit 05c4cd3

File tree

7 files changed

+594
-231
lines changed

7 files changed

+594
-231
lines changed

graph_net/paddle/run_model.py

Lines changed: 23 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import base64
3+
import importlib
34
import importlib.util
45
import json
56
import os
@@ -10,7 +11,11 @@
1011
import paddle
1112
from graph_net import imp_util
1213
from graph_net.paddle import utils
13-
from jinja2 import Template
14+
15+
16+
BUILTIN_DECORATORS = {
17+
"AgentUnittestGenerator": "graph_net.paddle.sample_passes.agent_unittest_generator",
18+
}
1419

1520

1621
def load_class_from_file(file_path: str, class_name: str):
@@ -20,6 +25,17 @@ def load_class_from_file(file_path: str, class_name: str):
2025
return model_class
2126

2227

28+
def _load_builtin_decorator(class_name: str):
29+
module_path = BUILTIN_DECORATORS.get(class_name)
30+
if not module_path:
31+
return None
32+
try:
33+
module = importlib.import_module(module_path)
34+
except ModuleNotFoundError:
35+
return None
36+
return getattr(module, class_name, None)
37+
38+
2339
def get_input_dict(model_path):
2440
inputs_params = utils.load_converted_from_text(f"{model_path}")
2541
params = inputs_params["weight_info"]
@@ -43,7 +59,7 @@ def _convert_to_dict(config_str):
4359

4460

4561
def _get_decorator(arg):
46-
"""兼容旧接口:既接受 argparse.Namespace,也接受已解析的 dict"""
62+
"""Accept argparse.Namespace or already-parsed dict configs."""
4763
if arg is None:
4864
return lambda model: model
4965

@@ -64,105 +80,17 @@ def _get_decorator(arg):
6480
)
6581
return decorator_class(decorator_kwargs)
6682

83+
builtin_decorator = _load_builtin_decorator(class_name)
84+
if builtin_decorator:
85+
return builtin_decorator(decorator_kwargs)
86+
6787
if hasattr(sys.modules[__name__], class_name):
6888
decorator_class = getattr(sys.modules[__name__], class_name)
6989
return decorator_class(decorator_kwargs)
7090

7191
return lambda model: model
7292

7393

74-
class AgentUnittestGenerator:
75-
"""生成 Paddle 子图的独立 unittest 脚本,验证前向可运行。"""
76-
77-
def __init__(self, config):
78-
defaults = {
79-
"model_path": None,
80-
"output_path": None,
81-
"force_device": "auto", # auto / cpu / gpu
82-
"use_numpy": True,
83-
}
84-
merged = {**defaults, **(config or {})}
85-
if merged["model_path"] is None:
86-
raise ValueError("AgentUnittestGenerator requires 'model_path' in config")
87-
self.model_path = merged["model_path"]
88-
self.output_path = merged["output_path"] or self._default_output_path()
89-
self.force_device = merged["force_device"]
90-
self.use_numpy = merged["use_numpy"]
91-
92-
def __call__(self, model):
93-
self._generate_unittest_file()
94-
return model
95-
96-
def _default_output_path(self):
97-
base = os.path.basename(os.path.normpath(self.model_path))
98-
return os.path.join(self.model_path, f"{base}_test.py")
99-
100-
def _choose_device(self):
101-
if self.force_device == "cpu":
102-
return "cpu"
103-
if self.force_device == "gpu":
104-
return "gpu"
105-
return "gpu" if paddle.device.is_compiled_with_cuda() else "cpu"
106-
107-
def _generate_unittest_file(self):
108-
target_device = self._choose_device()
109-
template_str = """
110-
import importlib.util
111-
import os
112-
import unittest
113-
114-
import paddle
115-
from graph_net.paddle import utils
116-
117-
118-
def _load_graph_module(model_path: str):
119-
source_path = os.path.join(model_path, "model.py")
120-
spec = importlib.util.spec_from_file_location("agent_graph_module", source_path)
121-
module = importlib.util.module_from_spec(spec)
122-
spec.loader.exec_module(module)
123-
return module.GraphModule
124-
125-
126-
class AgentGraphTest(unittest.TestCase):
127-
def setUp(self):
128-
self.model_path = os.path.dirname(__file__)
129-
self.target_device = "{{ target_device }}"
130-
paddle.set_device(self.target_device)
131-
self.GraphModule = _load_graph_module(self.model_path)
132-
self.meta = utils.load_converted_from_text(self.model_path)
133-
self.use_numpy = {{ use_numpy_flag }}
134-
135-
def _with_device(self, info):
136-
cloned = {"info": dict(info["info"]), "data": info.get("data")}
137-
cloned["info"]["device"] = self.target_device
138-
return cloned
139-
140-
def _build_tensor(self, meta):
141-
return utils.replay_tensor(self._with_device(meta), use_numpy=self.use_numpy)
142-
143-
def test_forward_runs(self):
144-
model = self.GraphModule()
145-
inputs = {k: self._build_tensor(v) for k, v in self.meta["input_info"].items()}
146-
params = {k: self._build_tensor(v) for k, v in self.meta["weight_info"].items()}
147-
model.__graph_net_file_path__ = self.model_path
148-
output = model(**params, **inputs)
149-
self.assertIsNotNone(output)
150-
151-
152-
if __name__ == "__main__":
153-
unittest.main()
154-
"""
155-
156-
rendered = Template(template_str).render(
157-
target_device=target_device, use_numpy_flag=self.use_numpy
158-
)
159-
160-
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
161-
with open(self.output_path, "w", encoding="utf-8") as f:
162-
f.write(rendered)
163-
print(f"[Agent] unittest 已生成: {self.output_path} (device={target_device})")
164-
165-
16694
def main(args):
16795
model_path = args.model_path
16896
model_class = load_class_from_file(
@@ -175,6 +103,7 @@ def main(args):
175103
if decorator_config:
176104
decorator_config.setdefault("decorator_config", {})
177105
decorator_config["decorator_config"].setdefault("model_path", model_path)
106+
decorator_config["decorator_config"].setdefault("output_dir", None)
178107
decorator_config["decorator_config"].setdefault("use_numpy", True)
179108

180109
model = _get_decorator(decorator_config)(model)

graph_net/paddle/sample_passes/__init__.py

Whitespace-only changes.
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
from pathlib import Path
2+
from typing import Any, Dict
3+
4+
import paddle
5+
from jinja2 import Template
6+
7+
from graph_net.sample_pass.sample_pass import SamplePass
8+
9+
10+
PADDLE_UNITTEST_TEMPLATE = r"""
11+
import importlib.util
12+
import os
13+
import unittest
14+
from typing import Any, Dict
15+
16+
import numpy as np
17+
import paddle
18+
19+
20+
def _get_classes(file_path: str):
21+
spec = importlib.util.spec_from_file_location("agent_meta", file_path)
22+
module = importlib.util.module_from_spec(spec)
23+
spec.loader.exec_module(module)
24+
return [
25+
(name, cls)
26+
for name, cls in vars(module).items()
27+
if isinstance(cls, type)
28+
]
29+
30+
31+
def _convert_meta_classes_to_wrappers(file_path: str):
32+
current_device = paddle.device.get_device()
33+
for _, cls in _get_classes(file_path):
34+
attrs = {
35+
k: v for k, v in vars(cls).items() if not k.startswith("__") and not callable(v)
36+
}
37+
dtype_attr = attrs.get("dtype", "float32")
38+
dtype = getattr(paddle, str(dtype_attr).split(".")[-1])
39+
shape = [1 if dim is None else dim for dim in attrs.get("shape", [])]
40+
info = {
41+
"shape": shape,
42+
"dtype": dtype,
43+
"device": attrs.get("device", current_device),
44+
"mean": attrs.get("mean"),
45+
"std": attrs.get("std"),
46+
"min_val": attrs.get("min_val", 0),
47+
"max_val": attrs.get("max_val", 2),
48+
}
49+
data = attrs.get("data")
50+
if data is not None and not isinstance(data, paddle.Tensor):
51+
data = paddle.to_tensor(data, dtype=dtype).reshape(info["shape"])
52+
yield {"info": info, "data": data, "name": attrs.get("name")}
53+
54+
55+
def _convert_meta_to_tensors(model_path: str):
56+
weight_meta = os.path.join(model_path, "weight_meta.py")
57+
input_meta = os.path.join(model_path, "input_meta.py")
58+
weight_info = {
59+
item["name"]: item for item in _convert_meta_classes_to_wrappers(weight_meta)
60+
}
61+
input_info = {
62+
item["name"]: item for item in _convert_meta_classes_to_wrappers(input_meta)
63+
}
64+
return {"weight_info": weight_info, "input_info": input_info}
65+
66+
67+
def _init_integer_tensor(dtype, shape, min_val, max_val, use_numpy: bool):
68+
if use_numpy:
69+
array = np.random.randint(low=min_val, high=max_val + 1, size=shape, dtype=dtype)
70+
return paddle.to_tensor(array)
71+
return paddle.randint(low=min_val, high=max_val + 1, shape=shape, dtype=dtype)
72+
73+
74+
def _init_float_tensor(shape, mean, std, min_val, max_val, use_numpy: bool):
75+
if use_numpy:
76+
if mean is not None and std is not None:
77+
array = np.random.normal(0, 1, shape) * std * 0.2 + mean
78+
array = np.clip(array, min_val, max_val)
79+
else:
80+
array = np.random.uniform(low=min_val, high=max_val, size=shape)
81+
return paddle.to_tensor(array)
82+
if mean is not None and std is not None:
83+
tensor = paddle.randn(shape, dtype="float32") * std * 0.2 + mean
84+
tensor = paddle.clip(tensor, min=min_val, max=max_val)
85+
return tensor
86+
return paddle.uniform(shape=shape, dtype="float32", min=min_val, max=max_val)
87+
88+
89+
def _replay_tensor(info: Dict[str, Any], use_numpy: bool):
90+
device = info["info"].get("device", paddle.device.get_device())
91+
dtype = info["info"].get("dtype", paddle.float32)
92+
shape = [1 if dim is None else dim for dim in info["info"].get("shape", [])]
93+
mean = info["info"].get("mean")
94+
std = info["info"].get("std")
95+
min_val = info["info"].get("min_val", 0)
96+
max_val = info["info"].get("max_val", 2)
97+
if info.get("data") is not None:
98+
return paddle.reshape(info["data"], shape).to(dtype).to(device)
99+
if dtype in [paddle.int32, paddle.int64, paddle.bool]:
100+
init_dtype = "int32" if dtype == paddle.bool else "int64"
101+
if dtype == paddle.bool:
102+
min_val, max_val = 0, 1
103+
return _init_integer_tensor(init_dtype, shape, min_val, max_val, use_numpy).to(dtype).to(device)
104+
tensor = _init_float_tensor(shape, mean, std, min_val, max_val, use_numpy)
105+
return tensor.to(dtype).to(device)
106+
107+
108+
def _get_dummy_tensor(info: Dict[str, Any]):
109+
device = info["info"].get("device", paddle.device.get_device())
110+
dtype = info["info"].get("dtype", paddle.float32)
111+
shape = [1 if dim is None else dim for dim in info["info"].get("shape", [])]
112+
if info.get("data") is not None:
113+
return paddle.reshape(info["data"], shape).to(dtype).to(device)
114+
return paddle.empty(shape=shape, dtype=dtype, device=device)
115+
116+
117+
def _load_graph_module(model_path: str):
118+
source_path = os.path.join(model_path, "model.py")
119+
spec = importlib.util.spec_from_file_location("agent_graph_module", source_path)
120+
module = importlib.util.module_from_spec(spec)
121+
spec.loader.exec_module(module)
122+
return module.GraphModule
123+
124+
125+
class AgentGraphTest(unittest.TestCase):
126+
def setUp(self):
127+
self.model_path = os.path.dirname(__file__)
128+
self.target_device = "{{ target_device }}"
129+
self.use_numpy = {{ use_numpy_flag }}
130+
paddle.set_device(self.target_device)
131+
self.GraphModule = _load_graph_module(self.model_path)
132+
self.meta = _convert_meta_to_tensors(self.model_path)
133+
134+
def _with_device(self, info: Dict[str, Any]):
135+
cloned = {"info": dict(info["info"]), "data": info.get("data")}
136+
cloned["info"]["device"] = self.target_device
137+
return cloned
138+
139+
def test_forward_runs(self):
140+
model = self.GraphModule()
141+
inputs = {k: _replay_tensor(self._with_device(v), self.use_numpy) for k, v in self.meta["input_info"].items()}
142+
params = {k: _replay_tensor(self._with_device(v), self.use_numpy) for k, v in self.meta["weight_info"].items()}
143+
model.__graph_net_file_path__ = self.model_path
144+
output = model(**params, **inputs)
145+
self.assertIsNotNone(output)
146+
147+
148+
if __name__ == "__main__":
149+
unittest.main()
150+
"""
151+
152+
153+
class AgentUnittestGenerator:
154+
"""Generate standalone unittest scripts for Paddle samples."""
155+
156+
def __init__(self, config: Dict[str, Any]):
157+
defaults = {
158+
"model_path": None,
159+
"output_path": None,
160+
"output_dir": None,
161+
"force_device": "auto", # auto / cpu / gpu
162+
"use_numpy": True,
163+
}
164+
merged = {**defaults, **(config or {})}
165+
if not merged["model_path"]:
166+
raise ValueError("AgentUnittestGenerator requires 'model_path' in config")
167+
168+
self.model_path = Path(merged["model_path"]).resolve()
169+
self.output_path = Path(merged["output_path"]) if merged.get("output_path") else None
170+
self.output_dir = Path(merged["output_dir"]) if merged.get("output_dir") else None
171+
self.force_device = merged["force_device"]
172+
self.use_numpy = merged["use_numpy"]
173+
174+
def __call__(self, model):
175+
self.generate()
176+
return model
177+
178+
def generate(self):
179+
output_path = self._resolve_output_path()
180+
target_device = self._choose_device()
181+
rendered = Template(PADDLE_UNITTEST_TEMPLATE).render(
182+
target_device=target_device, use_numpy_flag=self.use_numpy
183+
)
184+
output_path.parent.mkdir(parents=True, exist_ok=True)
185+
output_path.write_text(rendered, encoding="utf-8")
186+
print(f"[Agent] unittest generated: {output_path} (device={target_device})")
187+
188+
def _resolve_output_path(self) -> Path:
189+
if self.output_path:
190+
return self.output_path
191+
target_dir = self.output_dir or self.model_path
192+
return Path(target_dir) / f"{self.model_path.name}_test.py"
193+
194+
def _choose_device(self) -> str:
195+
if self.force_device == "cpu":
196+
return "cpu"
197+
if self.force_device == "gpu":
198+
return "gpu"
199+
return "gpu" if paddle.device.is_compiled_with_cuda() else "cpu"
200+
201+
202+
class AgentUnittestGeneratorPass(SamplePass):
203+
"""SamplePass wrapper to generate Paddle unittests via model_path_handler."""
204+
205+
def __init__(self, config=None):
206+
super().__init__(config)
207+
208+
def declare_config(
209+
self,
210+
model_path_prefix: str,
211+
output_dir: str = None,
212+
force_device: str = "auto",
213+
use_numpy: bool = True,
214+
):
215+
pass
216+
217+
def __call__(self, rel_model_path: str):
218+
model_path_prefix = Path(self.config["model_path_prefix"])
219+
target_root = Path(self.config.get("output_dir") or model_path_prefix)
220+
model_path = model_path_prefix / rel_model_path
221+
generator = AgentUnittestGenerator(
222+
{
223+
"model_path": str(model_path),
224+
"output_dir": str(target_root / rel_model_path),
225+
"force_device": self.config["force_device"],
226+
"use_numpy": self.config["use_numpy"],
227+
}
228+
)
229+
generator.generate()

0 commit comments

Comments
 (0)