Skip to content

Commit 8c9cf00

Browse files
awaelchlilantiga
authored andcommitted
Fix num_nodes not set for FSDPStrategy (#17438)
(cherry picked from commit 2e5a7f9)
1 parent 8b11e99 commit 8c9cf00

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
- Fixed issue where `Model.load_from_checkpoint("checkpoint.ckpt", map_location=map_location)` would always return model on CPU ([#17308](https://github.com/Lightning-AI/lightning/pull/17308))
1818

19+
- Fixed an issue that caused `num_nodes` not to be set correctly for `FSDPStrategy` ([#17438](https://github.com/Lightning-AI/lightning/pull/17438))
1920

2021

2122
## [2.0.1] - 2023-03-30

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def _lazy_init_strategy(self) -> None:
571571
else:
572572
self.strategy.parallel_devices = self._parallel_devices
573573
if hasattr(self.strategy, "num_nodes"):
574-
self.strategy._num_nodes = self._num_nodes_flag
574+
self.strategy.num_nodes = self._num_nodes_flag
575575
if hasattr(self.strategy, "_layer_sync"):
576576
self.strategy._layer_sync = self._layer_sync
577577
if hasattr(self.strategy, "set_world_ranks"):

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,3 +971,17 @@ def _mock_tpu_available(value):
971971
assert isinstance(connector.strategy.cluster_environment, XLAEnvironment)
972972
assert connector.strategy._start_method == "fork"
973973
assert connector.strategy.launcher.is_interactive_compatible
974+
975+
976+
@pytest.mark.parametrize(
977+
"strategy",
978+
[
979+
"ddp",
980+
"ddp_spawn",
981+
pytest.param("deepspeed", marks=RunIf(deepspeed=True)),
982+
pytest.param("fsdp", marks=RunIf(min_torch="1.12.0")),
983+
],
984+
)
985+
def test_connector_sets_num_nodes(strategy, cuda_count_2):
986+
trainer = Trainer(accelerator="cuda", strategy=strategy, devices=2, num_nodes=2)
987+
assert trainer.strategy.num_nodes == 2

0 commit comments

Comments
 (0)