@@ -210,6 +210,7 @@ def __init__(
210210 self ,
211211 framework : str ,
212212 model_path : str ,
213+ output_name : str ,
213214 output_dir : str ,
214215 device : Literal ["auto" , "cpu" , "cuda" ] = "auto" ,
215216 generate_main : bool = False ,
@@ -219,6 +220,7 @@ def __init__(
219220 ):
220221 self .framework = framework
221222 self .model_path = Path (model_path ).resolve ()
223+ self .output_name = output_name
222224 self .output_dir = Path (output_dir )
223225 self .device = self ._choose_device (device )
224226 self .generate_main = generate_main
@@ -283,7 +285,7 @@ def _make_data_input_predicator(
283285 return lambda * args , ** kwargs : True
284286
285287 def _write_to_file (self , unittest , output_dir ):
286- output_path = Path (output_dir ) / f" { self .model_path . name } _test.py"
288+ output_path = Path (output_dir ) / self .output_name
287289 output_path .parent .mkdir (parents = True , exist_ok = True )
288290 output_path .write_text (unittest , encoding = "utf-8" )
289291 print (
@@ -447,12 +449,25 @@ def declare_config(
447449 def __call__ (self , rel_model_path : str ):
448450 self .resumable_handle_sample (rel_model_path )
449451
452+ def sample_handled (self , rel_model_path : str ) -> bool :
453+ dst_model_path = Path (self .config ["output_dir" ]) / rel_model_path
454+ if not dst_model_path .exists ():
455+ return False
456+ output_name = self ._get_output_name (rel_model_path )
457+ num_model_py_files = len (list (dst_model_path .rglob (output_name )))
458+ assert num_model_py_files <= 1
459+ return num_model_py_files == 1
460+
461+ def _get_output_name (self , rel_model_path : str ):
462+ return f"{ Path (rel_model_path ).name } _test.py"
463+
450464 def resume (self , rel_model_path : str ):
451465 model_path_prefix = Path (self .config ["model_path_prefix" ])
452466 output_dir = Path (self .config ["output_dir" ])
453467 generator = AgentUnittestGenerator (
454468 framework = self .config ["framework" ],
455469 model_path = str (model_path_prefix / rel_model_path ),
470+ output_name = self ._get_output_name (rel_model_path ),
456471 output_dir = str (output_dir / rel_model_path ),
457472 device = self .config ["device" ],
458473 generate_main = self .config ["generate_main" ],
0 commit comments