2424from torch .utils .data import DataLoader
2525
2626from lightning .fabric .accelerators import Accelerator
27- from lightning .fabric .accelerators .xla import _using_pjrt
27+ from lightning .fabric .accelerators .xla import _XLA_AVAILABLE , _using_pjrt
2828from lightning .fabric .plugins import XLAPrecision
2929from lightning .fabric .plugins .environments import XLAEnvironment
30- from lightning .fabric .plugins .io .checkpoint_io import CheckpointIO
3130from lightning .fabric .plugins .io .xla import XLACheckpointIO
3231from lightning .fabric .strategies import ParallelStrategy , _StrategyRegistry
3332from lightning .fabric .strategies .fsdp import _apply_filter
@@ -85,22 +84,23 @@ def __init__(
8584 self ,
8685 accelerator : Optional [Accelerator ] = None ,
8786 parallel_devices : Optional [List [torch .device ]] = None ,
88- checkpoint_io : Optional [CheckpointIO ] = None ,
87+ checkpoint_io : Optional [XLACheckpointIO ] = None ,
8988 precision : Optional [XLAPrecision ] = None ,
9089 auto_wrap_policy : Optional [_POLICY ] = None ,
9190 activation_checkpointing_policy : Optional [_POLICY_SET ] = None ,
9291 state_dict_type : Literal ["full" , "sharded" ] = "sharded" ,
9392 sequential_save : bool = False ,
9493 ** kwargs : Any ,
9594 ) -> None :
95+ if not _XLA_AVAILABLE :
96+ raise ModuleNotFoundError (str (_XLA_AVAILABLE ))
9697 super ().__init__ (
9798 accelerator = accelerator ,
9899 parallel_devices = parallel_devices ,
99100 cluster_environment = XLAEnvironment (),
100101 checkpoint_io = checkpoint_io ,
101102 precision = precision ,
102103 )
103- self ._checkpoint_io : Optional [CheckpointIO ]
104104 self ._backward_sync_control = _XLAFSDPBackwardSyncControl ()
105105
106106 self ._auto_wrap_policy = auto_wrap_policy
@@ -122,16 +122,34 @@ def root_device(self) -> torch.device:
122122 def num_processes (self ) -> int :
123123 return len (self .parallel_devices ) if self .parallel_devices is not None else 0
124124
125- @property
126- def checkpoint_io (self ) -> CheckpointIO :
127- if self ._checkpoint_io is None :
128- self ._checkpoint_io = XLACheckpointIO ()
129- return self ._checkpoint_io
125+ @property # type: ignore[override]
126+ def checkpoint_io (self ) -> XLACheckpointIO :
127+ plugin = self ._checkpoint_io
128+ if plugin is not None :
129+ assert isinstance (plugin , XLACheckpointIO )
130+ return plugin
131+ return XLACheckpointIO ()
130132
131133 @checkpoint_io .setter
132- def checkpoint_io (self , io : Optional [CheckpointIO ]) -> None :
134+ def checkpoint_io (self , io : Optional [XLACheckpointIO ]) -> None :
135+ if io is not None and not isinstance (io , XLACheckpointIO ):
136+ raise TypeError (f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found { io } " )
133137 self ._checkpoint_io = io
134138
139+ @property # type: ignore[override]
140+ def precision (self ) -> XLAPrecision :
141+ plugin = self ._precision
142+ if plugin is not None :
143+ assert isinstance (plugin , XLAPrecision )
144+ return plugin
145+ return XLAPrecision ("32-true" )
146+
147+ @precision .setter
148+ def precision (self , precision : Optional [XLAPrecision ]) -> None :
149+ if precision is not None and not isinstance (precision , XLAPrecision ):
150+ raise TypeError (f"The XLA FSDP strategy can only work with the `XLAPrecision` plugin, found { precision } " )
151+ self ._precision = precision
152+
135153 @property
136154 def global_rank (self ) -> int :
137155 return super ().global_rank if self ._launched else 0
@@ -227,21 +245,8 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
227245 flattened parameters.
228246
229247 """
230- if _TORCH_GREATER_EQUAL_2_0 :
231- return optimizer
232-
233- from torch_xla .distributed .fsdp .xla_flatten_params_wrapper import FlatParameter
234-
235- num_groups = len (optimizer .param_groups )
236- if num_groups > 1 :
237- raise ValueError (
238- "An optimizer used with an XLAFSDP model does not support multiple param groups."
239- f" Found { num_groups } parameter groups."
240- )
241-
242- if any (isinstance (param , FlatParameter ) for param in optimizer .param_groups [0 ]["params" ]):
248+ if any (getattr (p , "_is_sharded" , False ) for group in optimizer .param_groups for p in group ["params" ]):
243249 return optimizer
244-
245250 raise ValueError (
246251 "The optimizer does not seem to reference any XLAFSDP parameters. HINT: Make sure to create the optimizer"
247252 " after setting up the model."
0 commit comments