Skip to content

Commit 7479e64

Browse files
committed
Override sample_handled the check whether the unittest file exists.
1 parent 1c34917 commit 7479e64

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

graph_net/torch/sample_passes/agent_unittest_generator.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class AgentUnittestGenerator:
129129
def __init__(
130130
self,
131131
model_path: str,
132+
output_name: str,
132133
output_dir: str,
133134
device: Literal["auto", "cpu", "cuda"] = "auto",
134135
generate_main: bool = False,
@@ -137,6 +138,7 @@ def __init__(
137138
data_input_predicator_class_name: str = None,
138139
):
139140
self.model_path = Path(model_path).resolve()
141+
self.output_name = output_name
140142
self.output_dir = Path(output_dir)
141143
self.device = self._choose_device(device)
142144
self.generate_main = generate_main
@@ -191,7 +193,7 @@ def _make_data_input_predicator(
191193
return lambda *args, **kwargs: True
192194

193195
def _write_to_file(self, unittest, output_dir):
194-
output_path = Path(output_dir) / f"{self.model_path.name}_test.py"
196+
output_path = Path(output_dir) / self.output_name
195197
output_path.parent.mkdir(parents=True, exist_ok=True)
196198
output_path.write_text(unittest, encoding="utf-8")
197199
print(
@@ -326,11 +328,24 @@ def declare_config(
326328
def __call__(self, rel_model_path: str):
327329
self.resumable_handle_sample(rel_model_path)
328330

331+
def sample_handled(self, rel_model_path: str) -> bool:
332+
dst_model_path = Path(self.config["output_dir"]) / rel_model_path
333+
if not dst_model_path.exists():
334+
return False
335+
output_name = self._get_output_name(rel_model_path)
336+
num_model_py_files = len(list(dst_model_path.rglob(output_name)))
337+
assert num_model_py_files <= 1
338+
return num_model_py_files == 1
339+
340+
def _get_output_name(self, rel_model_path: str):
341+
return f"{Path(rel_model_path).name}_test.py"
342+
329343
def resume(self, rel_model_path: str):
330344
model_path_prefix = Path(self.config["model_path_prefix"])
331345
output_dir = Path(self.config["output_dir"])
332346
generator = AgentUnittestGenerator(
333347
model_path=str(model_path_prefix / rel_model_path),
348+
output_name=self._get_output_name(rel_model_path),
334349
output_dir=str(output_dir / rel_model_path),
335350
device=self.config["device"],
336351
generate_main=self.config["generate_main"],

0 commit comments

Comments
 (0)