Skip to content

Commit 467b935

Browse files
committed
integrate xla strategies
1 parent c63d855 commit 467b935

File tree

3 files changed

+71
-572
lines changed

3 files changed

+71
-572
lines changed

src/lightning/fabric/strategies/single_xla.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
# limitations under the License.
1414
from typing import Optional
1515

16-
import torch
1716
from typing_extensions import override
1817

1918
from lightning.fabric.accelerators import Accelerator
20-
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
2119
from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision
2220
from lightning.fabric.plugins.io.xla import XLACheckpointIO
2321
from lightning.fabric.strategies import _StrategyRegistry
2422
from lightning.fabric.strategies.single_device import SingleDeviceStrategy
23+
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2524
from lightning.fabric.utilities.types import _DEVICE
2625

2726

@@ -35,20 +34,16 @@ def __init__(
3534
checkpoint_io: Optional[XLACheckpointIO] = None,
3635
precision: Optional[XLAPrecision] = None,
3736
):
38-
if not _XLA_AVAILABLE:
39-
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
40-
if isinstance(device, torch.device):
41-
# unwrap the `torch.device` in favor of `xla_device`
42-
device = device.index
43-
44-
import torch_xla.core.xla_model as xm
37+
_raise_enterprise_not_available()
38+
from pytorch_lightning_enterprise.fabric.strategies.xla.single import validate_xla_strategy
4539

4640
super().__init__(
4741
accelerator=accelerator,
48-
device=xm.xla_device(device),
42+
device=device,
4943
checkpoint_io=checkpoint_io,
5044
precision=precision,
5145
)
46+
validate_xla_strategy(strategy=self, device=device)
5247

5348
@property
5449
@override

src/lightning/fabric/strategies/xla.py

Lines changed: 22 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import io
1514
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
1615

1716
import torch
@@ -22,14 +21,13 @@
2221
from typing_extensions import override
2322

2423
from lightning.fabric.accelerators import Accelerator
25-
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
2624
from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision
2725
from lightning.fabric.plugins.environments import XLAEnvironment
2826
from lightning.fabric.plugins.io.xla import XLACheckpointIO
2927
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
3028
from lightning.fabric.strategies.launchers.xla import _XLALauncher
3129
from lightning.fabric.strategies.strategy import TBroadcast
32-
from lightning.fabric.utilities.rank_zero import rank_zero_only
30+
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
3331
from lightning.fabric.utilities.types import _PATH, ReduceOp
3432

3533
if TYPE_CHECKING:
@@ -55,22 +53,19 @@ def __init__(
5553
checkpoint_io=checkpoint_io,
5654
precision=precision,
5755
)
58-
self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call
59-
self._launched = False
60-
self._sync_module_states = sync_module_states
56+
_raise_enterprise_not_available()
57+
from pytorch_lightning_enterprise.fabric.strategies.xla.ddp import XLAStrategyFabric as EnterpriseXLAStrategy
58+
59+
self.xla_strategy_impl = EnterpriseXLAStrategy(outer_object=self, sync_module_states=sync_module_states)
6160

6261
@property
6362
@override
6463
def root_device(self) -> torch.device:
65-
if not self._launched:
66-
raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
67-
import torch_xla.core.xla_model as xm
68-
69-
return xm.xla_device()
64+
return self.xla_strategy_impl.root_device
7065

7166
@property
7267
def num_processes(self) -> int:
73-
return len(self.parallel_devices) if self.parallel_devices is not None else 0
68+
return self.xla_strategy_impl.num_processes
7469

7570
@property
7671
@override
@@ -107,71 +102,42 @@ def precision(self, precision: Optional[Precision]) -> None:
107102
@property
108103
@override
109104
def global_rank(self) -> int:
110-
return super().global_rank if self._launched else 0
105+
return self.xla_strategy_impl.global_rank
111106

112107
@property
113108
@override
114109
def local_rank(self) -> int:
115-
return super().local_rank if self._launched else 0
110+
return self.xla_strategy_impl.local_rank
116111

117112
@property
118113
@override
119114
def node_rank(self) -> int:
120-
return super().node_rank if self._launched else 0
115+
return self.xla_strategy_impl.node_rank
121116

122117
@property
123118
@override
124119
def world_size(self) -> int:
125-
return super().world_size if self._launched else 1
120+
return self.xla_strategy_impl.world_size
126121

127122
@override
128123
def _configure_launcher(self) -> None:
129124
self._launcher = _XLALauncher(self)
130125

131126
@override
132127
def setup_environment(self) -> None:
133-
assert self.parallel_devices is not None
134-
if len(self.parallel_devices) == 1:
135-
# spawning only 1 device with PjRT is not supported:
136-
# https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732
137-
raise NotImplementedError(
138-
f"The {type(self).__name__} does not support running on a single device with the PjRT runtime."
139-
" Try using all devices or the `SingleDeviceXLAStrategy` strategy"
140-
)
141-
142-
self._launched = True
143-
rank_zero_only.rank = self.global_rank
144-
super().setup_environment()
128+
return self.xla_strategy_impl.setup_environment()
145129

