1414
1515from graph_net import imp_util
1616from graph_net .sample_pass .sample_pass import SamplePass
17+ from graph_net .sample_pass .resumable_sample_pass_mixin import ResumableSamplePassMixin
1718from graph_net .tensor_meta import TensorMeta
1819
1920
@@ -298,7 +299,7 @@ def _render_template(self, graph_module_desc):
298299 return template .render (graph_module_desc = graph_module_desc )
299300
300301
301- class AgentUnittestGeneratorPass (SamplePass ):
302+ class AgentUnittestGeneratorPass (SamplePass , ResumableSamplePassMixin ):
302303 """SamplePass wrapper to generate Torch unittests via model_path_handler."""
303304
304305 def __init__ (self , config = None ):
@@ -313,10 +314,15 @@ def declare_config(
313314 try_run : bool = False ,
314315 data_input_predicator_filepath : str = None ,
315316 data_input_predicator_class_name : str = None ,
317+ resume : bool = False ,
318+ limits_handled_models : int = None ,
316319 ):
317320 pass
318321
319322 def __call__ (self , rel_model_path : str ):
323+ self .resumable_handle_sample (rel_model_path )
324+
325+ def resume (self , rel_model_path : str ):
320326 model_path_prefix = Path (self .config ["model_path_prefix" ])
321327 output_dir = Path (self .config ["output_dir" ])
322328 generator = AgentUnittestGenerator (
0 commit comments