Skip to content

Commit 2ed0c1a

Browse files
committed
Enable try-run the unittest when main is generated.
1 parent 7065b30 commit 2ed0c1a

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

graph_net/test/agent_unittest_generator_test.sh

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@
22

33
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
44
GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
5-
6-
MODEL_PATH_PREFIX="$ROOT_DIR"
75
OUTPUT_DIR="/tmp/agent_unittests"
86

97
HANDLER_CONFIG=$(base64 -w 0 <<EOF
108
{
119
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/agent_unittest_generator.py",
1210
"handler_class_name": "AgentUnittestGeneratorPass",
1311
"handler_config": {
14-
"model_path_prefix": "$MODEL_PATH_PREFIX",
12+
"model_path_prefix": "${GRAPH_NET_ROOT}",
1513
"output_dir": "$OUTPUT_DIR",
1614
"device": "auto",
1715
"generate_main": true,
16+
"try_run": true,
1817
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
1918
"data_input_predicator_class_name": "NaiveDataInputPredicator"
2019
}
@@ -24,14 +23,13 @@ EOF
2423

2524
run_case() {
2625
local rel_sample_path="$1"
27-
local name="$2"
28-
echo "[AgentTest] running $name sample at $rel_sample_path"
26+
echo "[AgentTest] running $rel_sample_path"
2927
python -m graph_net.model_path_handler \
3028
--model-path "$rel_sample_path" \
3129
--handler-config "$HANDLER_CONFIG"
3230
}
3331

34-
run_case "samples/torchvision/resnet18" "CV (torchvision/resnet18)"
35-
run_case "samples/transformers-auto-model/albert-base-v2" "NLP (transformers-auto-model/albert-base-v2)"
32+
run_case "samples/torchvision/resnet18"
33+
run_case "samples/transformers-auto-model/albert-base-v2"
3634

3735
echo "[AgentTest] done. Generated *_test.py files should now exist beside the samples."

graph_net/torch/sample_passes/agent_unittest_generator.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import re
2+
import os
3+
import sys
24
import ast
35
import inspect
46
import jinja2
57
import textwrap
8+
import tempfile
69
from pathlib import Path
710
from typing import Literal
811
from collections import namedtuple
@@ -127,18 +130,21 @@ def __init__(
127130
output_dir: str,
128131
device: Literal["auto", "cpu", "cuda"] = "auto",
129132
generate_main: bool = False,
133+
try_run: bool = False,
130134
data_input_predicator_filepath: str = None,
131135
data_input_predicator_class_name: str = None,
132136
):
133137
self.model_path = Path(model_path).resolve()
134138
self.output_dir = Path(output_dir)
135139
self.device = self._choose_device(device)
136140
self.generate_main = generate_main
141+
self.try_run = try_run and generate_main
137142
self.data_input_predicator = self._make_data_input_predicator(
138143
data_input_predicator_filepath, data_input_predicator_class_name
139144
)
140145

141146
def generate(self):
147+
print(f"[AgentUnittestGenerator] Generate unittest for {self.model_path}")
142148
model_name = "".join(
143149
word.capitalize() for word in re.split(r"[_.-]", self.model_path.name)
144150
)
@@ -152,7 +158,6 @@ def generate(self):
152158
input_tensor_metas,
153159
weight_tensor_metas,
154160
) = self._get_input_and_weight_tensor_metas(input_arg_names, weight_arg_names)
155-
print(f"{input_arg_names=}")
156161
graph_module_desc = GraphModuleDescriptor(
157162
device=self.device,
158163
generate_main=self.generate_main,
@@ -166,7 +171,8 @@ def generate(self):
166171
),
167172
)
168173
unittest = self._render_template(graph_module_desc)
169-
self._write_to_file(unittest)
174+
if self._try_to_run_unittest(unittest):
175+
self._write_to_file(unittest, self.output_dir)
170176

171177
def _choose_device(self, device) -> str:
172178
if device in ["cpu", "cuda"]:
@@ -180,15 +186,25 @@ def _make_data_input_predicator(
180186
module = imp_util.load_module(data_input_predicator_filepath)
181187
cls = getattr(module, data_input_predicator_class_name)
182188
return cls(config={})
183-
return lambda *args, **kwargs: False
189+
return lambda *args, **kwargs: True
184190

185-
def _write_to_file(self, unittest):
186-
output_path = Path(self.output_dir) / f"{self.model_path.name}_test.py"
191+
def _write_to_file(self, unittest, output_dir):
192+
output_path = Path(output_dir) / f"{self.model_path.name}_test.py"
187193
output_path.parent.mkdir(parents=True, exist_ok=True)
188194
output_path.write_text(unittest, encoding="utf-8")
189195
print(
190196
f"[AgentUnittestGenerator] Generate unittest: {output_path} (device={self.device})"
191197
)
198+
return output_path
199+
200+
def _try_to_run_unittest(self, unittest):
201+
if not self.try_run:
202+
return True
203+
204+
with tempfile.TemporaryDirectory(prefix="unittest_") as temp_dir:
205+
output_path = self._write_to_file(unittest, temp_dir)
206+
cmd = f"{sys.executable} {output_path}"
207+
return os.system(cmd) == 0
192208

193209
def _get_input_and_weight_arg_names(self, graph_module):
194210
input_arg_names = []
@@ -287,28 +303,28 @@ class AgentUnittestGeneratorPass(SamplePass):
287303

288304
def __init__(self, config=None):
289305
super().__init__(config)
290-
print(f"[AgentUnittestGeneratorPass] {self.config=}")
291306

292307
def declare_config(
293308
self,
294309
model_path_prefix: str,
295310
output_dir: str,
296311
device: str = "auto",
297312
generate_main: bool = False,
313+
try_run: bool = False,
298314
data_input_predicator_filepath: str = None,
299315
data_input_predicator_class_name: str = None,
300316
):
301317
pass
302318

303319
def __call__(self, rel_model_path: str):
304-
print(f"[AgentUnittestGeneratorPass] {rel_model_path=}")
305320
model_path_prefix = Path(self.config["model_path_prefix"])
306321
output_dir = Path(self.config["output_dir"])
307322
generator = AgentUnittestGenerator(
308323
model_path=str(model_path_prefix / rel_model_path),
309324
output_dir=str(output_dir / rel_model_path),
310325
device=self.config["device"],
311326
generate_main=self.config["generate_main"],
327+
try_run=self.config["try_run"],
312328
data_input_predicator_filepath=self.config[
313329
"data_input_predicator_filepath"
314330
],

0 commit comments

Comments
 (0)