146130
@override
147131
def setup_module(self, module: Module) -> Module:
148-
if self._sync_module_states:
149-
if _XLA_GREATER_EQUAL_2_1:
150-
from torch_xla.core.xla_model import broadcast_master_param
151-
else:
152-
from torch_xla.experimental.pjrt import broadcast_master_param
153-
154-
broadcast_master_param(module)
155-
156-
return module
132+
return self.xla_strategy_impl.setup_module(module=module)
157133

158134
@override
159135
def module_to_device(self, module: Module) -> None:
160-
module.to(self.root_device)
136+
return self.xla_strategy_impl.module_to_device(module=module)
161137

162138
@override
163139
def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader":
164-
from torch_xla.distributed.parallel_loader import MpDeviceLoader
165-
166-
if isinstance(dataloader, MpDeviceLoader):
167-
# dataloader is already wrapped by MpDeviceLoader
168-
return dataloader
169-
170-
dataloader = MpDeviceLoader(dataloader, self.root_device)
171-
# Mimic interface to torch.utils.data.DataLoader
172-
dataloader.dataset = dataloader._loader.dataset
173-
dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None)
174-
return dataloader
140+
return self.xla_strategy_impl.process_dataloader(dataloader=dataloader)
175141

176142
@override
177143
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
@@ -185,92 +151,21 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
185151
A tensor of shape (world_size, ...)
186152
187153
"""
188-
if not self._launched:
189-
return tensor
190-
if not isinstance(tensor, Tensor):
191-
raise NotImplementedError(
192-
f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}"
193-
)
194-
if tensor.dim() == 0:
195-
tensor = tensor.unsqueeze(0)
196-
original_device = tensor.device
197-
tensor = tensor.to(self.root_device)
198-
199-
import torch_xla.core.functions as xf
200-
import torch_xla.core.xla_model as xm
201-
202-
tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
203-
tensor = tensor.to(original_device)
204-
return tensor
154+
return self.xla_strategy_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads)
205155

206156
@override
207157
def all_reduce(
208158
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
209159
) -> Tensor:
210-
if not isinstance(output, Tensor):
211-
output = torch.tensor(output, device=self.root_device)
212-
213-
invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
214-
invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
215-
if invalid_reduce_op or invalid_reduce_op_str:
216-
raise ValueError(
217-
"Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
218-
f" {reduce_op}"
219-
)
220-
import torch_xla.core.xla_model as xm
221-
222-
output = xm.mesh_reduce("reduce", output, sum)
223-
224-
if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
225-
output = output / self.world_size
226-
227-
return output
160+
return self.xla_strategy_impl.all_reduce(output=output, group=group, reduce_op=reduce_op)
228161

229162
@override
230163
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
231-
if not self._launched:
232-
return
233-
import torch_xla.core.xla_model as xm
234-
235-
if name is None:
236-
# `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments"
237-
name = ""
238-
xm.rendezvous(name)
164+
return self.xla_strategy_impl.barrier(name=name, *args, **kwargs)
239165

240166
@override
241167
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
242-
if not self._launched:
243-
return obj
244-
245-
import torch_xla.core.xla_model as xm
246-
247-
is_tensor = isinstance(obj, Tensor)
248-
if is_tensor:
249-
if obj.dim() == 0:
250-
obj = obj.unsqueeze(0)
251-
original_device = obj.device
252-
# XLA distributed requires that the data is on the XLA device
253-
obj = obj.to(self.root_device)
254-
else:
255-
# support for arbitrary pickle-ables
256-
buffer = io.BytesIO()
257-
torch.save(obj, buffer)
258-
obj = torch.tensor( # type: ignore[assignment]
259-
bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float
260-
)
261-
262-
obj = [obj]
263-
xm.collective_broadcast(obj, root_ordinal=src)
264-
obj = obj[0]
265-
266-
if not is_tensor:
267-
# this will preserve the dtype and device of any tensors
268-
buffer = io.BytesIO(obj.cpu().byte().numpy())
269-
obj = torch.load(buffer)
270-
else:
271-
obj = obj.to(original_device)
272-
273-
return obj
168+
return self.xla_strategy_impl.broadcast(obj=obj, src=src)
274169

275170
@override
276171
def save_checkpoint(
@@ -291,12 +186,9 @@ def save_checkpoint(
291186
boolean indicating whether the given parameter should be saved (``True``) or filtered out (``False``).
292187
293188
"""
294-
import torch_xla.core.xla_model as xm
295-
296-
# sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
297-
xm.mark_step()
298-
# save on global rank zero only
299-
super().save_checkpoint(path, state, storage_options=storage_options, filter=filter)
189+
return self.xla_strategy_impl.save_checkpoint(
190+
path=path, state=state, storage_options=storage_options, filter=filter
191+
)
300192

301193
@classmethod
302194
@override

0 commit comments

Comments
 (0)