11import re
2+ import os
3+ import sys
24import ast
35import inspect
46import jinja2
57import textwrap
8+ import tempfile
69from pathlib import Path
710from typing import Literal
811from collections import namedtuple
@@ -127,18 +130,21 @@ def __init__(
127130 output_dir : str ,
128131 device : Literal ["auto" , "cpu" , "cuda" ] = "auto" ,
129132 generate_main : bool = False ,
133+ try_run : bool = False ,
130134 data_input_predicator_filepath : str = None ,
131135 data_input_predicator_class_name : str = None ,
132136 ):
133137 self .model_path = Path (model_path ).resolve ()
134138 self .output_dir = Path (output_dir )
135139 self .device = self ._choose_device (device )
136140 self .generate_main = generate_main
141+ self .try_run = try_run and generate_main
137142 self .data_input_predicator = self ._make_data_input_predicator (
138143 data_input_predicator_filepath , data_input_predicator_class_name
139144 )
140145
141146 def generate (self ):
147+ print (f"[AgentUnittestGenerator] Generate unittest for { self .model_path } " )
142148 model_name = "" .join (
143149 word .capitalize () for word in re .split (r"[_.-]" , self .model_path .name )
144150 )
@@ -152,7 +158,6 @@ def generate(self):
152158 input_tensor_metas ,
153159 weight_tensor_metas ,
154160 ) = self ._get_input_and_weight_tensor_metas (input_arg_names , weight_arg_names )
155- print (f"{ input_arg_names = } " )
156161 graph_module_desc = GraphModuleDescriptor (
157162 device = self .device ,
158163 generate_main = self .generate_main ,
@@ -166,7 +171,8 @@ def generate(self):
166171 ),
167172 )
168173 unittest = self ._render_template (graph_module_desc )
169- self ._write_to_file (unittest )
174+ if self ._try_to_run_unittest (unittest ):
175+ self ._write_to_file (unittest , self .output_dir )
170176
171177 def _choose_device (self , device ) -> str :
172178 if device in ["cpu" , "cuda" ]:
@@ -180,15 +186,25 @@ def _make_data_input_predicator(
180186 module = imp_util .load_module (data_input_predicator_filepath )
181187 cls = getattr (module , data_input_predicator_class_name )
182188 return cls (config = {})
183- return lambda * args , ** kwargs : False
189+ return lambda * args , ** kwargs : True
184190
185- def _write_to_file (self , unittest ):
186- output_path = Path (self . output_dir ) / f"{ self .model_path .name } _test.py"
191+ def _write_to_file (self , unittest , output_dir ):
192+ output_path = Path (output_dir ) / f"{ self .model_path .name } _test.py"
187193 output_path .parent .mkdir (parents = True , exist_ok = True )
188194 output_path .write_text (unittest , encoding = "utf-8" )
189195 print (
190196 f"[AgentUnittestGenerator] Generate unittest: { output_path } (device={ self .device } )"
191197 )
198+ return output_path
199+
200+ def _try_to_run_unittest (self , unittest ):
201+ if not self .try_run :
202+ return True
203+
204+ with tempfile .TemporaryDirectory (prefix = "unittest_" ) as temp_dir :
205+ output_path = self ._write_to_file (unittest , temp_dir )
206+ cmd = f"{ sys .executable } { output_path } "
207+ return os .system (cmd ) == 0
192208
193209 def _get_input_and_weight_arg_names (self , graph_module ):
194210 input_arg_names = []
@@ -287,28 +303,28 @@ class AgentUnittestGeneratorPass(SamplePass):
287303
288304 def __init__ (self , config = None ):
289305 super ().__init__ (config )
290- print (f"[AgentUnittestGeneratorPass] { self .config = } " )
291306
292307 def declare_config (
293308 self ,
294309 model_path_prefix : str ,
295310 output_dir : str ,
296311 device : str = "auto" ,
297312 generate_main : bool = False ,
313+ try_run : bool = False ,
298314 data_input_predicator_filepath : str = None ,
299315 data_input_predicator_class_name : str = None ,
300316 ):
301317 pass
302318
303319 def __call__ (self , rel_model_path : str ):
304- print (f"[AgentUnittestGeneratorPass] { rel_model_path = } " )
305320 model_path_prefix = Path (self .config ["model_path_prefix" ])
306321 output_dir = Path (self .config ["output_dir" ])
307322 generator = AgentUnittestGenerator (
308323 model_path = str (model_path_prefix / rel_model_path ),
309324 output_dir = str (output_dir / rel_model_path ),
310325 device = self .config ["device" ],
311326 generate_main = self .config ["generate_main" ],
327+ try_run = self .config ["try_run" ],
312328 data_input_predicator_filepath = self .config [
313329 "data_input_predicator_filepath"
314330 ],
0 commit comments