Skip to content

Commit ebe167e

Browse files
committed
XLA ddp trainer
1 parent e68c226 commit ebe167e

File tree

1 file changed

+30
-168
lines changed
  • src/lightning/pytorch/strategies

1 file changed

+30
-168
lines changed

src/lightning/pytorch/strategies/xla.py

Lines changed: 30 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import io
15-
import os
1614
from typing import TYPE_CHECKING, Any, Optional, Union
1715

1816
import torch
@@ -21,20 +19,16 @@
2119
from typing_extensions import override
2220

2321
import lightning.pytorch as pl
24-
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _XLA_GREATER_EQUAL_2_1
2522
from lightning.fabric.plugins import CheckpointIO, Precision, XLACheckpointIO
2623
from lightning.fabric.plugins.environments import XLAEnvironment
2724
from lightning.fabric.strategies import _StrategyRegistry
28-
from lightning.fabric.utilities.optimizer import _optimizers_to_device
25+
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2926
from lightning.fabric.utilities.types import _PATH, ReduceOp
3027
from lightning.pytorch.plugins import XLAPrecision
3128
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
3229
from lightning.pytorch.strategies.ddp import DDPStrategy
3330
from lightning.pytorch.strategies.launchers.xla import _XLALauncher
3431
from lightning.pytorch.strategies.strategy import TBroadcast
35-
from lightning.pytorch.trainer.states import TrainerFn
36-
from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters
37-
from lightning.pytorch.utilities.rank_zero import rank_zero_only
3832

3933
if TYPE_CHECKING:
4034
from torch_xla.distributed.parallel_loader import MpDeviceLoader
@@ -56,8 +50,6 @@ def __init__(
5650
sync_module_states: bool = True,
5751
**_: Any,
5852
) -> None:
59-
if not _XLA_AVAILABLE:
60-
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
6153
super().__init__(
6254
accelerator=accelerator,
6355
parallel_devices=parallel_devices,
@@ -66,9 +58,12 @@ def __init__(
6658
precision_plugin=precision_plugin,
6759
start_method="fork",
6860
)
69-
self.debug = debug
70-
self._launched = False
71-
self._sync_module_states = sync_module_states
61+
_raise_enterprise_not_available()
62+
from pytorch_lightning_enterprise.strategies.xla.ddp import XLAStrategyTrainer as EnterpriseXLAStrategy
63+
64+
self.xla_strategy_impl = EnterpriseXLAStrategy(
65+
outer_object=self, debug=debug, sync_module_states=sync_module_states
66+
)
7267

7368
@property
7469
@override
@@ -105,145 +100,64 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
105100
@property
106101
@override
107102
def root_device(self) -> torch.device:
108-
if not self._launched:
109-
raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
110-
import torch_xla.core.xla_model as xm
111-
112-
return xm.xla_device()
103+
return self.xla_strategy_impl.root_device
113104

114105
@property
115106
@override
116107
def global_rank(self) -> int:
117-
return super().global_rank if self._launched else 0
108+
return self.xla_strategy_impl.global_rank
118109

119110
@property
120111
@override
121112
def local_rank(self) -> int:
122-
return super().local_rank if self._launched else 0
113+
return self.xla_strategy_impl.local_rank
123114

124115
@property
125116
@override
126117
def node_rank(self) -> int:
127-
return super().node_rank if self._launched else 0
118+
return self.xla_strategy_impl.node_rank
128119

129120
@property
130121
@override
131122
def world_size(self) -> int:
132-
return super().world_size if self._launched else 1
123+
return self.xla_strategy_impl.world_size
133124

134125
@override
135126
def _configure_launcher(self) -> None:
136127
self._launcher = _XLALauncher(self)
137128

138129
@override
139130
def setup(self, trainer: "pl.Trainer") -> None:
140-
assert self.accelerator is not None
141-
self.accelerator.setup(trainer)
142-
143-
if self.debug:
144-
os.environ["PT_XLA_DEBUG"] = "1"
145-
146-
assert self.model is not None
147-
self.precision_plugin.convert_module(self.model)
148-
149-
shared_params = find_shared_parameters(self.model)
150-
self.model_to_device()
151-
set_shared_parameters(self.model, shared_params)
152-
153-
self.model = self._setup_model(self.model)
154-
155-
if self._sync_module_states:
156-
if _XLA_GREATER_EQUAL_2_1:
157-
from torch_xla.core.xla_model import broadcast_master_param
158-
else:
159-
from torch_xla.experimental.pjrt import broadcast_master_param
160-
161-
broadcast_master_param(self.model)
162-
163-
if trainer.state.fn == TrainerFn.FITTING:
164-
self.setup_optimizers(trainer)
165-
self.setup_precision_plugin()
166-
if trainer.state.fn == TrainerFn.FITTING:
167-
_optimizers_to_device(self.optimizers, self.root_device)
131+
return self.xla_strategy_impl.setup(trainer=trainer)
168132

169133
@override
170134
def _setup_model(self, model: Module) -> Module: # type: ignore
171-
return model
135+
return self.xla_strategy_impl._setup_model(model=model)
172136

173137
@property
174138
@override
175139
def distributed_sampler_kwargs(self) -> dict[str, int]:
176-
return {"num_replicas": self.world_size, "rank": self.global_rank}
140+
return self.xla_strategy_impl.distributed_sampler_kwargs
177141

178142
@override
179143
def process_dataloader(self, dataloader: object) -> "MpDeviceLoader":
180-
from torch_xla.distributed.parallel_loader import MpDeviceLoader
181-
182-
if isinstance(dataloader, MpDeviceLoader):
183-
# dataloader is already wrapped by MpDeviceLoader
184-
return dataloader
185-
186-
dataloader = MpDeviceLoader(dataloader, self.root_device)
187-
# Mimic interface to torch.utils.data.DataLoader
188-
dataloader.dataset = dataloader._loader.dataset
189-
dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None)
190-
return dataloader
144+
return self.xla_strategy_impl.process_dataloader(dataloader=dataloader)
191145

