Skip to content

Commit daa23ea

Browse files
authored
Update xla.py
1 parent e6f3056 commit daa23ea

File tree

1 file changed

+5
-1
lines changed
  • src/lightning/fabric/accelerators

1 file changed

+5
-1
lines changed

src/lightning/fabric/accelerators/xla.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
107107

108108
def _using_pjrt() -> bool:
109109
# `using_pjrt` is removed in torch_xla 2.5
110-
if _XLA_GREATER_EQUAL_2_5:
110+
if True:
111111
from torch_xla import runtime as xr
112112

113113
return xr.device_type() is not None
@@ -117,6 +117,10 @@ def _using_pjrt() -> bool:
117117

118118
return xr.using_pjrt()
119119

120+
from torch_xla.experimental import pjrt
121+
122+
return pjrt.using_pjrt()
123+
120124

121125
def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
122126
"""Parses the TPU devices given in the format as accepted by the

0 commit comments

Comments
 (0)