Skip to content

Commit 5a83f54

Browse files
authored
Minor strategy fixes [TPU] (#18774)
1 parent 4df6e13 commit 5a83f54

File tree

17 files changed

+186
-184
lines changed

17 files changed

+186
-184
lines changed

src/lightning/fabric/connector.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -454,10 +454,9 @@ def _init_strategy(self) -> None:
454454
self.strategy = self._strategy_flag
455455

456456
def _check_and_init_precision(self) -> Precision:
457-
self._validate_precision_choice()
458457
if isinstance(self._precision_instance, Precision):
459458
return self._precision_instance
460-
if isinstance(self.accelerator, XLAAccelerator):
459+
if isinstance(self.strategy, (SingleDeviceXLAStrategy, XLAStrategy, XLAFSDPStrategy)):
461460
return XLAPrecision(self._precision_input) # type: ignore
462461
if isinstance(self.strategy, DeepSpeedStrategy):
463462
return DeepSpeedPrecision(self._precision_input) # type: ignore
@@ -492,18 +491,6 @@ def _check_and_init_precision(self) -> Precision:
492491

493492
raise RuntimeError("No precision set")
494493

495-
def _validate_precision_choice(self) -> None:
496-
"""Validate the combination of choices for precision, and accelerator."""
497-
if (
498-
isinstance(self.accelerator, XLAAccelerator)
499-
and self._precision_instance
500-
and not isinstance(self._precision_instance, XLAPrecision)
501-
):
502-
raise ValueError(
503-
f"The `XLAAccelerator` can only be used with a `XLAPrecision` plugin,"
504-
f" found: {self._precision_instance}."
505-
)
506-
507494
def _lazy_init_strategy(self) -> None:
508495
"""Lazily set missing attributes on the previously instantiated strategy."""
509496
self.strategy.accelerator = self.accelerator

src/lightning/fabric/strategies/fsdp.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -305,25 +305,15 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
305305
flattened parameters.
306306
307307
"""
308-
if _TORCH_GREATER_EQUAL_2_0:
309-
return optimizer
310-
311-
from torch.distributed.fsdp import FlatParameter
312-
313-
num_groups = len(optimizer.param_groups)
314-
if num_groups > 1:
308+
if self._fsdp_kwargs.get("use_orig_params"):
309+
return super().setup_optimizer(optimizer)
310+
if not _optimizer_has_flat_params(optimizer):
311+
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
315312
raise ValueError(
316-
"An optimizer used with an FSDP model does not support multiple param groups."
317-
f" Found {num_groups} parameter groups."
313+
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer"
314+
" after setting up the model."
318315
)
319-
320-
if any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]):
321-
return optimizer
322-
323-
raise ValueError(
324-
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer"
325-
" after setting up the model."
326-
)
316+
return optimizer
327317

328318
def module_to_device(self, module: Module) -> None:
329319
pass

src/lightning/fabric/strategies/single_xla.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
from lightning.fabric.accelerators import Accelerator
1919
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
20-
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
20+
from lightning.fabric.plugins import XLAPrecision
2121
from lightning.fabric.plugins.io.xla import XLACheckpointIO
22-
from lightning.fabric.plugins.precision import Precision
2322
from lightning.fabric.strategies import _StrategyRegistry
2423
from lightning.fabric.strategies.single_device import SingleDeviceStrategy
2524
from lightning.fabric.utilities.types import _DEVICE
@@ -32,8 +31,8 @@ def __init__(
3231
self,
3332
device: _DEVICE,
3433
accelerator: Optional[Accelerator] = None,
35-
checkpoint_io: Optional[CheckpointIO] = None,
36-
precision: Optional[Precision] = None,
34+
checkpoint_io: Optional[XLACheckpointIO] = None,
35+
precision: Optional[XLAPrecision] = None,
3736
):
3837
if not _XLA_AVAILABLE:
3938
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
@@ -50,16 +49,34 @@ def __init__(
5049
precision=precision,
5150
)
5251

53-
@property
54-
def checkpoint_io(self) -> CheckpointIO:
55-
if self._checkpoint_io is None:
56-
self._checkpoint_io = XLACheckpointIO()
57-
return self._checkpoint_io
52+
@property # type: ignore[override]
53+
def checkpoint_io(self) -> XLACheckpointIO:
54+
plugin = self._checkpoint_io
55+
if plugin is not None:
56+
assert isinstance(plugin, XLACheckpointIO)
57+
return plugin
58+
return XLACheckpointIO()
5859

5960
@checkpoint_io.setter
60-
def checkpoint_io(self, io: CheckpointIO) -> None:
61+
def checkpoint_io(self, io: Optional[XLACheckpointIO]) -> None:
62+
if io is not None and not isinstance(io, XLACheckpointIO):
63+
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
6164
self._checkpoint_io = io
6265

66+
@property # type: ignore[override]
67+
def precision(self) -> XLAPrecision:
68+
plugin = self._precision
69+
if plugin is not None:
70+
assert isinstance(plugin, XLAPrecision)
71+
return plugin
72+
return XLAPrecision("32-true")
73+
74+
@precision.setter
75+
def precision(self, precision: Optional[XLAPrecision]) -> None:
76+
if precision is not None and not isinstance(precision, XLAPrecision):
77+
raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision}")
78+
self._precision = precision
79+
6380
@classmethod
6481
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
6582
strategy_registry.register("single_xla", cls, description=cls.__name__)

src/lightning/fabric/strategies/xla.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222

2323
from lightning.fabric.accelerators import Accelerator
2424
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1, _using_pjrt
25+
from lightning.fabric.plugins import XLAPrecision
2526
from lightning.fabric.plugins.environments import XLAEnvironment
26-
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
2727
from lightning.fabric.plugins.io.xla import XLACheckpointIO
28-
from lightning.fabric.plugins.precision import Precision
2928
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
3029
from lightning.fabric.strategies.launchers.xla import _XLALauncher
3130
from lightning.fabric.strategies.strategy import TBroadcast
@@ -44,8 +43,8 @@ def __init__(
4443
self,
4544
accelerator: Optional[Accelerator] = None,
4645
parallel_devices: Optional[List[torch.device]] = None,
47-
checkpoint_io: Optional[CheckpointIO] = None,
48-
precision: Optional[Precision] = None,
46+
checkpoint_io: Optional[XLACheckpointIO] = None,
47+
precision: Optional[XLAPrecision] = None,
4948
sync_module_states: bool = True,
5049
) -> None:
5150
super().__init__(
@@ -55,7 +54,6 @@ def __init__(
5554
checkpoint_io=checkpoint_io,
5655
precision=precision,
5756
)
58-
self._checkpoint_io: Optional[CheckpointIO]
5957
self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call
6058
self._launched = False
6159
self._sync_module_states = sync_module_states
@@ -72,16 +70,34 @@ def root_device(self) -> torch.device:
7270
def num_processes(self) -> int:
7371
return len(self.parallel_devices) if self.parallel_devices is not None else 0
7472

75-
@property
76-
def checkpoint_io(self) -> CheckpointIO:
77-
if self._checkpoint_io is None:
78-
self._checkpoint_io = XLACheckpointIO()
79-
return self._checkpoint_io
73+
@property # type: ignore[override]
74+
def checkpoint_io(self) -> XLACheckpointIO:
75+
plugin = self._checkpoint_io
76+
if plugin is not None:
77+
assert isinstance(plugin, XLACheckpointIO)
78+
return plugin
79+
return XLACheckpointIO()
8080

8181
@checkpoint_io.setter
82-
def checkpoint_io(self, io: CheckpointIO) -> None:
82+
def checkpoint_io(self, io: Optional[XLACheckpointIO]) -> None:
83+
if io is not None and not isinstance(io, XLACheckpointIO):
84+
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
8385
self._checkpoint_io = io
8486

87+
@property # type: ignore[override]
88+
def precision(self) -> XLAPrecision:
89+
plugin = self._precision
90+
if plugin is not None:
91+
assert isinstance(plugin, XLAPrecision)
92+
return plugin
93+
return XLAPrecision("32-true")
94+
95+
@precision.setter
96+
def precision(self, precision: Optional[XLAPrecision]) -> None:
97+
if precision is not None and not isinstance(precision, XLAPrecision):
98+
raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision}")
99+
self._precision = precision
100+
85101
@property
86102
def global_rank(self) -> int:
87103
return super().global_rank if self._launched else 0

src/lightning/fabric/strategies/xla_fsdp.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@
2424
from torch.utils.data import DataLoader
2525

2626
from lightning.fabric.accelerators import Accelerator
27-
from lightning.fabric.accelerators.xla import _using_pjrt
27+
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _using_pjrt
2828
from lightning.fabric.plugins import XLAPrecision
2929
from lightning.fabric.plugins.environments import XLAEnvironment
30-
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
3130
from lightning.fabric.plugins.io.xla import XLACheckpointIO
3231
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
3332
from lightning.fabric.strategies.fsdp import _apply_filter
@@ -85,22 +84,23 @@ def __init__(
8584
self,
8685
accelerator: Optional[Accelerator] = None,
8786
parallel_devices: Optional[List[torch.device]] = None,
88-
checkpoint_io: Optional[CheckpointIO] = None,
87+
checkpoint_io: Optional[XLACheckpointIO] = None,
8988
precision: Optional[XLAPrecision] = None,
9089
auto_wrap_policy: Optional[_POLICY] = None,
9190
activation_checkpointing_policy: Optional[_POLICY_SET] = None,
9291
state_dict_type: Literal["full", "sharded"] = "sharded",
9392
sequential_save: bool = False,
9493
**kwargs: Any,
9594
) -> None:
95+
if not _XLA_AVAILABLE:
96+
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
9697
super().__init__(
9798
accelerator=accelerator,
9899
parallel_devices=parallel_devices,
99100
cluster_environment=XLAEnvironment(),
100101
checkpoint_io=checkpoint_io,
101102
precision=precision,
102103
)
103-
self._checkpoint_io: Optional[CheckpointIO]
104104
self._backward_sync_control = _XLAFSDPBackwardSyncControl()
105105

106106
self._auto_wrap_policy = auto_wrap_policy
@@ -122,16 +122,34 @@ def root_device(self) -> torch.device:
122122
def num_processes(self) -> int:
123123
return len(self.parallel_devices) if self.parallel_devices is not None else 0
124124

125-
@property
126-
def checkpoint_io(self) -> CheckpointIO:
127-
if self._checkpoint_io is None:
128-
self._checkpoint_io = XLACheckpointIO()
129-
return self._checkpoint_io
125+
@property # type: ignore[override]
126+
def checkpoint_io(self) -> XLACheckpointIO:
127+
plugin = self._checkpoint_io
128+
if plugin is not None:
129+
assert isinstance(plugin, XLACheckpointIO)
130+
return plugin
131+
return XLACheckpointIO()
130132

131133
@checkpoint_io.setter
132-
def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
134+
def checkpoint_io(self, io: Optional[XLACheckpointIO]) -> None:
135+
if io is not None and not isinstance(io, XLACheckpointIO):
136+
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
133137
self._checkpoint_io = io
134138

139+
@property # type: ignore[override]
140+
def precision(self) -> XLAPrecision:
141+
plugin = self._precision
142+
if plugin is not None:
143+
assert isinstance(plugin, XLAPrecision)
144+
return plugin
145+
return XLAPrecision("32-true")
146+
147+
@precision.setter
148+
def precision(self, precision: Optional[XLAPrecision]) -> None:
149+
if precision is not None and not isinstance(precision, XLAPrecision):
150+
raise TypeError(f"The XLA FSDP strategy can only work with the `XLAPrecision` plugin, found {precision}")
151+
self._precision = precision
152+
135153
@property
136154
def global_rank(self) -> int:
137155
return super().global_rank if self._launched else 0
@@ -227,21 +245,8 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
227245
flattened parameters.
228246
229247
"""
230-
if _TORCH_GREATER_EQUAL_2_0:
231-
return optimizer
232-
233-
from torch_xla.distributed.fsdp.xla_flatten_params_wrapper import FlatParameter
234-
235-
num_groups = len(optimizer.param_groups)
236-
if num_groups > 1:
237-
raise ValueError(
238-
"An optimizer used with an XLAFSDP model does not support multiple param groups."
239-
f" Found {num_groups} parameter groups."
240-
)
241-
242-
if any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]):
248+
if any(getattr(p, "_is_sharded", False) for group in optimizer.param_groups for p in group["params"]):
243249
return optimizer
244-
245250
raise ValueError(
246251
"The optimizer does not seem to reference any XLAFSDP parameters. HINT: Make sure to create the optimizer"
247252
" after setting up the model."

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -327,16 +327,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
327327
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
328328
if self.kwargs.get("use_orig_params"):
329329
return super().setup_optimizers(trainer)
330-
331-
invalid_params_error = False
332-
try:
333-
super().setup_optimizers(trainer)
334-
except ValueError as ex:
335-
if "optimizer got an empty parameter list" not in str(ex):
336-
raise
337-
invalid_params_error = True
338-
339-
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
330+
if any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
340331
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
341332
raise ValueError(
342333
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"

src/lightning/pytorch/strategies/launchers/xla.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
# limitations under the License.
1414
import os
1515
import queue
16-
from typing import Any, Callable, Optional, Union
16+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
1717

1818
import torch.multiprocessing as mp
1919

20-
import lightning.pytorch as pl
2120
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _using_pjrt
2221
from lightning.fabric.strategies.launchers.xla import _rank_teardown
2322
from lightning.fabric.utilities import move_data_to_device
@@ -29,6 +28,9 @@
2928
from lightning.pytorch.trainer.states import TrainerFn
3029
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
3130

31+
if TYPE_CHECKING:
32+
import lightning.pytorch as pl
33+
3234

3335
class _XLALauncher(_MultiProcessingLauncher):
3436
r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the
@@ -145,12 +147,11 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
145147
else None
146148
)
147149

148-
# requires to compute the state_dict on all processes in case Metrics are present
149-
state_dict = trainer.lightning_module.state_dict()
150-
151150
# save the last weights
152151
weights_path = None
153152
if trainer.state.fn == TrainerFn.FITTING:
153+
# requires to compute the state_dict on all processes in case Metrics are present
154+
state_dict = self._strategy.lightning_module_state_dict()
154155
weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
155156
self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path)
156157

0 commit comments

Comments
 (0)