@@ -33,14 +33,23 @@ class AsyncCheckpointIO(_WrappingCheckpointIO):
3333
3434 def __init__ (self , checkpoint_io : Optional ["CheckpointIO" ] = None ) -> None :
3535 super ().__init__ (checkpoint_io )
36-
37- self ._executor = ThreadPoolExecutor (max_workers = 1 )
36+ self ._executor : Optional [ThreadPoolExecutor ] = None
3837 self ._error : Optional [BaseException ] = None
3938
39+ # CheckpointIO doesn't have a setup method so we have to do something like.
40+ # We can't do setup in __init__ because if train or validate is called more than once the
41+ # teardown method deletes the executor.
42+ def _ensure_setup (self ) -> None :
43+ if self ._executor is None :
44+ self ._executor = ThreadPoolExecutor (max_workers = 1 )
45+ self ._error : Optional [BaseException ] = None
46+
4047 @override
4148 def save_checkpoint (self , * args : Any , ** kwargs : Any ) -> None :
4249 """Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``."""
4350
51+ self ._ensure_setup ()
52+
4453 def _save_checkpoint (* args : Any , ** kwargs : Any ) -> None :
4554 try :
4655 assert self .checkpoint_io is not None
@@ -58,8 +67,10 @@ def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
5867 @override
5968 def teardown (self ) -> None :
6069 """This method is called to close the threads."""
61- self ._executor .shutdown (wait = True )
70+ if self ._executor is not None :
71+ self ._executor .shutdown (wait = True )
72+ self ._executor = None
6273
63- # if an error was raised anytime in any of the `executor.submit` calls
64- if self ._error :
65- raise self ._error
74+ # if an error was raised anytime in any of the `executor.submit` calls
75+ if self ._error :
76+ raise self ._error
0 commit comments