Skip to content

Commit 90af6f3

Browse files
committed
tmp
1 parent 6ea571b commit 90af6f3

File tree

11 files changed

+31
-159
lines changed

11 files changed

+31
-159
lines changed

src/fairseq2/checkpoint/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,11 @@
1515
CheckpointNotFoundError as CheckpointNotFoundError,
1616
)
1717
from fairseq2.checkpoint._manager import CheckpointSaveError as CheckpointSaveError
18-
from fairseq2.checkpoint._manager import CheckpointSaver as CheckpointSaver
1918
from fairseq2.checkpoint._manager import CheckpointState as CheckpointState
2019
from fairseq2.checkpoint._manager import (
2120
CheckpointStateProcessor as CheckpointStateProcessor,
2221
)
2322
from fairseq2.checkpoint._manager import FileCheckpointManager as FileCheckpointManager
24-
from fairseq2.checkpoint._manager import InProcCheckpointSaver as InProcCheckpointSaver
25-
from fairseq2.checkpoint._manager import (
26-
OutOfProcCheckpointSaver as OutOfProcCheckpointSaver,
27-
)
2823
from fairseq2.checkpoint._manager import Stateful as Stateful
2924
from fairseq2.checkpoint._metadata_provider import (
3025
CheckpointMetadataSaver as CheckpointMetadataSaver,

src/fairseq2/checkpoint/_manager.py

Lines changed: 14 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,12 @@
1010
from collections.abc import Callable, Mapping, Set
1111
from concurrent.futures import Future
1212
from copy import deepcopy
13-
from multiprocessing.pool import Pool
1413
from os import scandir
1514
from pathlib import Path
16-
from pickle import PickleError
1715
from shutil import Error
18-
from signal import SIG_IGN, SIGINT, signal
19-
from typing import ClassVar, Protocol, TypeAlias, cast, final, runtime_checkable
16+
from typing import Protocol, TypeAlias, cast, final, runtime_checkable
2017

2118
import torch
22-
import torch.multiprocessing as mp
2319
from torch import Tensor
2420
from typing_extensions import override
2521

@@ -141,7 +137,7 @@ class FileCheckpointManager(CheckpointManager):
141137
_checkpoint_dir: Path
142138
_gangs: Gangs
143139
_file_system: FileSystem
144-
_saver: CheckpointSaver
140+
_tensor_dumper: TensorDumper
145141
_tensor_loader: TensorLoader
146142
_thread_pool: ThreadPool
147143
_save_op: Future[Callable[[], None]] | None
@@ -152,7 +148,7 @@ def __init__(
152148
checkpoint_dir: Path,
153149
gangs: Gangs,
154150
file_system: FileSystem,
155-
saver: CheckpointSaver,
151+
tensor_dumper: TensorDumper,
156152
tensor_loader: TensorLoader,
157153
thread_pool: ThreadPool,
158154
) -> None:
@@ -167,8 +163,7 @@ def __init__(
167163

168164
self._file_system = file_system
169165

170-
self._saver = saver
171-
166+
self._tensor_dumper = tensor_dumper
172167
self._tensor_loader = tensor_loader
173168

174169
self._thread_pool = thread_pool
@@ -464,7 +459,7 @@ def _do_save_checkpoint(
464459
def save() -> Callable[[], None]:
465460
nonlocal state
466461

467-
self._saver.save(step_nr, state)
462+
self._save_state_files(step_nr, state)
468463

469464
del state
470465

@@ -540,6 +535,15 @@ def move_to_host(item: object) -> object:
540535

541536
return cast(dict[str, object], move_to_host(state_dict))
542537

538+
def _save_state_files(self, step_nr: int, state: CheckpointState) -> None:
539+
for kind, (file, state_dict) in state.items():
540+
try:
541+
self._tensor_dumper.dump(state_dict, file)
542+
except TensorDumpError as ex:
543+
raise CheckpointSaveError(
544+
step_nr, f"The '{kind}' state of step {step_nr} cannot be saved to the '{ex.path}' file. See the nested exception for details." # fmt: skip
545+
) from ex
546+
543547
def _copy_cc(self, step_nr: int) -> None:
544548
gangs = self._gangs
545549

@@ -983,103 +987,11 @@ def load_error() -> CheckpointError:
983987

984988
return scores
985989

986-
@override
987-
def close(self) -> None:
988-
self._saver.close()
989-
990-
991-
class CheckpointSaver(Closable):
992-
@abstractmethod
993-
def save(self, step_nr: int, state: CheckpointState) -> None: ...
994-
995-
996-
@final
997-
class InProcCheckpointSaver(CheckpointSaver):
998-
_tensor_dumper: TensorDumper
999-
1000-
def __init__(self, tensor_dumper: TensorDumper) -> None:
1001-
self._tensor_dumper = tensor_dumper
1002-
1003-
@override
1004-
def save(self, step_nr: int, state: CheckpointState) -> None:
1005-
_save_state_files(self._tensor_dumper, step_nr, state)
1006-
1007990
@override
1008991
def close(self) -> None:
1009992
pass
1010993

1011994

1012-
@final
1013-
class OutOfProcCheckpointSaver(CheckpointSaver):
1014-
_pool: Pool
1015-
1016-
def __init__(self, pool: Pool) -> None:
1017-
self._pool = pool
1018-
1019-
@staticmethod
1020-
def create(tensor_dumper: TensorDumper) -> OutOfProcCheckpointSaver:
1021-
mp.set_sharing_strategy("file_system")
1022-
1023-
ctx = mp.get_context("spawn")
1024-
1025-
# Do not allow the pool process to handle SIGINT. It will be gracefully
1026-
# closed when `close()` is called.
1027-
sig = signal(SIGINT, SIG_IGN)
1028-
1029-
try:
1030-
pool = ctx.Pool(1, _PoolProcess.init, (tensor_dumper,))
1031-
except (RuntimeError, ValueError, PickleError) as ex:
1032-
raise CheckpointError(
1033-
"The checkpoint process pool cannot be initialized. See the nested exception for details." # fmt: skip
1034-
) from ex
1035-
finally:
1036-
signal(SIGINT, sig)
1037-
1038-
return OutOfProcCheckpointSaver(pool)
1039-
1040-
@override
1041-
def save(self, step_nr: int, state: CheckpointState) -> None:
1042-
try:
1043-
self._pool.apply(_PoolProcess.save_state_files, (step_nr, state))
1044-
except RuntimeError as ex:
1045-
raise CheckpointError(
1046-
"The checkpoint process pool has failed to dispatch the save operation. See the nested exception for details." # fmt: skip
1047-
) from ex
1048-
1049-
@override
1050-
def close(self) -> None:
1051-
self._pool.close()
1052-
1053-
self._pool.join()
1054-
1055-
1056-
class _PoolProcess:
1057-
_tensor_dumper: ClassVar[TensorDumper | None] = None
1058-
1059-
@staticmethod
1060-
def init(tensor_dumper: TensorDumper) -> None:
1061-
_PoolProcess._tensor_dumper = tensor_dumper
1062-
1063-
@staticmethod
1064-
def save_state_files(step_nr: int, state: CheckpointState) -> None:
1065-
if _PoolProcess._tensor_dumper is None:
1066-
raise InternalError("`_tensor_dumper` is `None`.")
1067-
1068-
_save_state_files(_PoolProcess._tensor_dumper, step_nr, state)
1069-
1070-
1071-
def _save_state_files(
1072-
tensor_dumper: TensorDumper, step_nr: int, state: CheckpointState
1073-
) -> None:
1074-
for kind, (file, state_dict) in state.items():
1075-
try:
1076-
tensor_dumper.dump(state_dict, file)
1077-
except TensorDumpError as ex:
1078-
raise CheckpointSaveError(
1079-
step_nr, f"The '{kind}' state of step {step_nr} cannot be saved to the '{ex.path}' file. See the nested exception for details." # fmt: skip
1080-
) from ex
1081-
1082-
1083995
class CheckpointNotFoundError(Exception):
1084996
step_nr: int
1085997

src/fairseq2/datasets/instruction.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def splits(self) -> set[str]:
128128

129129

130130
# TODO: FIX, INFER
131-
npc = 10
131+
npc = 5 # 10
132132

133133

134134
GENERIC_INSTRUCTION_DATASET_FAMILY: Final = "generic_instruction"
@@ -223,9 +223,9 @@ def create_reader(
223223
else:
224224
builder = DataPipeline.concat(pipelines)
225225

226-
# Shuffle files. Must be consistent across all processes.
227-
if options.example_shuffle_window != 1:
228-
builder.shuffle(options.example_shuffle_window, seed=seed)
226+
# # Shuffle files. Must be consistent across all processes.
227+
# if options.example_shuffle_window != 1:
228+
# builder.shuffle(options.example_shuffle_window, seed=seed)
229229

230230
seed += 1
231231

@@ -286,9 +286,9 @@ def skip(example: dict[str, Any]) -> bool:
286286
else:
287287
raise NotSupportedError(f"`{batching}` is not supported.")
288288

289-
# Shuffle buckets.
290-
if options.batch_shuffle_window != 1:
291-
builder.shuffle(options.batch_shuffle_window, seed=seed)
289+
## # Shuffle buckets.
290+
# if options.batch_shuffle_window != 1:
291+
# builder.shuffle(options.batch_shuffle_window, seed=seed)
292292

293293
seed += 1
294294

@@ -308,7 +308,7 @@ def skip(example: dict[str, Any]) -> bool:
308308
builder.take(options.max_num_batches)
309309

310310
# Prefetch `num_prefetch` batches in background.
311-
builder.prefetch(options.num_prefetch)
311+
# builder.prefetch(options.num_prefetch)
312312

313313
# Wrap examples with `SequenceBatch`.
314314
def to_batch(example: dict[str, Any]) -> SequenceBatch:

src/fairseq2/recipes/common/_checkpoint.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,15 @@
88

99
from pathlib import Path
1010

11-
from fairseq2.checkpoint import (
12-
CheckpointManager,
13-
CheckpointSaver,
14-
FileCheckpointManager,
15-
InProcCheckpointSaver,
16-
OutOfProcCheckpointSaver,
17-
)
11+
from fairseq2.checkpoint import CheckpointManager, FileCheckpointManager
1812
from fairseq2.context import RuntimeContext
1913
from fairseq2.gang import Gangs
20-
from fairseq2.recipes.config import RegimeSection
2114
from fairseq2.utils.io import TorchTensorDumper, TorchTensorLoader
2215
from fairseq2.utils.threading import get_default_thread_pool
2316

2417

2518
def create_checkpoint_manager(
26-
context: RuntimeContext,
27-
regime_section: RegimeSection,
28-
gangs: Gangs,
29-
output_dir: Path,
19+
context: RuntimeContext, gangs: Gangs, output_dir: Path
3020
) -> CheckpointManager:
3121
checkpoint_dir = output_dir.joinpath("checkpoints")
3222

@@ -35,15 +25,8 @@ def create_checkpoint_manager(
3525
tensor_loader = TorchTensorLoader(file_system)
3626
tensor_dumper = TorchTensorDumper(file_system)
3727

38-
saver: CheckpointSaver
39-
40-
if regime_section.in_proc_checkpoint:
41-
saver = InProcCheckpointSaver(tensor_dumper)
42-
else:
43-
saver = OutOfProcCheckpointSaver.create(tensor_dumper)
44-
4528
thread_pool = get_default_thread_pool()
4629

4730
return FileCheckpointManager(
48-
checkpoint_dir, gangs, file_system, saver, tensor_loader, thread_pool
31+
checkpoint_dir, gangs, file_system, tensor_dumper, tensor_loader, thread_pool
4932
)

src/fairseq2/recipes/config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,6 @@ class RegimeSection:
266266

267267
keep_checkpoint_every_n_steps: int | None = None
268268

269-
in_proc_checkpoint: bool = False
270-
"""
271-
If ``True``, saves checkpoints in a background thread instead of a child
272-
process.
273-
"""
274-
275269
publish_metrics_after_n_steps: int = 0
276270

277271
publish_metrics_every_n_steps: int | None = None

src/fairseq2/recipes/lm/_instruction_finetune.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,7 @@ def load_instruction_finetuner(
227227

228228
gangs = setup_training_gangs(context, config.gang, config.trainer)
229229

230-
checkpoint_manager = create_checkpoint_manager(
231-
context, config.regime, gangs, output_dir
232-
)
230+
checkpoint_manager = create_checkpoint_manager(context, gangs, output_dir)
233231

234232
seed = config.common.seed
235233

src/fairseq2/recipes/lm/_preference_finetune/_recipe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,7 @@ def load_po_finetuner(
114114

115115
gangs = setup_training_gangs(context, config.gang, config.trainer)
116116

117-
checkpoint_manager = create_checkpoint_manager(
118-
context, config.regime, gangs, output_dir
119-
)
117+
checkpoint_manager = create_checkpoint_manager(context, gangs, output_dir)
120118

121119
seed = config.common.seed
122120

src/fairseq2/recipes/lm/_train.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,7 @@ def load_lm_trainer(
180180

181181
gangs = setup_training_gangs(context, config.gang, config.trainer)
182182

183-
checkpoint_manager = create_checkpoint_manager(
184-
context, config.regime, gangs, output_dir
185-
)
183+
checkpoint_manager = create_checkpoint_manager(context, gangs, output_dir)
186184

187185
seed = config.common.seed
188186

src/fairseq2/recipes/mt/_train.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,7 @@ def load_mt_trainer(
226226

227227
gangs = setup_training_gangs(context, config.gang, config.trainer)
228228

229-
checkpoint_manager = create_checkpoint_manager(
230-
context, config.regime, gangs, output_dir
231-
)
229+
checkpoint_manager = create_checkpoint_manager(context, gangs, output_dir)
232230

233231
seed = config.common.seed
234232

src/fairseq2/recipes/wav2vec2/_train.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,7 @@ def load_wav2vec2_trainer(
203203

204204
gangs = setup_training_gangs(context, config.gang, config.trainer)
205205

206-
checkpoint_manager = create_checkpoint_manager(
207-
context, config.regime, gangs, output_dir
208-
)
206+
checkpoint_manager = create_checkpoint_manager(context, gangs, output_dir)
209207

210208
seed = config.common.seed
211209

0 commit comments

Comments
 (0)