192146
@override
193147
def configure_ddp(self) -> None:
194-
pass
148+
return self.xla_strategy_impl.configure_ddp()
195149

196150
@override
197151
def model_to_device(self) -> None:
198-
assert self.model is not None
199-
self.model = self.model.to(self.root_device)
152+
return self.xla_strategy_impl.model_to_device()
200153

201154
@override
202155
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
203-
if not self._launched:
204-
return
205-
206-
import torch_xla.core.xla_model as xm
207-
208-
if name is None:
209-
# `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments"
210-
name = ""
211-
xm.rendezvous(name)
156+
return self.xla_strategy_impl.barrier(name=name, *args, **kwargs)
212157

213158
@override
214159
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
215-
if not self._launched:
216-
return obj
217-
218-
import torch_xla.core.xla_model as xm
219-
220-
is_tensor = isinstance(obj, Tensor)
221-
if is_tensor:
222-
if obj.dim() == 0:
223-
obj = obj.unsqueeze(0)
224-
original_device = obj.device
225-
# XLA distributed requires that the data is on the XLA device
226-
obj = obj.to(self.root_device)
227-
else:
228-
# support for arbitrary pickle-ables
229-
buffer = io.BytesIO()
230-
torch.save(obj, buffer)
231-
obj = torch.tensor( # type: ignore[assignment]
232-
bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float
233-
)
234-
235-
obj = [obj]
236-
xm.collective_broadcast(obj, root_ordinal=src)
237-
obj = obj[0]
238-
239-
if not is_tensor:
240-
# this will preserve the dtype and device of any tensors
241-
buffer = io.BytesIO(obj.cpu().byte().numpy())
242-
obj = torch.load(buffer)
243-
else:
244-
obj = obj.to(original_device)
245-
246-
return obj
160+
return self.xla_strategy_impl.broadcast(obj=obj, src=src)
247161

