diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py index e19b5761c4d4b..95c7b4b41746e 100644 --- a/src/lightning/pytorch/loops/optimization/automatic.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -23,13 +23,14 @@ from typing_extensions import override import lightning.pytorch as pl +from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.loops.loop import _Loop from lightning.pytorch.loops.optimization.closure import AbstractClosure, OutputResult from lightning.pytorch.loops.progress import _OptimizationProgress from lightning.pytorch.loops.utilities import _block_parallel_sync_behavior from lightning.pytorch.trainer import call from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.rank_zero import WarningCache +from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -320,10 +321,11 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult: self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility if training_step_output is None and trainer.world_size > 1: - raise RuntimeError( + rank_zero_warn( "Skipping the `training_step` by returning None in distributed training is not supported." " It is recommended that you rewrite your training logic to avoid having to skip the step in the first" - " place." + " place.", + category=PossibleUserWarning, ) return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches) diff --git a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py index e20c1789be023..48e4d5405664d 100644 --- a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterator, Mapping -from contextlib import nullcontext from typing import Generic, TypeVar import pytest @@ -84,27 +83,3 @@ def training_step(self, batch, batch_idx): with pytest.raises(MisconfigurationException, match=match): trainer.fit(model) - - -@pytest.mark.parametrize("world_size", [1, 2]) -def test_skip_training_step_not_allowed(world_size, tmp_path): - """Test that skipping the training_step in distributed training is not allowed.""" - - class TestModel(BoringModel): - def training_step(self, batch, batch_idx): - return None - - model = TestModel() - trainer = Trainer( - default_root_dir=tmp_path, - max_steps=1, - barebones=True, - ) - trainer.strategy.world_size = world_size # mock world size without launching processes - error_context = ( - pytest.raises(RuntimeError, match="Skipping the `training_step` .* is not supported") - if world_size > 1 - else nullcontext() - ) - with error_context: - trainer.fit(model)