1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import io
15- import os
1614from typing import TYPE_CHECKING , Any , Optional , Union
1715
1816import torch
2119from typing_extensions import override
2220
2321import lightning .pytorch as pl
24- from lightning .fabric .accelerators .xla import _XLA_AVAILABLE , _XLA_GREATER_EQUAL_2_1
2522from lightning .fabric .plugins import CheckpointIO , Precision , XLACheckpointIO
2623from lightning .fabric .plugins .environments import XLAEnvironment
2724from lightning .fabric .strategies import _StrategyRegistry
28- from lightning .fabric .utilities .optimizer import _optimizers_to_device
25+ from lightning .fabric .utilities .imports import _raise_enterprise_not_available
2926from lightning .fabric .utilities .types import _PATH , ReduceOp
3027from lightning .pytorch .plugins import XLAPrecision
3128from lightning .pytorch .plugins .io .wrapper import _WrappingCheckpointIO
3229from lightning .pytorch .strategies .ddp import DDPStrategy
3330from lightning .pytorch .strategies .launchers .xla import _XLALauncher
3431from lightning .pytorch .strategies .strategy import TBroadcast
35- from lightning .pytorch .trainer .states import TrainerFn
36- from lightning .pytorch .utilities import find_shared_parameters , set_shared_parameters
37- from lightning .pytorch .utilities .rank_zero import rank_zero_only
3832
3933if TYPE_CHECKING :
4034 from torch_xla .distributed .parallel_loader import MpDeviceLoader
@@ -56,8 +50,6 @@ def __init__(
5650 sync_module_states : bool = True ,
5751 ** _ : Any ,
5852 ) -> None :
59- if not _XLA_AVAILABLE :
60- raise ModuleNotFoundError (str (_XLA_AVAILABLE ))
6153 super ().__init__ (
6254 accelerator = accelerator ,
6355 parallel_devices = parallel_devices ,
@@ -66,9 +58,12 @@ def __init__(
6658 precision_plugin = precision_plugin ,
6759 start_method = "fork" ,
6860 )
69- self .debug = debug
70- self ._launched = False
71- self ._sync_module_states = sync_module_states
61+ _raise_enterprise_not_available ()
62+ from pytorch_lightning_enterprise .strategies .xla .ddp import XLAStrategyTrainer as EnterpriseXLAStrategy
63+
64+ self .xla_strategy_impl = EnterpriseXLAStrategy (
65+ outer_object = self , debug = debug , sync_module_states = sync_module_states
66+ )
7267
7368 @property
7469 @override
@@ -105,145 +100,64 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
105100 @property
106101 @override
107102 def root_device (self ) -> torch .device :
108- if not self ._launched :
109- raise RuntimeError ("Accessing the XLA device before processes have spawned is not allowed." )
110- import torch_xla .core .xla_model as xm
111-
112- return xm .xla_device ()
103+ return self .xla_strategy_impl .root_device
113104
114105 @property
115106 @override
116107 def global_rank (self ) -> int :
117- return super (). global_rank if self ._launched else 0
108+ return self .xla_strategy_impl . global_rank
118109
119110 @property
120111 @override
121112 def local_rank (self ) -> int :
122- return super (). local_rank if self ._launched else 0
113+ return self .xla_strategy_impl . local_rank
123114
124115 @property
125116 @override
126117 def node_rank (self ) -> int :
127- return super (). node_rank if self ._launched else 0
118+ return self .xla_strategy_impl . node_rank
128119
129120 @property
130121 @override
131122 def world_size (self ) -> int :
132- return super (). world_size if self ._launched else 1
123+ return self .xla_strategy_impl . world_size
133124
134125 @override
135126 def _configure_launcher (self ) -> None :
136127 self ._launcher = _XLALauncher (self )
137128
138129 @override
139130 def setup (self , trainer : "pl.Trainer" ) -> None :
140- assert self .accelerator is not None
141- self .accelerator .setup (trainer )
142-
143- if self .debug :
144- os .environ ["PT_XLA_DEBUG" ] = "1"
145-
146- assert self .model is not None
147- self .precision_plugin .convert_module (self .model )
148-
149- shared_params = find_shared_parameters (self .model )
150- self .model_to_device ()
151- set_shared_parameters (self .model , shared_params )
152-
153- self .model = self ._setup_model (self .model )
154-
155- if self ._sync_module_states :
156- if _XLA_GREATER_EQUAL_2_1 :
157- from torch_xla .core .xla_model import broadcast_master_param
158- else :
159- from torch_xla .experimental .pjrt import broadcast_master_param
160-
161- broadcast_master_param (self .model )
162-
163- if trainer .state .fn == TrainerFn .FITTING :
164- self .setup_optimizers (trainer )
165- self .setup_precision_plugin ()
166- if trainer .state .fn == TrainerFn .FITTING :
167- _optimizers_to_device (self .optimizers , self .root_device )
131+ return self .xla_strategy_impl .setup (trainer = trainer )
168132
169133 @override
170134 def _setup_model (self , model : Module ) -> Module : # type: ignore
171- return model
135+ return self . xla_strategy_impl . _setup_model ( model = model )
172136
173137 @property
174138 @override
175139 def distributed_sampler_kwargs (self ) -> dict [str , int ]:
176- return { "num_replicas" : self .world_size , "rank" : self . global_rank }
140+ return self .xla_strategy_impl . distributed_sampler_kwargs
177141
178142 @override
179143 def process_dataloader (self , dataloader : object ) -> "MpDeviceLoader" :
180- from torch_xla .distributed .parallel_loader import MpDeviceLoader
181-
182- if isinstance (dataloader , MpDeviceLoader ):
183- # dataloader is already wrapped by MpDeviceLoader
184- return dataloader
185-
186- dataloader = MpDeviceLoader (dataloader , self .root_device )
187- # Mimic interface to torch.utils.data.DataLoader
188- dataloader .dataset = dataloader ._loader .dataset
189- dataloader .batch_sampler = getattr (dataloader ._loader , "batch_sampler" , None )
190- return dataloader
144+ return self .xla_strategy_impl .process_dataloader (dataloader = dataloader )
191145
192146 @override
193147 def configure_ddp (self ) -> None :
194- pass
148+ return self . xla_strategy_impl . configure_ddp ()
195149
196150 @override
197151 def model_to_device (self ) -> None :
198- assert self .model is not None
199- self .model = self .model .to (self .root_device )
152+ return self .xla_strategy_impl .model_to_device ()
200153
201154 @override
202155 def barrier (self , name : Optional [str ] = None , * args : Any , ** kwargs : Any ) -> None :
203- if not self ._launched :
204- return
205-
206- import torch_xla .core .xla_model as xm
207-
208- if name is None :
209- # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments"
210- name = ""
211- xm .rendezvous (name )
156+ return self .xla_strategy_impl .barrier (name = name , * args , ** kwargs )
212157
213158 @override
214159 def broadcast (self , obj : TBroadcast , src : int = 0 ) -> TBroadcast :
215- if not self ._launched :
216- return obj
217-
218- import torch_xla .core .xla_model as xm
219-
220- is_tensor = isinstance (obj , Tensor )
221- if is_tensor :
222- if obj .dim () == 0 :
223- obj = obj .unsqueeze (0 )
224- original_device = obj .device
225- # XLA distributed requires that the data is on the XLA device
226- obj = obj .to (self .root_device )
227- else :
228- # support for arbitrary pickle-ables
229- buffer = io .BytesIO ()
230- torch .save (obj , buffer )
231- obj = torch .tensor ( # type: ignore[assignment]
232- bytearray (buffer .getbuffer ()), device = self .root_device , dtype = torch .float
233- )
234-
235- obj = [obj ]
236- xm .collective_broadcast (obj , root_ordinal = src )
237- obj = obj [0 ]
238-
239- if not is_tensor :
240- # this will preserve the dtype and device of any tensors
241- buffer = io .BytesIO (obj .cpu ().byte ().numpy ())
242- obj = torch .load (buffer )
243- else :
244- obj = obj .to (original_device )
245-
246- return obj
160+ return self .xla_strategy_impl .broadcast (obj = obj , src = src )
247161
248162 @override
249163 def reduce (
@@ -252,60 +166,27 @@ def reduce(
252166 group : Optional [Any ] = None ,
253167 reduce_op : Optional [Union [ReduceOp , str ]] = "mean" ,
254168 ) -> Tensor :
255- if not isinstance (output , Tensor ):
256- output = torch .tensor (output , device = self .root_device )
257-
258- invalid_reduce_op = isinstance (reduce_op , ReduceOp ) and reduce_op != ReduceOp .SUM
259- invalid_reduce_op_str = isinstance (reduce_op , str ) and reduce_op .lower () not in ("sum" , "mean" , "avg" )
260- if invalid_reduce_op or invalid_reduce_op_str :
261- raise ValueError (
262- "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
263- f" { reduce_op } "
264- )
265-
266- import torch_xla .core .xla_model as xm
267-
268- output = xm .mesh_reduce ("reduce" , output , sum )
269-
270- if isinstance (reduce_op , str ) and reduce_op .lower () in ("avg" , "mean" ):
271- output = output / self .world_size
272-
273- return output
169+ return self .xla_strategy_impl .reduce (output = output , group = group , reduce_op = reduce_op )
274170
275171 @override
276172 def setup_environment (self ) -> None :
277- self ._launched = True
278- super ().setup_environment ()
173+ return self .xla_strategy_impl .setup_environment ()
279174
280175 @override
281176 def setup_distributed (self ) -> None :
282- assert self .parallel_devices is not None
283- if len (self .parallel_devices ) == 1 :
284- # spawning only 1 device with PjRT is not supported:
285- # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732
286- raise NotImplementedError (
287- "The `XLAStrategy` does not support running on a single device with the PjRT runtime."
288- " Try using all devices or the `SingleDeviceXLAStrategy` strategy"
289- )
290- rank_zero_only .rank = self .global_rank
177+ return self .xla_strategy_impl .setup_distributed ()
291178
292179 @override
293180 def set_world_ranks (self ) -> None :
294- # accessing global_rank will initialize the XLA computation client. since this is called outside of the spawned
295- # processes (by the accelerator connector), we cannot run the code that would normally be here.
296- # instead it's done in `setup_distributed`
297- pass
181+ return self .xla_strategy_impl .set_world_ranks ()
298182
299183 @override
300184 def save_checkpoint (
301185 self , checkpoint : dict [str , Any ], filepath : _PATH , storage_options : Optional [Any ] = None
302186 ) -> None :
303- import torch_xla .core .xla_model as xm
304-
305- # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
306- xm .mark_step ()
307- # save on global rank zero only
308- super ().save_checkpoint (checkpoint , filepath , storage_options = storage_options )
187+ return self .xla_strategy_impl .save_checkpoint (
188+ checkpoint = checkpoint , filepath = filepath , storage_options = storage_options
189+ )
309190
310191 @override
311192 def remove_checkpoint (self , filepath : _PATH ) -> None :
@@ -315,8 +196,7 @@ def remove_checkpoint(self, filepath: _PATH) -> None:
315196 filepath: Path to checkpoint
316197
317198 """
318- if self .local_rank == 0 :
319- self .checkpoint_io .remove_checkpoint (filepath )
199+ return self .xla_strategy_impl .remove_checkpoint (filepath = filepath )
320200
321201 @override
322202 def all_gather (self , tensor : Tensor , group : Optional [Any ] = None , sync_grads : bool = False ) -> Tensor :
@@ -330,29 +210,11 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
330210 A tensor of shape (world_size, ...)
331211
332212 """
333- if not self ._launched :
334- return tensor
335- if not isinstance (tensor , Tensor ):
336- raise NotImplementedError (
337- f"`{ type (self ).__name__ } .all_gather` is only implemented for tensors. Given { tensor } "
338- )
339- if tensor .dim () == 0 :
340- tensor = tensor .unsqueeze (0 )
341- original_device = tensor .device
342- tensor = tensor .to (self .root_device )
343-
344- import torch_xla .core .functions as xf
345- import torch_xla .core .xla_model as xm
346-
347- tensor = xf .all_gather (tensor ) if sync_grads else xm .all_gather (tensor )
348- tensor = tensor .to (original_device )
349- return tensor
213+ return self .xla_strategy_impl .all_gather (tensor = tensor , group = group , sync_grads = sync_grads )
350214
351215 @override
352216 def teardown (self ) -> None :
353- super ().teardown ()
354- self ._launched = False # after the Trainer finishes, we aren't inside the spawned region
355- os .environ .pop ("PT_XLA_DEBUG" , None )
217+ return self .xla_strategy_impl .teardown ()
356218
357219 @classmethod
358220 @override
0 commit comments