Skip to content

Commit 90c8ae0

Browse files
committed
update
1 parent ecb8e00 commit 90c8ae0

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

tests/tests_pytorch/graveyard/test_tpu.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,21 @@
55
import torch
66

77

8+
# mimics `lightning_utilites.RequirementCache`
9+
class MockXLAAvailable:
10+
def __init__(self, available: bool, pkg_name: str = "torch_xla"):
11+
self.available = available
12+
self.pkg_name = pkg_name
13+
14+
def __bool__(self):
15+
return self.available
16+
17+
def __str__(self):
18+
if self.available:
19+
return f"Requirement '{self.pkg_name}' met"
20+
return f"Module not found: {self.pkg_name!r}. HINT: Try running `pip install -U {self.pkg_name}`"
21+
22+
823
@pytest.mark.parametrize(
924
("import_path", "name"),
1025
[
@@ -35,7 +50,10 @@ def test_graveyard_single_tpu(import_path, name):
3550
("lightning.pytorch.plugins.precision.xlabf16", "XLABf16PrecisionPlugin"),
3651
],
3752
)
38-
def test_graveyard_no_device(import_path, name):
53+
def test_graveyard_no_device(import_path, name, monkeypatch):
54+
monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", MockXLAAvailable(False))
55+
monkeypatch.setattr("pytorch_lightning_enterprise.plugins.precision.xla._XLA_AVAILABLE", MockXLAAvailable(False))
56+
3957
module = import_module(import_path)
4058
cls = getattr(module, name)
4159
with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"):

0 commit comments

Comments
 (0)