Skip to content

Commit e1dee5b

Browse files
authored
[Feature Enhancement] Enable try-run for generated unittest when main is generated and support resume. (#460)
* Enable try-run the unittest when main is generated. * Enable resume. * Use subprocess.run instead of os.system to speedup. * Add import. * Override sample_handled the check whether the unittest file exists.
1 parent 7ad078a commit e1dee5b

File tree

2 files changed

+57
-16
lines changed

2 files changed

+57
-16
lines changed
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
#!/usr/bin/env bash
22

3-
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
4-
GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
53

6-
MODEL_PATH_PREFIX="$ROOT_DIR"
4+
5+
GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
76
OUTPUT_DIR="/tmp/agent_unittests"
87

98
HANDLER_CONFIG=$(base64 -w 0 <<EOF
109
{
1110
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/agent_unittest_generator.py",
1211
"handler_class_name": "AgentUnittestGeneratorPass",
1312
"handler_config": {
14-
"model_path_prefix": "$MODEL_PATH_PREFIX",
13+
"model_path_prefix": "${GRAPH_NET_ROOT}",
1514
"output_dir": "$OUTPUT_DIR",
1615
"device": "auto",
1716
"generate_main": true,
17+
"try_run": true,
18+
"resume": false,
1819
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
1920
"data_input_predicator_class_name": "NaiveDataInputPredicator"
2021
}
@@ -24,14 +25,13 @@ EOF
2425

2526
run_case() {
2627
local rel_sample_path="$1"
27-
local name="$2"
28-
echo "[AgentTest] running $name sample at $rel_sample_path"
28+
echo "[AgentTest] running $rel_sample_path"
2929
python -m graph_net.model_path_handler \
3030
--model-path "$rel_sample_path" \
3131
--handler-config "$HANDLER_CONFIG"
3232
}
3333

34-
run_case "samples/torchvision/resnet18" "CV (torchvision/resnet18)"
35-
run_case "samples/transformers-auto-model/albert-base-v2" "NLP (transformers-auto-model/albert-base-v2)"
34+
run_case "samples/torchvision/resnet18"
35+
run_case "samples/transformers-auto-model/albert-base-v2"
3636

3737
echo "[AgentTest] done. Generated *_test.py files should now exist beside the samples."

graph_net/torch/sample_passes/agent_unittest_generator.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import re
2+
import sys
3+
import subprocess
24
import ast
35
import inspect
46
import jinja2
57
import textwrap
8+
import tempfile
69
from pathlib import Path
710
from typing import Literal
811
from collections import namedtuple
@@ -11,6 +14,7 @@
1114

1215
from graph_net import imp_util
1316
from graph_net.sample_pass.sample_pass import SamplePass
17+
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
1418
from graph_net.tensor_meta import TensorMeta
1519

1620

@@ -20,6 +24,7 @@
2024
{%- endif -%}
2125
{{"\n"}}
2226
import torch
27+
from torch import device
2328
2429
2530
{% macro get_input_tensor_instance(tensor_meta, device) -%}
@@ -124,21 +129,26 @@ class AgentUnittestGenerator:
124129
def __init__(
125130
self,
126131
model_path: str,
132+
output_name: str,
127133
output_dir: str,
128134
device: Literal["auto", "cpu", "cuda"] = "auto",
129135
generate_main: bool = False,
136+
try_run: bool = False,
130137
data_input_predicator_filepath: str = None,
131138
data_input_predicator_class_name: str = None,
132139
):
133140
self.model_path = Path(model_path).resolve()
141+
self.output_name = output_name
134142
self.output_dir = Path(output_dir)
135143
self.device = self._choose_device(device)
136144
self.generate_main = generate_main
145+
self.try_run = try_run and generate_main
137146
self.data_input_predicator = self._make_data_input_predicator(
138147
data_input_predicator_filepath, data_input_predicator_class_name
139148
)
140149

141150
def generate(self):
151+
print(f"[AgentUnittestGenerator] Generate unittest for {self.model_path}")
142152
model_name = "".join(
143153
word.capitalize() for word in re.split(r"[_.-]", self.model_path.name)
144154
)
@@ -152,7 +162,6 @@ def generate(self):
152162
input_tensor_metas,
153163
weight_tensor_metas,
154164
) = self._get_input_and_weight_tensor_metas(input_arg_names, weight_arg_names)
155-
print(f"{input_arg_names=}")
156165
graph_module_desc = GraphModuleDescriptor(
157166
device=self.device,
158167
generate_main=self.generate_main,
@@ -166,7 +175,8 @@ def generate(self):
166175
),
167176
)
168177
unittest = self._render_template(graph_module_desc)
169-
self._write_to_file(unittest)
178+
if self._try_to_run_unittest(unittest):
179+
self._write_to_file(unittest, self.output_dir)
170180

171181
def _choose_device(self, device) -> str:
172182
if device in ["cpu", "cuda"]:
@@ -180,15 +190,28 @@ def _make_data_input_predicator(
180190
module = imp_util.load_module(data_input_predicator_filepath)
181191
cls = getattr(module, data_input_predicator_class_name)
182192
return cls(config={})
183-
return lambda *args, **kwargs: False
193+
return lambda *args, **kwargs: True
184194

185-
def _write_to_file(self, unittest):
186-
output_path = Path(self.output_dir) / f"{self.model_path.name}_test.py"
195+
def _write_to_file(self, unittest, output_dir):
196+
output_path = Path(output_dir) / self.output_name
187197
output_path.parent.mkdir(parents=True, exist_ok=True)
188198
output_path.write_text(unittest, encoding="utf-8")
189199
print(
190200
f"[AgentUnittestGenerator] Generate unittest: {output_path} (device={self.device})"
191201
)
202+
return output_path
203+
204+
def _try_to_run_unittest(self, unittest):
205+
if not self.try_run:
206+
return True
207+
208+
with tempfile.TemporaryDirectory(prefix="unittest_") as temp_dir:
209+
output_path = self._write_to_file(unittest, temp_dir)
210+
result = subprocess.run(
211+
[sys.executable, output_path],
212+
check=True,
213+
)
214+
return result.returncode == 0
192215

193216
def _get_input_and_weight_arg_names(self, graph_module):
194217
input_arg_names = []
@@ -282,33 +305,51 @@ def _render_template(self, graph_module_desc):
282305
return template.render(graph_module_desc=graph_module_desc)
283306

284307

285-
class AgentUnittestGeneratorPass(SamplePass):
308+
class AgentUnittestGeneratorPass(SamplePass, ResumableSamplePassMixin):
286309
"""SamplePass wrapper to generate Torch unittests via model_path_handler."""
287310

288311
def __init__(self, config=None):
289312
super().__init__(config)
290-
print(f"[AgentUnittestGeneratorPass] {self.config=}")
291313

292314
def declare_config(
293315
self,
294316
model_path_prefix: str,
295317
output_dir: str,
296318
device: str = "auto",
297319
generate_main: bool = False,
320+
try_run: bool = False,
298321
data_input_predicator_filepath: str = None,
299322
data_input_predicator_class_name: str = None,
323+
resume: bool = False,
324+
limits_handled_models: int = None,
300325
):
301326
pass
302327

303328
def __call__(self, rel_model_path: str):
304-
print(f"[AgentUnittestGeneratorPass] {rel_model_path=}")
329+
self.resumable_handle_sample(rel_model_path)
330+
331+
def sample_handled(self, rel_model_path: str) -> bool:
332+
dst_model_path = Path(self.config["output_dir"]) / rel_model_path
333+
if not dst_model_path.exists():
334+
return False
335+
output_name = self._get_output_name(rel_model_path)
336+
num_model_py_files = len(list(dst_model_path.rglob(output_name)))
337+
assert num_model_py_files <= 1
338+
return num_model_py_files == 1
339+
340+
def _get_output_name(self, rel_model_path: str):
341+
return f"{Path(rel_model_path).name}_test.py"
342+
343+
def resume(self, rel_model_path: str):
305344
model_path_prefix = Path(self.config["model_path_prefix"])
306345
output_dir = Path(self.config["output_dir"])
307346
generator = AgentUnittestGenerator(
308347
model_path=str(model_path_prefix / rel_model_path),
348+
output_name=self._get_output_name(rel_model_path),
309349
output_dir=str(output_dir / rel_model_path),
310350
device=self.config["device"],
311351
generate_main=self.config["generate_main"],
352+
try_run=self.config["try_run"],
312353
data_input_predicator_filepath=self.config[
313354
"data_input_predicator_filepath"
314355
],

0 commit comments

Comments
 (0)