File tree Expand file tree Collapse file tree 1 file changed +19
-1
lines changed
tests/tests_pytorch/graveyard Expand file tree Collapse file tree 1 file changed +19
-1
lines changed Original file line number Diff line number Diff line change 55import torch
66
77
8+ # mimics `lightning_utilites.RequirementCache`
9+ class MockXLAAvailable :
10+ def __init__ (self , available : bool , pkg_name : str = "torch_xla" ):
11+ self .available = available
12+ self .pkg_name = pkg_name
13+
14+ def __bool__ (self ):
15+ return self .available
16+
17+ def __str__ (self ):
18+ if self .available :
19+ return f"Requirement '{ self .pkg_name } ' met"
20+ return f"Module not found: { self .pkg_name !r} . HINT: Try running `pip install -U { self .pkg_name } `"
21+
22+
823@pytest .mark .parametrize (
924 ("import_path" , "name" ),
1025 [
@@ -35,7 +50,10 @@ def test_graveyard_single_tpu(import_path, name):
3550 ("lightning.pytorch.plugins.precision.xlabf16" , "XLABf16PrecisionPlugin" ),
3651 ],
3752)
38- def test_graveyard_no_device (import_path , name ):
53+ def test_graveyard_no_device (import_path , name , monkeypatch ):
54+ monkeypatch .setattr ("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE" , MockXLAAvailable (False ))
55+ monkeypatch .setattr ("pytorch_lightning_enterprise.plugins.precision.xla._XLA_AVAILABLE" , MockXLAAvailable (False ))
56+
3957 module = import_module (import_path )
4058 cls = getattr (module , name )
4159 with pytest .deprecated_call (match = "is deprecated" ), pytest .raises (ModuleNotFoundError , match = "torch_xla" ):
You can’t perform that action at this time.
0 commit comments