Skip to content

Commit fcb8e17

Browse files
authored
[TPU] Preserve the device with XLA's collectives (#18275)
1 parent 683faaa commit fcb8e17

File tree

6 files changed

+55
-22
lines changed

6 files changed

+55
-22
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
138138
- DataLoader re-instantiation is now only performed when a distributed sampler is required ([#18191](https://github.com/Lightning-AI/lightning/pull/18191))
139139

140140

141+
- Broadcast and reduction of tensors with XLA-based strategies now preserve the input's device ([#18275](https://github.com/Lightning-AI/lightning/pull/18275))
142+
141143
### Deprecated
142144

143145
- Deprecated the `DDPStrategy.is_distributed` property. This strategy is distributed by definition ([#17381](https://github.com/Lightning-AI/lightning/pull/17381))

src/lightning/fabric/strategies/xla.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,15 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
161161
)
162162
if tensor.dim() == 0:
163163
tensor = tensor.unsqueeze(0)
164-
if tensor.device.type != "xla":
165-
tensor = tensor.to(self.root_device)
164+
original_device = tensor.device
165+
tensor = tensor.to(self.root_device)
166166

167167
import torch_xla.core.functions as xf
168168
import torch_xla.core.xla_model as xm
169169

170-
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
170+
tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
171+
tensor = tensor.to(original_device)
172+
return tensor
171173

172174
def all_reduce(
173175
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
@@ -211,8 +213,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
211213
if is_tensor:
212214
if obj.dim() == 0:
213215
obj = obj.unsqueeze(0)
214-
if obj.device.type != "xla":
215-
obj = obj.to(self.root_device)
216+
original_device = obj.device
217+
# XLA distributed requires that the data is on the XLA device
218+
obj = obj.to(self.root_device)
216219
else:
217220
# support for arbitrary pickle-ables
218221
buffer = io.BytesIO()
@@ -226,8 +229,11 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
226229
obj = obj[0]
227230

228231
if not is_tensor:
232+
# this will preserve the dtype and device of any tensors
229233
buffer = io.BytesIO(obj.cpu().byte().numpy())
230234
obj = torch.load(buffer)
235+
else:
236+
obj = obj.to(original_device)
231237

232238
return obj
233239

src/lightning/fabric/strategies/xla_fsdp.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,15 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
269269
)
270270
if tensor.dim() == 0:
271271
tensor = tensor.unsqueeze(0)
272-
if tensor.device.type != "xla":
273-
tensor = tensor.to(self.root_device)
272+
original_device = tensor.device
273+
tensor = tensor.to(self.root_device)
274274

275275
import torch_xla.core.functions as xf
276276
import torch_xla.core.xla_model as xm
277277

278-
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
278+
tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
279+
tensor = tensor.to(original_device)
280+
return tensor
279281

280282
def all_reduce(
281283
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
@@ -319,8 +321,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
319321
if is_tensor:
320322
if obj.dim() == 0:
321323
obj = obj.unsqueeze(0)
322-
if obj.device.type != "xla":
323-
obj = obj.to(self.root_device)
324+
original_device = obj.device
325+
# XLA distributed requires that the data is on the XLA device
326+
obj = obj.to(self.root_device)
324327
else:
325328
# support for arbitrary pickle-ables
326329
buffer = io.BytesIO()
@@ -334,8 +337,11 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
334337
obj = obj[0]
335338

336339
if not is_tensor:
340+
# this will preserve the dtype and device of any tensors
337341
buffer = io.BytesIO(obj.cpu().byte().numpy())
338342
obj = torch.load(buffer)
343+
else:
344+
obj = obj.to(original_device)
339345

340346
return obj
341347

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
148148
- The input tensors now get cast to the right precision type before transfer to the device ([#18264](https://github.com/Lightning-AI/lightning/pull/18264))
149149

150150

151+
- Broadcast and reduction of tensors with XLA-based strategies now preserve the input's device ([#18275](https://github.com/Lightning-AI/lightning/pull/18275))
152+
151153
### Deprecated
152154

153155
- Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))

src/lightning/pytorch/strategies/xla.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
186186
if is_tensor:
187187
if obj.dim() == 0:
188188
obj = obj.unsqueeze(0)
189-
if obj.device.type != "xla":
190-
obj = obj.to(self.root_device)
189+
original_device = obj.device
190+
# XLA distributed requires that the data is on the XLA device
191+
obj = obj.to(self.root_device)
191192
else:
192193
# support for arbitrary pickle-ables
193194
buffer = io.BytesIO()
@@ -201,8 +202,11 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
201202
obj = obj[0]
202203

203204
if not is_tensor:
205+
# this will preserve the dtype and device of any tensors
204206
buffer = io.BytesIO(obj.cpu().byte().numpy())
205207
obj = torch.load(buffer)
208+
else:
209+
obj = obj.to(original_device)
206210

207211
return obj
208212

@@ -290,13 +294,15 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
290294
)
291295
if tensor.dim() == 0:
292296
tensor = tensor.unsqueeze(0)
293-
if tensor.device.type != "xla":
294-
tensor = tensor.to(self.root_device)
297+
original_device = tensor.device
298+
tensor = tensor.to(self.root_device)
295299

296300
import torch_xla.core.functions as xf
297301
import torch_xla.core.xla_model as xm
298302

299-
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
303+
tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
304+
tensor = tensor.to(original_device)
305+
return tensor
300306

301307
def teardown(self) -> None:
302308
super().teardown()

tests/tests_fabric/strategies/test_xla.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
import os
1515
from functools import partial
1616
from unittest import mock
17-
from unittest.mock import MagicMock, Mock
17+
from unittest.mock import ANY, MagicMock, Mock
1818

1919
import pytest
2020
import torch
2121
from torch.utils.data import DataLoader
2222

23-
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1, XLAAccelerator
23+
from lightning.fabric.accelerators.xla import _using_pjrt, _XLA_GREATER_EQUAL_2_1, XLAAccelerator
2424
from lightning.fabric.strategies import XLAStrategy
2525
from lightning.fabric.strategies.launchers.xla import _XLALauncher
2626
from lightning.fabric.utilities.distributed import ReduceOp
@@ -52,19 +52,30 @@ def xla_launch(fn, strategy=None):
5252
def broadcast_on_tpu_fn(strategy):
5353
# test broadcasting a tensor
5454
obj = torch.tensor(strategy.global_rank)
55+
assert obj.device.type == "cpu"
5556
# In PjRT, the local rank and global rank have no solid relation.
5657
# global rank may not even be contiguous on a host, because it depends on the 3D mesh structure that is formed by
5758
# the TPUs on all hosts in a pod. So checking a different src is not reliable
5859
# https://github.com/pytorch/xla/blob/v2.0.0/torch_xla/experimental/pjrt.py#L161-L163
5960
src = 0
6061
result = strategy.broadcast(obj, src)
6162
assert result.item() == src
62-
assert result.device.type == "xla"
63+
assert result.device.type == "cpu" # the original device is preserved
6364

6465
# test broadcasting an arbitrary object
65-
obj = ("ver_0.5", "logger_name", strategy.global_rank)
66-
result = strategy.broadcast(obj, src=src)
67-
assert result == ("ver_0.5", "logger_name", src)
66+
if _using_pjrt():
67+
tensor = torch.tensor(strategy.global_rank, device=strategy.root_device, dtype=torch.bfloat16)
68+
obj = ("ver_0.5", "logger_name", strategy.global_rank, tensor)
69+
result = strategy.broadcast(obj, src=src)
70+
assert result == ("ver_0.5", "logger_name", src, ANY)
71+
assert result[3].device.type == "xla" # the original device is preserved
72+
assert result[3].dtype == torch.bfloat16
73+
else:
74+
# XRT fails to unpickle tensors, segfaults with
75+
# RuntimeError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1)
76+
obj = ("ver_0.5", "logger_name", strategy.global_rank)
77+
result = strategy.broadcast(obj, src=src)
78+
assert result == ("ver_0.5", "logger_name", src)
6879

6980

7081
@RunIf(tpu=True)
@@ -134,7 +145,7 @@ def tpu_all_gather_fn(strategy):
134145
tensor = torch.tensor(1.0, requires_grad=True)
135146
result = strategy.all_gather(tensor, sync_grads=sync_grads)
136147
summed = result.sum()
137-
assert summed.device.type == "xla"
148+
assert summed.device.type == "cpu" # the original device is preserved
138149
assert torch.equal(summed, torch.tensor(strategy.world_size, dtype=torch.float32))
139150
if not _XLA_GREATER_EQUAL_2_1:
140151
summed.backward()

0 commit comments

Comments
 (0)