Skip to content

Commit 3bee819

Browse files
authored
Add strategy="auto" support on the 1.9.x branch (#16916)
* Fix auto support on the 1.9.x branch * CHANGELOG * CHANGELOG * Fix CHANGELOG
1 parent 8e55ff7 commit 3bee819

File tree

7 files changed

+54
-7
lines changed

7 files changed

+54
-7
lines changed

src/lightning_app/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8-
## [1.9.4] - 2023-02-28
8+
## [1.9.4] - 2023-03-01
99

1010
### Removed
1111

src/lightning_fabric/CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8-
## [1.9.4] - 2023-02-28
8+
## [1.9.4] - 2023-03-01
9+
10+
### Added
11+
12+
- Added `Fabric(strategy="auto")` support ([#16916](https://github.com/Lightning-AI/lightning/pull/16916))
913

1014
### Fixed
1115

src/lightning_fabric/connector.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(
145145
self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment()
146146

147147
# 4. Instantiate Strategy - Part 1
148-
if self._strategy_flag is None:
148+
if self._strategy_flag in (None, "auto"):
149149
self._strategy_flag = self._choose_strategy()
150150
# In specific cases, ignore user selection and fall back to a different strategy
151151
self._check_strategy_and_fallback()
@@ -184,7 +184,11 @@ def _check_config_and_set_final_flags(
184184
if strategy is not None:
185185
self._strategy_flag = strategy
186186

187-
if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
187+
if (
188+
strategy not in (None, "auto")
189+
and strategy not in self._registered_strategies
190+
and not isinstance(strategy, Strategy)
191+
):
188192
raise ValueError(
189193
f"You selected an invalid strategy name: `strategy={strategy!r}`."
190194
" It must be either a string or an instance of `lightning.fabric.strategies.Strategy`."

src/pytorch_lightning/CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8-
## [1.9.4] - 2023-02-28
8+
## [1.9.4] - 2023-03-01
9+
10+
### Added
11+
12+
- Added `Fabric(strategy="auto")` support. It will choose DDP over DDP-spawn, contrary to `strategy=None` (default) ([#16916](https://github.com/Lightning-AI/lightning/pull/16916))
913

1014
### Fixed
1115

src/pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def __init__(
207207
self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment()
208208

209209
# 4. Instantiate Strategy - Part 1
210-
if self._strategy_flag is None:
210+
if self._strategy_flag in (None, "auto"):
211211
self._strategy_flag = self._choose_strategy()
212212
# In specific cases, ignore user selection and fall back to a different strategy
213213
self._check_strategy_and_fallback()
@@ -273,7 +273,11 @@ def _check_config_and_set_final_flags(
273273
" you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead."
274274
)
275275

276-
if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
276+
if (
277+
strategy not in (None, "auto")
278+
and strategy not in self._registered_strategies
279+
and not isinstance(strategy, Strategy)
280+
):
277281
raise ValueError(
278282
f"You selected an invalid strategy name: `strategy={strategy!r}`."
279283
" It must be either a string or an instance of `pytorch_lightning.strategies.Strategy`."
@@ -639,6 +643,9 @@ def _choose_strategy(self) -> Union[Strategy, str]:
639643
if len(self._parallel_devices) > 1:
640644
if _IS_INTERACTIVE:
641645
return "ddp_fork"
646+
if self._strategy_flag == "auto":
647+
# None chooses "ddp_spawn" for backwards compatibility, auto chooses "ddp" for future compatibility
648+
return "ddp"
642649
return "ddp_spawn"
643650

644651
return DDPStrategy.strategy_name

tests/tests_fabric/test_connector.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,3 +893,19 @@ def get_defaults(cls):
893893
# defaults should match on the intersection of argument names
894894
for name, connector_default in connector_defaults.items():
895895
assert connector_default == fabric_defaults[name]
896+
897+
898+
@mock.patch("lightning_fabric.accelerators.cuda.num_cuda_devices", return_value=2)
899+
@mock.patch("lightning_fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
900+
def test_connector_auto_selection(*_):
901+
connector = _Connector(accelerator="auto", strategy=None, devices="auto")
902+
assert isinstance(connector.accelerator, CUDAAccelerator)
903+
assert isinstance(connector.strategy, DDPStrategy)
904+
assert isinstance(connector.strategy.launcher, _SubprocessScriptLauncher)
905+
assert connector._devices_flag == [0, 1]
906+
907+
connector = _Connector(accelerator="auto", strategy="auto", devices="auto")
908+
assert isinstance(connector.accelerator, CUDAAccelerator)
909+
assert isinstance(connector.strategy, DDPStrategy)
910+
assert isinstance(connector.strategy.launcher, _SubprocessScriptLauncher)
911+
assert connector._devices_flag == [0, 1]

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,3 +892,15 @@ def get_defaults(cls):
892892
for name, connector_default in connector_defaults.items():
893893
name = lut.get(name, name)
894894
assert connector_default == trainer_defaults[name]
895+
896+
897+
def test_connector_auto_selection(cuda_count_2, mps_count_0):
898+
trainer = Trainer(accelerator="auto", strategy=None, devices="auto")
899+
assert isinstance(trainer.accelerator, CUDAAccelerator)
900+
assert isinstance(trainer.strategy, DDPSpawnStrategy)
901+
assert trainer.num_devices == 2
902+
903+
trainer = Trainer(accelerator="auto", strategy="auto", devices="auto")
904+
assert isinstance(trainer.accelerator, CUDAAccelerator)
905+
assert isinstance(trainer.strategy, DDPStrategy)
906+
assert trainer.num_devices == 2

0 commit comments

Comments
 (0)