File tree Expand file tree Collapse file tree 3 files changed +11
-3
lines changed
tests/tests_pytorch/accelerators Expand file tree Collapse file tree 3 files changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -80,4 +80,4 @@ def test_gpu_device_name():
80
80
81
81
82
82
def test_gpu_device_name_no_gpu (cuda_count_0 ):
83
- assert str ( False ) == CUDAAccelerator . device_name ()
83
+ assert CUDAAccelerator . device_name ( ) == ""
Original file line number Diff line number Diff line change @@ -47,7 +47,7 @@ def test_mps_device_name():
47
47
48
48
def test_mps_device_name_not_available ():
49
49
with mock .patch ("torch.backends.mps.is_available" , return_value = False ):
50
- assert MPSAccelerator .device_name () == "False "
50
+ assert MPSAccelerator .device_name () == ""
51
51
52
52
53
53
def test_warning_if_mps_not_used (mps_count_1 ):
Original file line number Diff line number Diff line change @@ -304,7 +304,15 @@ def test_warning_if_tpus_not_used(tpu_available):
304
304
305
305
@RunIf (tpu = True )
306
306
def test_tpu_device_name ():
307
- assert XLAAccelerator .device_name () == "TPU"
307
+ from lightning .fabric .accelerators .xla import _XLA_GREATER_EQUAL_2_1
308
+
309
+ if _XLA_GREATER_EQUAL_2_1 :
310
+ from torch_xla ._internal import tpu
311
+ else :
312
+ from torch_xla .experimental import tpu
313
+ import torch_xla .core .xla_env_vars as xenv
314
+
315
+ assert XLAAccelerator .device_name () == tpu .get_tpu_env ()[xenv .ACCELERATOR_TYPE ]
308
316
309
317
310
318
@pytest .mark .parametrize (
You can’t perform that action at this time.
0 commit comments