@@ -129,6 +129,7 @@ class AgentUnittestGenerator:
129129 def __init__ (
130130 self ,
131131 model_path : str ,
132+ output_name : str ,
132133 output_dir : str ,
133134 device : Literal ["auto" , "cpu" , "cuda" ] = "auto" ,
134135 generate_main : bool = False ,
@@ -137,6 +138,7 @@ def __init__(
137138 data_input_predicator_class_name : str = None ,
138139 ):
139140 self .model_path = Path (model_path ).resolve ()
141+ self .output_name = output_name
140142 self .output_dir = Path (output_dir )
141143 self .device = self ._choose_device (device )
142144 self .generate_main = generate_main
@@ -191,7 +193,7 @@ def _make_data_input_predicator(
191193 return lambda * args , ** kwargs : True
192194
193195 def _write_to_file (self , unittest , output_dir ):
194- output_path = Path (output_dir ) / f" { self .model_path . name } _test.py"
196+ output_path = Path (output_dir ) / self .output_name
195197 output_path .parent .mkdir (parents = True , exist_ok = True )
196198 output_path .write_text (unittest , encoding = "utf-8" )
197199 print (
@@ -326,11 +328,24 @@ def declare_config(
326328 def __call__ (self , rel_model_path : str ):
327329 self .resumable_handle_sample (rel_model_path )
328330
331+ def sample_handled (self , rel_model_path : str ) -> bool :
332+ dst_model_path = Path (self .config ["output_dir" ]) / rel_model_path
333+ if not dst_model_path .exists ():
334+ return False
335+ output_name = self ._get_output_name (rel_model_path )
336+ num_model_py_files = len (list (dst_model_path .rglob (output_name )))
337+ assert num_model_py_files <= 1
338+ return num_model_py_files == 1
339+
340+ def _get_output_name (self , rel_model_path : str ):
341+ return f"{ Path (rel_model_path ).name } _test.py"
342+
329343 def resume (self , rel_model_path : str ):
330344 model_path_prefix = Path (self .config ["model_path_prefix" ])
331345 output_dir = Path (self .config ["output_dir" ])
332346 generator = AgentUnittestGenerator (
333347 model_path = str (model_path_prefix / rel_model_path ),
348+ output_name = self ._get_output_name (rel_model_path ),
334349 output_dir = str (output_dir / rel_model_path ),
335350 device = self .config ["device" ],
336351 generate_main = self .config ["generate_main" ],
0 commit comments