1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import io
1514from typing import TYPE_CHECKING , Any , Callable , Optional , Union
1615
1716import torch
2221from typing_extensions import override
2322
2423from lightning .fabric .accelerators import Accelerator
25- from lightning .fabric .accelerators .xla import _XLA_GREATER_EQUAL_2_1
2624from lightning .fabric .plugins import CheckpointIO , Precision , XLAPrecision
2725from lightning .fabric .plugins .environments import XLAEnvironment
2826from lightning .fabric .plugins .io .xla import XLACheckpointIO
2927from lightning .fabric .strategies import ParallelStrategy , _StrategyRegistry
3028from lightning .fabric .strategies .launchers .xla import _XLALauncher
3129from lightning .fabric .strategies .strategy import TBroadcast
32- from lightning .fabric .utilities .rank_zero import rank_zero_only
30+ from lightning .fabric .utilities .imports import _raise_enterprise_not_available
3331from lightning .fabric .utilities .types import _PATH , ReduceOp
3432
3533if TYPE_CHECKING :
@@ -55,22 +53,19 @@ def __init__(
5553 checkpoint_io = checkpoint_io ,
5654 precision = precision ,
5755 )
58- self ._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call
59- self ._launched = False
60- self ._sync_module_states = sync_module_states
56+ _raise_enterprise_not_available ()
57+ from pytorch_lightning_enterprise .fabric .strategies .xla .ddp import XLAStrategyFabric as EnterpriseXLAStrategy
58+
59+ self .xla_strategy_impl = EnterpriseXLAStrategy (outer_object = self , sync_module_states = sync_module_states )
6160
6261 @property
6362 @override
6463 def root_device (self ) -> torch .device :
65- if not self ._launched :
66- raise RuntimeError ("Accessing the XLA device before processes have spawned is not allowed." )
67- import torch_xla .core .xla_model as xm
68-
69- return xm .xla_device ()
64+ return self .xla_strategy_impl .root_device
7065
7166 @property
7267 def num_processes (self ) -> int :
73- return len ( self .parallel_devices ) if self . parallel_devices is not None else 0
68+ return self .xla_strategy_impl . num_processes
7469
7570 @property
7671 @override
@@ -107,71 +102,42 @@ def precision(self, precision: Optional[Precision]) -> None:
107102 @property
108103 @override
109104 def global_rank (self ) -> int :
110- return super (). global_rank if self ._launched else 0
105+ return self .xla_strategy_impl . global_rank
111106
112107 @property
113108 @override
114109 def local_rank (self ) -> int :
115- return super (). local_rank if self ._launched else 0
110+ return self .xla_strategy_impl . local_rank
116111
117112 @property
118113 @override
119114 def node_rank (self ) -> int :
120- return super (). node_rank if self ._launched else 0
115+ return self .xla_strategy_impl . node_rank
121116
122117 @property
123118 @override
124119 def world_size (self ) -> int :
125- return super (). world_size if self ._launched else 1
120+ return self .xla_strategy_impl . world_size
126121
127122 @override
128123 def _configure_launcher (self ) -> None :
129124 self ._launcher = _XLALauncher (self )
130125
131126 @override
132127 def setup_environment (self ) -> None :
133- assert self .parallel_devices is not None
134- if len (self .parallel_devices ) == 1 :
135- # spawning only 1 device with PjRT is not supported:
136- # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732
137- raise NotImplementedError (
138- f"The { type (self ).__name__ } does not support running on a single device with the PjRT runtime."
139- " Try using all devices or the `SingleDeviceXLAStrategy` strategy"
140- )
141-
142- self ._launched = True
143- rank_zero_only .rank = self .global_rank
144- super ().setup_environment ()
128+ return self .xla_strategy_impl .setup_environment ()
145129
146130 @override
147131 def setup_module (self , module : Module ) -> Module :
148- if self ._sync_module_states :
149- if _XLA_GREATER_EQUAL_2_1 :
150- from torch_xla .core .xla_model import broadcast_master_param
151- else :
152- from torch_xla .experimental .pjrt import broadcast_master_param
153-
154- broadcast_master_param (module )
155-
156- return module
132+ return self .xla_strategy_impl .setup_module (module = module )
157133
158134 @override
159135 def module_to_device (self , module : Module ) -> None :
160- module . to ( self .root_device )
136+ return self .xla_strategy_impl . module_to_device ( module = module )
161137
162138 @override
163139 def process_dataloader (self , dataloader : DataLoader ) -> "MpDeviceLoader" :
164- from torch_xla .distributed .parallel_loader import MpDeviceLoader
165-
166- if isinstance (dataloader , MpDeviceLoader ):
167- # dataloader is already wrapped by MpDeviceLoader
168- return dataloader
169-
170- dataloader = MpDeviceLoader (dataloader , self .root_device )
171- # Mimic interface to torch.utils.data.DataLoader
172- dataloader .dataset = dataloader ._loader .dataset
173- dataloader .batch_sampler = getattr (dataloader ._loader , "batch_sampler" , None )
174- return dataloader
140+ return self .xla_strategy_impl .process_dataloader (dataloader = dataloader )
175141
176142 @override
177143 def all_gather (self , tensor : Tensor , group : Optional [Any ] = None , sync_grads : bool = False ) -> Tensor :
@@ -185,92 +151,21 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
185151 A tensor of shape (world_size, ...)
186152
187153 """
188- if not self ._launched :
189- return tensor
190- if not isinstance (tensor , Tensor ):
191- raise NotImplementedError (
192- f"`{ type (self ).__name__ } .all_gather` is only implemented for tensors. Given { tensor } "
193- )
194- if tensor .dim () == 0 :
195- tensor = tensor .unsqueeze (0 )
196- original_device = tensor .device
197- tensor = tensor .to (self .root_device )
198-
199- import torch_xla .core .functions as xf
200- import torch_xla .core .xla_model as xm
201-
202- tensor = xf .all_gather (tensor ) if sync_grads else xm .all_gather (tensor )
203- tensor = tensor .to (original_device )
204- return tensor
154+ return self .xla_strategy_impl .all_gather (tensor = tensor , group = group , sync_grads = sync_grads )
205155
206156 @override
207157 def all_reduce (
208158 self , output : Union [Tensor , Any ], group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = None
209159 ) -> Tensor :
210- if not isinstance (output , Tensor ):
211- output = torch .tensor (output , device = self .root_device )
212-
213- invalid_reduce_op = isinstance (reduce_op , ReduceOp ) and reduce_op != ReduceOp .SUM
214- invalid_reduce_op_str = isinstance (reduce_op , str ) and reduce_op .lower () not in ("sum" , "mean" , "avg" )
215- if invalid_reduce_op or invalid_reduce_op_str :
216- raise ValueError (
217- "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
218- f" { reduce_op } "
219- )
220- import torch_xla .core .xla_model as xm
221-
222- output = xm .mesh_reduce ("reduce" , output , sum )
223-
224- if isinstance (reduce_op , str ) and reduce_op .lower () in ("avg" , "mean" ):
225- output = output / self .world_size
226-
227- return output
160+ return self .xla_strategy_impl .all_reduce (output = output , group = group , reduce_op = reduce_op )
228161
229162 @override
230163 def barrier (self , name : Optional [str ] = None , * args : Any , ** kwargs : Any ) -> None :
231- if not self ._launched :
232- return
233- import torch_xla .core .xla_model as xm
234-
235- if name is None :
236- # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments"
237- name = ""
238- xm .rendezvous (name )
164+ return self .xla_strategy_impl .barrier (name = name , * args , ** kwargs )
239165
240166 @override
241167 def broadcast (self , obj : TBroadcast , src : int = 0 ) -> TBroadcast :
242- if not self ._launched :
243- return obj
244-
245- import torch_xla .core .xla_model as xm
246-
247- is_tensor = isinstance (obj , Tensor )
248- if is_tensor :
249- if obj .dim () == 0 :
250- obj = obj .unsqueeze (0 )
251- original_device = obj .device
252- # XLA distributed requires that the data is on the XLA device
253- obj = obj .to (self .root_device )
254- else :
255- # support for arbitrary pickle-ables
256- buffer = io .BytesIO ()
257- torch .save (obj , buffer )
258- obj = torch .tensor ( # type: ignore[assignment]
259- bytearray (buffer .getbuffer ()), device = self .root_device , dtype = torch .float
260- )
261-
262- obj = [obj ]
263- xm .collective_broadcast (obj , root_ordinal = src )
264- obj = obj [0 ]
265-
266- if not is_tensor :
267- # this will preserve the dtype and device of any tensors
268- buffer = io .BytesIO (obj .cpu ().byte ().numpy ())
269- obj = torch .load (buffer )
270- else :
271- obj = obj .to (original_device )
272-
273- return obj
168+ return self .xla_strategy_impl .broadcast (obj = obj , src = src )
274169
275170 @override
276171 def save_checkpoint (
@@ -291,12 +186,9 @@ def save_checkpoint(
291186 boolean indicating whether the given parameter should be saved (``True``) or filtered out (``False``).
292187
293188 """
294- import torch_xla .core .xla_model as xm
295-
296- # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
297- xm .mark_step ()
298- # save on global rank zero only
299- super ().save_checkpoint (path , state , storage_options = storage_options , filter = filter )
189+ return self .xla_strategy_impl .save_checkpoint (
190+ path = path , state = state , storage_options = storage_options , filter = filter
191+ )
300192
301193 @classmethod
302194 @override
0 commit comments