11import re
2+ import sys
3+ import subprocess
24import ast
35import inspect
46import jinja2
57import textwrap
8+ import tempfile
69from pathlib import Path
710from typing import Literal
811from collections import namedtuple
1114
1215from graph_net import imp_util
1316from graph_net .sample_pass .sample_pass import SamplePass
17+ from graph_net .sample_pass .resumable_sample_pass_mixin import ResumableSamplePassMixin
1418from graph_net .tensor_meta import TensorMeta
1519
1620
2024{%- endif -%}
2125{{"\n"}}
2226import torch
27+ from torch import device
2328
2429
2530{% macro get_input_tensor_instance(tensor_meta, device) -%}
@@ -124,21 +129,26 @@ class AgentUnittestGenerator:
124129 def __init__ (
125130 self ,
126131 model_path : str ,
132+ output_name : str ,
127133 output_dir : str ,
128134 device : Literal ["auto" , "cpu" , "cuda" ] = "auto" ,
129135 generate_main : bool = False ,
136+ try_run : bool = False ,
130137 data_input_predicator_filepath : str = None ,
131138 data_input_predicator_class_name : str = None ,
132139 ):
133140 self .model_path = Path (model_path ).resolve ()
141+ self .output_name = output_name
134142 self .output_dir = Path (output_dir )
135143 self .device = self ._choose_device (device )
136144 self .generate_main = generate_main
145+ self .try_run = try_run and generate_main
137146 self .data_input_predicator = self ._make_data_input_predicator (
138147 data_input_predicator_filepath , data_input_predicator_class_name
139148 )
140149
141150 def generate (self ):
151+ print (f"[AgentUnittestGenerator] Generate unittest for { self .model_path } " )
142152 model_name = "" .join (
143153 word .capitalize () for word in re .split (r"[_.-]" , self .model_path .name )
144154 )
@@ -152,7 +162,6 @@ def generate(self):
152162 input_tensor_metas ,
153163 weight_tensor_metas ,
154164 ) = self ._get_input_and_weight_tensor_metas (input_arg_names , weight_arg_names )
155- print (f"{ input_arg_names = } " )
156165 graph_module_desc = GraphModuleDescriptor (
157166 device = self .device ,
158167 generate_main = self .generate_main ,
@@ -166,7 +175,8 @@ def generate(self):
166175 ),
167176 )
168177 unittest = self ._render_template (graph_module_desc )
169- self ._write_to_file (unittest )
178+ if self ._try_to_run_unittest (unittest ):
179+ self ._write_to_file (unittest , self .output_dir )
170180
171181 def _choose_device (self , device ) -> str :
172182 if device in ["cpu" , "cuda" ]:
@@ -180,15 +190,28 @@ def _make_data_input_predicator(
180190 module = imp_util .load_module (data_input_predicator_filepath )
181191 cls = getattr (module , data_input_predicator_class_name )
182192 return cls (config = {})
183- return lambda * args , ** kwargs : False
193+ return lambda * args , ** kwargs : True
184194
185- def _write_to_file (self , unittest ):
186- output_path = Path (self . output_dir ) / f" { self .model_path . name } _test.py"
195+ def _write_to_file (self , unittest , output_dir ):
196+ output_path = Path (output_dir ) / self .output_name
187197 output_path .parent .mkdir (parents = True , exist_ok = True )
188198 output_path .write_text (unittest , encoding = "utf-8" )
189199 print (
190200 f"[AgentUnittestGenerator] Generate unittest: { output_path } (device={ self .device } )"
191201 )
202+ return output_path
203+
204+ def _try_to_run_unittest (self , unittest ):
205+ if not self .try_run :
206+ return True
207+
208+ with tempfile .TemporaryDirectory (prefix = "unittest_" ) as temp_dir :
209+ output_path = self ._write_to_file (unittest , temp_dir )
210+ result = subprocess .run (
211+ [sys .executable , output_path ],
212+ check = True ,
213+ )
214+ return result .returncode == 0
192215
193216 def _get_input_and_weight_arg_names (self , graph_module ):
194217 input_arg_names = []
@@ -282,33 +305,51 @@ def _render_template(self, graph_module_desc):
282305 return template .render (graph_module_desc = graph_module_desc )
283306
284307
285- class AgentUnittestGeneratorPass (SamplePass ):
308+ class AgentUnittestGeneratorPass (SamplePass , ResumableSamplePassMixin ):
286309 """SamplePass wrapper to generate Torch unittests via model_path_handler."""
287310
288311 def __init__ (self , config = None ):
289312 super ().__init__ (config )
290- print (f"[AgentUnittestGeneratorPass] { self .config = } " )
291313
292314 def declare_config (
293315 self ,
294316 model_path_prefix : str ,
295317 output_dir : str ,
296318 device : str = "auto" ,
297319 generate_main : bool = False ,
320+ try_run : bool = False ,
298321 data_input_predicator_filepath : str = None ,
299322 data_input_predicator_class_name : str = None ,
323+ resume : bool = False ,
324+ limits_handled_models : int = None ,
300325 ):
301326 pass
302327
303328 def __call__ (self , rel_model_path : str ):
304- print (f"[AgentUnittestGeneratorPass] { rel_model_path = } " )
329+ self .resumable_handle_sample (rel_model_path )
330+
331+ def sample_handled (self , rel_model_path : str ) -> bool :
332+ dst_model_path = Path (self .config ["output_dir" ]) / rel_model_path
333+ if not dst_model_path .exists ():
334+ return False
335+ output_name = self ._get_output_name (rel_model_path )
336+ num_model_py_files = len (list (dst_model_path .rglob (output_name )))
337+ assert num_model_py_files <= 1
338+ return num_model_py_files == 1
339+
340+ def _get_output_name (self , rel_model_path : str ):
341+ return f"{ Path (rel_model_path ).name } _test.py"
342+
343+ def resume (self , rel_model_path : str ):
305344 model_path_prefix = Path (self .config ["model_path_prefix" ])
306345 output_dir = Path (self .config ["output_dir" ])
307346 generator = AgentUnittestGenerator (
308347 model_path = str (model_path_prefix / rel_model_path ),
348+ output_name = self ._get_output_name (rel_model_path ),
309349 output_dir = str (output_dir / rel_model_path ),
310350 device = self .config ["device" ],
311351 generate_main = self .config ["generate_main" ],
352+ try_run = self .config ["try_run" ],
312353 data_input_predicator_filepath = self .config [
313354 "data_input_predicator_filepath"
314355 ],
0 commit comments