File tree Expand file tree Collapse file tree 3 files changed +11
-3
lines changed
tests/tests_pytorch/accelerators Expand file tree Collapse file tree 3 files changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -80,4 +80,4 @@ def test_gpu_device_name():
8080
8181
8282def test_gpu_device_name_no_gpu (cuda_count_0 ):
83- assert str ( False ) == CUDAAccelerator . device_name ()
83+ assert CUDAAccelerator . device_name ( ) == ""
Original file line number Diff line number Diff line change @@ -47,7 +47,7 @@ def test_mps_device_name():
4747
4848def test_mps_device_name_not_available ():
4949 with mock .patch ("torch.backends.mps.is_available" , return_value = False ):
50- assert MPSAccelerator .device_name () == "False "
50+ assert MPSAccelerator .device_name () == ""
5151
5252
5353def test_warning_if_mps_not_used (mps_count_1 ):
Original file line number Diff line number Diff line change @@ -304,7 +304,15 @@ def test_warning_if_tpus_not_used(tpu_available):
304304
305305@RunIf (tpu = True )
306306def test_tpu_device_name ():
307- assert XLAAccelerator .device_name () == "TPU"
307+ from lightning .fabric .accelerators .xla import _XLA_GREATER_EQUAL_2_1
308+
309+ if _XLA_GREATER_EQUAL_2_1 :
310+ from torch_xla ._internal import tpu
311+ else :
312+ from torch_xla .experimental import tpu
313+ import torch_xla .core .xla_env_vars as xenv
314+
315+ assert XLAAccelerator .device_name () == tpu .get_tpu_env ()[xenv .ACCELERATOR_TYPE ]
308316
309317
310318@pytest .mark .parametrize (
You can’t perform that action at this time.
0 commit comments