Skip to content

Commit 95d6b6b

Browse files
authored
Disable skipping training step in distributed training (#19918)
1 parent 5d79325 commit 95d6b6b

File tree

5 files changed

+41
-4
lines changed

5 files changed

+41
-4
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828

2929
- Relaxed the requirement for custom batch samplers to expose `drop_last` for prediction ([#19678](https://github.com/Lightning-AI/pytorch-lightning/pull/19678))
3030

31-
-
31+
- It is no longer allowed to skip `training_step()` by returning `None` in distributed training ([#19918](https://github.com/Lightning-AI/pytorch-lightning/pull/19918))
32+
3233

3334
### Deprecated
3435

src/lightning/pytorch/loops/optimization/automatic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,14 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
314314
"""
315315
trainer = self.trainer
316316

317-
# manually capture logged metrics
318317
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
319318
self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility
320319

320+
if training_step_output is None and trainer.world_size > 1:
321+
raise RuntimeError(
322+
"Skipping the `training_step` by returning None in distributed training is not supported."
323+
" It is recommended that you rewrite your training logic to avoid having to skip the step in the first"
324+
" place."
325+
)
326+
321327
return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)

tests/tests_pytorch/loops/optimization/test_optimizer_loop.py renamed to tests/tests_pytorch/loops/optimization/test_automatic_loop.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
from contextlib import nullcontext
1515
from typing import Dict, Generic, Iterator, Mapping, TypeVar
1616

1717
import pytest
@@ -82,3 +82,27 @@ def training_step(self, batch, batch_idx):
8282

8383
with pytest.raises(MisconfigurationException, match=match):
8484
trainer.fit(model)
85+
86+
87+
@pytest.mark.parametrize("world_size", [1, 2])
88+
def test_skip_training_step_not_allowed(world_size, tmp_path):
89+
"""Test that skipping the training_step in distributed training is not allowed."""
90+
91+
class TestModel(BoringModel):
92+
def training_step(self, batch, batch_idx):
93+
return None
94+
95+
model = TestModel()
96+
trainer = Trainer(
97+
default_root_dir=tmp_path,
98+
max_steps=1,
99+
barebones=True,
100+
)
101+
trainer.strategy.world_size = world_size # mock world size without launching processes
102+
error_context = (
103+
pytest.raises(RuntimeError, match="Skipping the `training_step` .* is not supported")
104+
if world_size > 1
105+
else nullcontext()
106+
)
107+
with error_context:
108+
trainer.fit(model)

tests/tests_pytorch/models/test_hooks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ class TestModel(BoringModel):
178178
def training_step(self, batch, batch_idx):
179179
assert batch.samples.device == self.device
180180
assert isinstance(batch_idx, int)
181+
# the actual training step is not needed for the assertions
182+
return super().training_step(torch.rand(1, 32, device=self.device), batch_idx)
181183

182184
def train_dataloader(self):
183185
return torch.utils.data.DataLoader(RandomDataset(32, 64), collate_fn=collate_fn)

tests/tests_pytorch/trainer/test_dataloaders.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,8 @@ def __init__(self):
641641

642642
def training_step(self, batch, batch_idx):
643643
self.batches_seen.append(batch)
644+
# the actual training step is not needed for the assertions below
645+
return super().training_step(torch.rand(1, 32, device=self.device), batch_idx)
644646

645647
def on_train_epoch_end(self):
646648
world_size = 2
@@ -810,8 +812,10 @@ def __init__(self):
810812
super().__init__()
811813
self.seen_samples = []
812814

813-
def training_step(self, batch):
815+
def training_step(self, batch, batch_idx):
814816
self.seen_samples.extend(batch.tolist())
817+
# the actual training step is not needed for the test
818+
return super().training_step(torch.rand(1, 32, device=self.device), batch_idx)
815819

816820
def on_train_end(self):
817821
seen_samples = self.all_gather(self.seen_samples)

0 commit comments

Comments
 (0)