Skip to content

Commit dab4fb5

Browse files
committed
fabric tests
1 parent 644f354 commit dab4fb5

File tree

4 files changed

+36
-7
lines changed

4 files changed

+36
-7
lines changed

src/lightning/fabric/strategies/launchers/xla.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ class _XLALauncher(_Launcher):
4242
def __init__(self, strategy: Union["XLAStrategy", "XLAFSDPStrategy"]) -> None:
4343
super().__init__()
4444
_raise_enterprise_not_available()
45-
from pytorch_lightning_enterprise.strategies.launchers.xla import _XLALauncher as EnterpriseXLALauncher
45+
from pytorch_lightning_enterprise.strategies.xla.launcher import XLALauncherFabric as EnterpriseXLALauncher
4646

4747
self.xla_impl = EnterpriseXLALauncher(strategy=strategy)
4848

4949
@property
5050
@override
5151
def is_interactive_compatible(self) -> bool:
52-
return self.xla_impl.is_interactive_compatible()
52+
return self.xla_impl.is_interactive_compatible
5353

5454
@override
5555
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
@@ -65,3 +65,11 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
6565
6666
"""
6767
return self.xla_impl.launch(function=function, *args, **kwargs)
68+
69+
@property
70+
def _start_method(self) -> str:
71+
return self.xla_impl._start_method
72+
73+
@_start_method.setter
74+
def _start_method(self, start_method: str) -> None:
75+
self.xla_impl._start_method = start_method

tests/tests_fabric/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
153153
monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
154154
monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
155155
monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock())
156+
monkeypatch.setitem(sys.modules, "torch_xla._internal", Mock())
157+
monkeypatch.setitem(sys.modules, "torch_xla._internal.tpu", Mock())
156158

157159
# Then patch the _XLA_AVAILABLE flags in various modules
158160
monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_AVAILABLE", value)
@@ -161,6 +163,12 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
161163
monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value)
162164
monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_1", value)
163165
monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_5", value)
166+
# Patch in the modules where they're used after import
167+
monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", value)
168+
monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_1", value)
169+
monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_5", value)
170+
monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_AVAILABLE", value)
171+
monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", value)
164172

165173

166174
@pytest.fixture
@@ -172,6 +180,11 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
172180
mock_xla_available(monkeypatch, value)
173181
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: value)
174182
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8)
183+
# Also mock the enterprise XLAAccelerator methods
184+
import pytorch_lightning_enterprise.accelerators.xla
185+
186+
monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "is_available", lambda: value)
187+
monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8)
175188

176189

177190
@pytest.fixture

tests/tests_fabric/test_connector.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,10 @@ class TestStrategy(DDPStrategy):
242242
),
243243
],
244244
)
245-
@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._read_hosts", return_value=["node0", "node1"])
246-
@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0)
245+
@mock.patch(
246+
"pytorch_lightning_enterprise.plugins.environments.lsf.LSFEnvironment._read_hosts", return_value=["node0", "node1"]
247+
)
248+
@mock.patch("pytorch_lightning_enterprise.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0)
247249
def test_fallback_from_ddp_spawn_to_ddp_on_cluster(_, __, env_vars, expected_environment):
248250
with mock.patch.dict(os.environ, env_vars, clear=True):
249251
connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2)
@@ -341,8 +343,8 @@ def test_cuda_accelerator_can_not_run_on_system(_):
341343

342344

343345
@pytest.mark.skipif(XLAAccelerator.is_available(), reason="test requires missing TPU")
344-
@mock.patch("lightning.fabric.accelerators.xla._XLA_AVAILABLE", True)
345-
@mock.patch("lightning.fabric.accelerators.xla._using_pjrt", return_value=True)
346+
@mock.patch("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", True)
347+
@mock.patch("pytorch_lightning_enterprise.accelerators.xla._using_pjrt", return_value=True)
346348
def test_tpu_accelerator_can_not_run_on_system(_):
347349
with pytest.raises(RuntimeError, match="XLAAccelerator` can not run on your system"):
348350
_Connector(accelerator="tpu", devices=8)

tests/tests_fabric/utilities/test_throughput.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@ def test_get_available_flops(xla_available):
4545
):
4646
assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None
4747

48-
from torch_xla.experimental import tpu
48+
# Import from the right module based on _XLA_GREATER_EQUAL_2_1
49+
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
50+
51+
if _XLA_GREATER_EQUAL_2_1:
52+
from torch_xla._internal import tpu
53+
else:
54+
from torch_xla.experimental import tpu
4955

5056
assert isinstance(tpu, Mock)
5157

0 commit comments

Comments
 (0)