Skip to content

Commit f124d55

Browse files
committed
fix: fix unittests.
1 parent 8e91f8f commit f124d55

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

tests/tests_pytorch/accelerators/test_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,4 @@ def test_gpu_device_name():
8080

8181

8282
def test_gpu_device_name_no_gpu(cuda_count_0):
83-
assert str(False) == CUDAAccelerator.device_name()
83+
assert CUDAAccelerator.device_name() == ""

tests/tests_pytorch/accelerators/test_mps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_mps_device_name():
4747

4848
def test_mps_device_name_not_available():
4949
with mock.patch("torch.backends.mps.is_available", return_value=False):
50-
assert MPSAccelerator.device_name() == "False"
50+
assert MPSAccelerator.device_name() == ""
5151

5252

5353
def test_warning_if_mps_not_used(mps_count_1):

tests/tests_pytorch/accelerators/test_xla.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,15 @@ def test_warning_if_tpus_not_used(tpu_available):
304304

305305
@RunIf(tpu=True)
306306
def test_tpu_device_name():
307-
assert XLAAccelerator.device_name() == "TPU"
307+
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
308+
309+
if _XLA_GREATER_EQUAL_2_1:
310+
from torch_xla._internal import tpu
311+
else:
312+
from torch_xla.experimental import tpu
313+
import torch_xla.core.xla_env_vars as xenv
314+
315+
assert XLAAccelerator.device_name() == tpu.get_tpu_env()[xenv.ACCELERATOR_TYPE]
308316

309317

310318
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)