248162
@override
249163
def reduce(
@@ -252,60 +166,27 @@ def reduce(
252166
group: Optional[Any] = None,
253167
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
254168
) -> Tensor:
255-
if not isinstance(output, Tensor):
256-
output = torch.tensor(output, device=self.root_device)
257-
258-
invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
259-
invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
260-
if invalid_reduce_op or invalid_reduce_op_str:
261-
raise ValueError(
262-
"Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
263-
f" {reduce_op}"
264-
)
265-
266-
import torch_xla.core.xla_model as xm
267-
268-
output = xm.mesh_reduce("reduce", output, sum)
269-
270-
if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
271-
output = output / self.world_size
272-
273-
return output
169+
return self.xla_strategy_impl.reduce(output=output, group=group, reduce_op=reduce_op)
274170

275171
@override
276172
def setup_environment(self) -> None:
277-
self._launched = True
278-
super().setup_environment()
173+
return self.xla_strategy_impl.setup_environment()
279174

280175
@override
281176
def setup_distributed(self) -> None:
282-
assert self.parallel_devices is not None
283-
if len(self.parallel_devices) == 1:
284-
# spawning only 1 device with PjRT is not supported:
285-
# https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732
286-
raise NotImplementedError(
287-
"The `XLAStrategy` does not support running on a single device with the PjRT runtime."
288-
" Try using all devices or the `SingleDeviceXLAStrategy` strategy"
289-
)
290-
rank_zero_only.rank = self.global_rank
177+
return self.xla_strategy_impl.setup_distributed()
291178

292179
@override
293180
def set_world_ranks(self) -> None:
294-
# accessing global_rank will initialize the XLA computation client. since this is called outside of the spawned
295-
# processes (by the accelerator connector), we cannot run the code that would normally be here.
296-
# instead it's done in `setup_distributed`
297-
pass
181+
return self.xla_strategy_impl.set_world_ranks()
298182

299183
@override
300184
def save_checkpoint(
301185
self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
302186
) -> None:
303-
import torch_xla.core.xla_model as xm
304-
305-
# sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
306-
xm.mark_step()
307-
# save on global rank zero only
308-
super().save_checkpoint(checkpoint, filepath, storage_options=storage_options)
187+
return self.xla_strategy_impl.save_checkpoint(
188+
checkpoint=checkpoint, filepath=filepath, storage_options=storage_options
189+
)
309190

310191
@override
311192
def remove_checkpoint(self, filepath: _PATH) -> None:
@@ -315,8 +196,7 @@ def remove_checkpoint(self, filepath: _PATH) -> None:
315196
filepath: Path to checkpoint
316197
317198
"""
318-
if self.local_rank == 0:
319-
self.checkpoint_io.remove_checkpoint(filepath)
199+
return self.xla_strategy_impl.remove_checkpoint(filepath=filepath)
320200

321201
@override
322202
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
@@ -330,29 +210,11 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
330210
A tensor of shape (world_size, ...)
331211
332212
"""
333-
if not self._launched:
334-
return tensor
335-
if not isinstance(tensor, Tensor):
336-
raise NotImplementedError(
337-
f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}"
338-
)
339-
if tensor.dim() == 0:
340-
tensor = tensor.unsqueeze(0)
341-
original_device = tensor.device
342-
tensor = tensor.to(self.root_device)
343-
344-
import torch_xla.core.functions as xf
345-
import torch_xla.core.xla_model as xm
346-
347-
tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
348-
tensor = tensor.to(original_device)
349-
return tensor
213+
return self.xla_strategy_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads)
350214

351215
@override
352216
def teardown(self) -> None:
353-
super().teardown()
354-
self._launched = False # after the Trainer finishes, we aren't inside the spawned region
355-
os.environ.pop("PT_XLA_DEBUG", None)
217+
return self.xla_strategy_impl.teardown()
356218

357219
@classmethod
358220
@override

0 commit comments

Comments
 (0)