Skip to content

Commit 7434c47

Browse files
carmoccaawaelchli
andauthored
Raise an exception when calling fit twice with spawn (#18776)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 5a83f54 commit 7434c47

File tree

5 files changed

+52
-16
lines changed

5 files changed

+52
-16
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8989
- If not set by the user, Lightning will set `OMP_NUM_THREADS` to `num_cpus / num_processes` when launching subprocesses (e.g. when DDP is used) to avoid system overload for CPU-intensive tasks ([#18677](https://github.com/Lightning-AI/lightning/pull/18677))
9090
- The `ModelCheckpoint` no longer deletes files under the save-top-k mechanism when resuming from a folder that is not the same as the current checkpoint folder ([#18750](https://github.com/Lightning-AI/lightning/pull/18750))
9191
- The `ModelCheckpoint` no longer deletes the file that was passed to `Trainer.fit(ckpt_path=...)` ([#18750](https://github.com/Lightning-AI/lightning/pull/18750))
92+
- Calling `trainer.fit()` twice now raises an error with strategies that spawn subprocesses through `multiprocessing` (ddp_spawn, xla) ([#18776](https://github.com/Lightning-AI/lightning/pull/18776))
9293

9394
### Deprecated
9495

src/lightning/pytorch/strategies/launchers/multiprocessing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080
f" {', '.join(mp.get_all_start_methods())}"
8181
)
8282
self.procs: List[mp.Process] = []
83+
self._already_fit = False
8384

8485
@property
8586
def is_interactive_compatible(self) -> bool:
@@ -106,6 +107,13 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
106107
_check_bad_cuda_fork()
107108
if self._start_method == "spawn":
108109
_check_missing_main_guard()
110+
if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING:
111+
# resolving https://github.com/Lightning-AI/lightning/issues/18775 will lift this restriction
112+
raise NotImplementedError(
113+
"Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"
114+
" supported. You can work around this limitation by creating a new Trainer instance and passing the"
115+
" `fit(ckpt_path=...)` argument."
116+
)
109117

110118
# The default cluster environment in Lightning chooses a random free port number
111119
# This needs to be done in the main process here before starting processes to ensure each rank will connect
@@ -137,6 +145,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
137145
if trainer is None:
138146
return worker_output
139147

148+
self._already_fit |= trainer.state.fn == TrainerFn.FITTING
140149
self._recover_results_in_main_process(worker_output, trainer)
141150
return worker_output.trainer_results
142151

src/lightning/pytorch/strategies/launchers/xla.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
7272
**kwargs: Optional keyword arguments to be passed to the given function.
7373
7474
"""
75+
if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING:
76+
# resolving https://github.com/Lightning-AI/lightning/issues/18775 will lift this restriction
77+
raise NotImplementedError(
78+
"Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"
79+
" supported. You can work around this by creating a new Trainer instance and passing the"
80+
" `fit(ckpt_path=...)` argument."
81+
)
82+
7583
using_pjrt = _using_pjrt()
7684
# pjrt requires that the queue is serializable
7785
return_queue: Union[queue.Queue, mp.SimpleQueue] = (
@@ -104,6 +112,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
104112
if trainer is None:
105113
return worker_output
106114

115+
self._already_fit |= trainer.state.fn == TrainerFn.FITTING
107116
self._recover_results_in_main_process(worker_output, trainer)
108117
return worker_output.trainer_results
109118

src/lightning/pytorch/trainer/trainer.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,9 @@ def fit(
539539
model = _maybe_unwrap_optimized(model)
540540
self.strategy._lightning_module = model
541541
_verify_strategy_supports_compile(model, self.strategy)
542+
self.state.fn = TrainerFn.FITTING
543+
self.state.status = TrainerStatus.RUNNING
544+
self.training = True
542545
call._call_and_handle_interrupt(
543546
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
544547
)
@@ -553,10 +556,6 @@ def _fit_impl(
553556
) -> None:
554557
log.debug(f"{self.__class__.__name__}: trainer fit stage")
555558

556-
self.state.fn = TrainerFn.FITTING
557-
self.state.status = TrainerStatus.RUNNING
558-
self.training = True
559-
560559
# if a datamodule comes in as the second arg, then fix it for the user
561560
if isinstance(train_dataloaders, LightningDataModule):
562561
datamodule = train_dataloaders
@@ -572,6 +571,7 @@ def _fit_impl(
572571
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
573572
)
574573

574+
assert self.state.fn is not None
575575
ckpt_path = self._checkpoint_connector._select_ckpt_path(
576576
self.state.fn,
577577
ckpt_path,
@@ -640,6 +640,9 @@ def validate(
640640
model = _maybe_unwrap_optimized(model)
641641
self.strategy._lightning_module = model
642642
_verify_strategy_supports_compile(self.lightning_module, self.strategy)
643+
self.state.fn = TrainerFn.VALIDATING
644+
self.state.status = TrainerStatus.RUNNING
645+
self.validating = True
643646
return call._call_and_handle_interrupt(
644647
self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule
645648
)
@@ -657,10 +660,6 @@ def _validate_impl(
657660
# --------------------
658661
log.debug(f"{self.__class__.__name__}: trainer validate stage")
659662

660-
self.state.fn = TrainerFn.VALIDATING
661-
self.state.status = TrainerStatus.RUNNING
662-
self.validating = True
663-
664663
# if a datamodule comes in as the second arg, then fix it for the user
665664
if isinstance(dataloaders, LightningDataModule):
666665
datamodule = dataloaders
@@ -680,6 +679,7 @@ def _validate_impl(
680679
# links data to the trainer
681680
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
682681

682+
assert self.state.fn is not None
683683
ckpt_path = self._checkpoint_connector._select_ckpt_path(
684684
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
685685
)
@@ -749,6 +749,9 @@ def test(
749749
model = _maybe_unwrap_optimized(model)
750750
self.strategy._lightning_module = model
751751
_verify_strategy_supports_compile(self.lightning_module, self.strategy)
752+
self.state.fn = TrainerFn.TESTING
753+
self.state.status = TrainerStatus.RUNNING
754+
self.testing = True
752755
return call._call_and_handle_interrupt(
753756
self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
754757
)
@@ -766,10 +769,6 @@ def _test_impl(
766769
# --------------------
767770
log.debug(f"{self.__class__.__name__}: trainer test stage")
768771

769-
self.state.fn = TrainerFn.TESTING
770-
self.state.status = TrainerStatus.RUNNING
771-
self.testing = True
772-
773772
# if a datamodule comes in as the second arg, then fix it for the user
774773
if isinstance(dataloaders, LightningDataModule):
775774
datamodule = dataloaders
@@ -789,6 +788,7 @@ def _test_impl(
789788
# links data to the trainer
790789
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
791790

791+
assert self.state.fn is not None
792792
ckpt_path = self._checkpoint_connector._select_ckpt_path(
793793
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
794794
)
@@ -859,6 +859,9 @@ def predict(
859859
model = _maybe_unwrap_optimized(model)
860860
self.strategy._lightning_module = model
861861
_verify_strategy_supports_compile(self.lightning_module, self.strategy)
862+
self.state.fn = TrainerFn.PREDICTING
863+
self.state.status = TrainerStatus.RUNNING
864+
self.predicting = True
862865
return call._call_and_handle_interrupt(
863866
self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
864867
)
@@ -876,10 +879,6 @@ def _predict_impl(
876879
# --------------------
877880
log.debug(f"{self.__class__.__name__}: trainer predict stage")
878881

879-
self.state.fn = TrainerFn.PREDICTING
880-
self.state.status = TrainerStatus.RUNNING
881-
self.predicting = True
882-
883882
self.predict_loop.return_predictions = return_predictions # type: ignore[assignment]
884883

885884
# if a datamodule comes in as the second arg, then fix it for the user
@@ -898,6 +897,7 @@ def _predict_impl(
898897
# links data to the trainer
899898
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
900899

900+
assert self.state.fn is not None
901901
ckpt_path = self._checkpoint_connector._select_ckpt_path(
902902
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
903903
)

tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,20 @@ def test_check_for_missing_main_guard():
212212
return_value=Mock(_inheriting=True), # pretend that main is importing itself
213213
), pytest.raises(RuntimeError, match="requires that your script guards the main"):
214214
launcher.launch(function=Mock())
215+
216+
217+
def test_fit_twice_raises():
218+
model = BoringModel()
219+
trainer = Trainer(
220+
limit_train_batches=1,
221+
limit_test_batches=1,
222+
num_sanity_val_steps=0,
223+
max_epochs=1,
224+
strategy="ddp_spawn",
225+
barebones=True,
226+
)
227+
trainer.fit(model)
228+
trainer.test(model) # make sure testing in between doesnt impact the result
229+
trainer.fit_loop.max_epochs += 1
230+
with pytest.raises(NotImplementedError, match=r"twice.*is not supported"):
231+
trainer.fit(model)

0 commit comments

Comments
 (0)