1010from collections .abc import Callable , Mapping , Set
1111from concurrent .futures import Future
1212from copy import deepcopy
13- from multiprocessing .pool import Pool
1413from os import scandir
1514from pathlib import Path
16- from pickle import PickleError
1715from shutil import Error
18- from signal import SIG_IGN , SIGINT , signal
19- from typing import ClassVar , Protocol , TypeAlias , cast , final , runtime_checkable
16+ from typing import Protocol , TypeAlias , cast , final , runtime_checkable
2017
2118import torch
22- import torch .multiprocessing as mp
2319from torch import Tensor
2420from typing_extensions import override
2521
@@ -141,7 +137,7 @@ class FileCheckpointManager(CheckpointManager):
141137 _checkpoint_dir : Path
142138 _gangs : Gangs
143139 _file_system : FileSystem
144- _saver : CheckpointSaver
140+ _tensor_dumper : TensorDumper
145141 _tensor_loader : TensorLoader
146142 _thread_pool : ThreadPool
147143 _save_op : Future [Callable [[], None ]] | None
@@ -152,7 +148,7 @@ def __init__(
152148 checkpoint_dir : Path ,
153149 gangs : Gangs ,
154150 file_system : FileSystem ,
155- saver : CheckpointSaver ,
151+ tensor_dumper : TensorDumper ,
156152 tensor_loader : TensorLoader ,
157153 thread_pool : ThreadPool ,
158154 ) -> None :
@@ -167,8 +163,7 @@ def __init__(
167163
168164 self ._file_system = file_system
169165
170- self ._saver = saver
171-
166+ self ._tensor_dumper = tensor_dumper
172167 self ._tensor_loader = tensor_loader
173168
174169 self ._thread_pool = thread_pool
@@ -464,7 +459,7 @@ def _do_save_checkpoint(
464459 def save () -> Callable [[], None ]:
465460 nonlocal state
466461
467- self ._saver . save (step_nr , state )
462+ self ._save_state_files (step_nr , state )
468463
469464 del state
470465
@@ -540,6 +535,15 @@ def move_to_host(item: object) -> object:
540535
541536 return cast (dict [str , object ], move_to_host (state_dict ))
542537
538+ def _save_state_files (self , step_nr : int , state : CheckpointState ) -> None :
539+ for kind , (file , state_dict ) in state .items ():
540+ try :
541+ self ._tensor_dumper .dump (state_dict , file )
542+ except TensorDumpError as ex :
543+ raise CheckpointSaveError (
544+ step_nr , f"The '{ kind } ' state of step { step_nr } cannot be saved to the '{ ex .path } ' file. See the nested exception for details." # fmt: skip
545+ ) from ex
546+
543547 def _copy_cc (self , step_nr : int ) -> None :
544548 gangs = self ._gangs
545549
@@ -983,103 +987,11 @@ def load_error() -> CheckpointError:
983987
984988 return scores
985989
986- @override
987- def close (self ) -> None :
988- self ._saver .close ()
989-
990-
991- class CheckpointSaver (Closable ):
992- @abstractmethod
993- def save (self , step_nr : int , state : CheckpointState ) -> None : ...
994-
995-
996- @final
997- class InProcCheckpointSaver (CheckpointSaver ):
998- _tensor_dumper : TensorDumper
999-
1000- def __init__ (self , tensor_dumper : TensorDumper ) -> None :
1001- self ._tensor_dumper = tensor_dumper
1002-
1003- @override
1004- def save (self , step_nr : int , state : CheckpointState ) -> None :
1005- _save_state_files (self ._tensor_dumper , step_nr , state )
1006-
1007990 @override
1008991 def close (self ) -> None :
1009992 pass
1010993
1011994
1012- @final
1013- class OutOfProcCheckpointSaver (CheckpointSaver ):
1014- _pool : Pool
1015-
1016- def __init__ (self , pool : Pool ) -> None :
1017- self ._pool = pool
1018-
1019- @staticmethod
1020- def create (tensor_dumper : TensorDumper ) -> OutOfProcCheckpointSaver :
1021- mp .set_sharing_strategy ("file_system" )
1022-
1023- ctx = mp .get_context ("spawn" )
1024-
1025- # Do not allow the pool process to handle SIGINT. It will be gracefully
1026- # closed when `close()` is called.
1027- sig = signal (SIGINT , SIG_IGN )
1028-
1029- try :
1030- pool = ctx .Pool (1 , _PoolProcess .init , (tensor_dumper ,))
1031- except (RuntimeError , ValueError , PickleError ) as ex :
1032- raise CheckpointError (
1033- "The checkpoint process pool cannot be initialized. See the nested exception for details." # fmt: skip
1034- ) from ex
1035- finally :
1036- signal (SIGINT , sig )
1037-
1038- return OutOfProcCheckpointSaver (pool )
1039-
1040- @override
1041- def save (self , step_nr : int , state : CheckpointState ) -> None :
1042- try :
1043- self ._pool .apply (_PoolProcess .save_state_files , (step_nr , state ))
1044- except RuntimeError as ex :
1045- raise CheckpointError (
1046- "The checkpoint process pool has failed to dispatch the save operation. See the nested exception for details." # fmt: skip
1047- ) from ex
1048-
1049- @override
1050- def close (self ) -> None :
1051- self ._pool .close ()
1052-
1053- self ._pool .join ()
1054-
1055-
1056- class _PoolProcess :
1057- _tensor_dumper : ClassVar [TensorDumper | None ] = None
1058-
1059- @staticmethod
1060- def init (tensor_dumper : TensorDumper ) -> None :
1061- _PoolProcess ._tensor_dumper = tensor_dumper
1062-
1063- @staticmethod
1064- def save_state_files (step_nr : int , state : CheckpointState ) -> None :
1065- if _PoolProcess ._tensor_dumper is None :
1066- raise InternalError ("`_tensor_dumper` is `None`." )
1067-
1068- _save_state_files (_PoolProcess ._tensor_dumper , step_nr , state )
1069-
1070-
1071- def _save_state_files (
1072- tensor_dumper : TensorDumper , step_nr : int , state : CheckpointState
1073- ) -> None :
1074- for kind , (file , state_dict ) in state .items ():
1075- try :
1076- tensor_dumper .dump (state_dict , file )
1077- except TensorDumpError as ex :
1078- raise CheckpointSaveError (
1079- step_nr , f"The '{ kind } ' state of step { step_nr } cannot be saved to the '{ ex .path } ' file. See the nested exception for details." # fmt: skip
1080- ) from ex
1081-
1082-
1083995class CheckpointNotFoundError (Exception ):
1084996 step_nr : int
1085997
0 commit comments