|
14 | 14 | import os |
15 | 15 | from functools import partial |
16 | 16 | from unittest import mock |
17 | | -from unittest.mock import MagicMock, Mock |
| 17 | +from unittest.mock import ANY, MagicMock, Mock |
18 | 18 |
|
19 | 19 | import pytest |
20 | 20 | import torch |
21 | 21 | from torch.utils.data import DataLoader |
22 | 22 |
|
23 | | -from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1, XLAAccelerator |
| 23 | +from lightning.fabric.accelerators.xla import _using_pjrt, _XLA_GREATER_EQUAL_2_1, XLAAccelerator |
24 | 24 | from lightning.fabric.strategies import XLAStrategy |
25 | 25 | from lightning.fabric.strategies.launchers.xla import _XLALauncher |
26 | 26 | from lightning.fabric.utilities.distributed import ReduceOp |
@@ -52,19 +52,30 @@ def xla_launch(fn, strategy=None): |
52 | 52 | def broadcast_on_tpu_fn(strategy): |
53 | 53 | # test broadcasting a tensor |
54 | 54 | obj = torch.tensor(strategy.global_rank) |
| 55 | + assert obj.device.type == "cpu" |
55 | 56 | # In PjRT, the local rank and global rank have no solid relation. |
56 | 57 | # global rank may not even be contiguous on a host, because it depends on the 3D mesh structure that is formed by |
57 | 58 | # the TPUs on all hosts in a pod. So checking a different src is not reliable |
58 | 59 | # https://github.com/pytorch/xla/blob/v2.0.0/torch_xla/experimental/pjrt.py#L161-L163 |
59 | 60 | src = 0 |
60 | 61 | result = strategy.broadcast(obj, src) |
61 | 62 | assert result.item() == src |
62 | | - assert result.device.type == "xla" |
| 63 | + assert result.device.type == "cpu" # the original device is preserved |
63 | 64 |
|
64 | 65 | # test broadcasting an arbitrary object |
65 | | - obj = ("ver_0.5", "logger_name", strategy.global_rank) |
66 | | - result = strategy.broadcast(obj, src=src) |
67 | | - assert result == ("ver_0.5", "logger_name", src) |
| 66 | + if _using_pjrt(): |
| 67 | + tensor = torch.tensor(strategy.global_rank, device=strategy.root_device, dtype=torch.bfloat16) |
| 68 | + obj = ("ver_0.5", "logger_name", strategy.global_rank, tensor) |
| 69 | + result = strategy.broadcast(obj, src=src) |
| 70 | + assert result == ("ver_0.5", "logger_name", src, ANY) |
| 71 | + assert result[3].device.type == "xla" # the original device is preserved |
| 72 | + assert result[3].dtype == torch.bfloat16 |
| 73 | + else: |
| 74 | + # XRT fails to unpickle tensors, segfaults with |
| 75 | + # RuntimeError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) |
| 76 | + obj = ("ver_0.5", "logger_name", strategy.global_rank) |
| 77 | + result = strategy.broadcast(obj, src=src) |
| 78 | + assert result == ("ver_0.5", "logger_name", src) |
68 | 79 |
|
69 | 80 |
|
70 | 81 | @RunIf(tpu=True) |
@@ -134,7 +145,7 @@ def tpu_all_gather_fn(strategy): |
134 | 145 | tensor = torch.tensor(1.0, requires_grad=True) |
135 | 146 | result = strategy.all_gather(tensor, sync_grads=sync_grads) |
136 | 147 | summed = result.sum() |
137 | | - assert summed.device.type == "xla" |
| 148 | + assert summed.device.type == "cpu" # the original device is preserved |
138 | 149 | assert torch.equal(summed, torch.tensor(strategy.world_size, dtype=torch.float32)) |
139 | 150 | if not _XLA_GREATER_EQUAL_2_1: |
140 | 151 | summed.backward() |
|
0 commit comments