From e6f30560bf5be1f517bab6a5fb472b8ff781c2c7 Mon Sep 17 00:00:00 2001 From: dimzhead <44970988+dimzhead@users.noreply.github.com> Date: Wed, 27 Nov 2024 08:38:13 -0500 Subject: [PATCH 1/2] Update xla.py --- src/lightning/fabric/accelerators/xla.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index d438197329939..c8d6a4a538d31 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -117,10 +117,6 @@ def _using_pjrt() -> bool: return xr.using_pjrt() - from torch_xla.experimental import pjrt - - return pjrt.using_pjrt() - def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: """Parses the TPU devices given in the format as accepted by the From daa23ea1785a4a8be7c1df1da73ef7e25cd356b1 Mon Sep 17 00:00:00 2001 From: dimzhead <44970988+dimzhead@users.noreply.github.com> Date: Wed, 27 Nov 2024 08:49:27 -0500 Subject: [PATCH 2/2] Update xla.py --- src/lightning/fabric/accelerators/xla.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index c8d6a4a538d31..07c51ceb38286 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -107,7 +107,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No def _using_pjrt() -> bool: # `using_pjrt` is removed in torch_xla 2.5 - if _XLA_GREATER_EQUAL_2_5: + if True: from torch_xla import runtime as xr return xr.device_type() is not None @@ -117,6 +117,10 @@ def _using_pjrt() -> bool: return xr.using_pjrt() + from torch_xla.experimental import pjrt + + return pjrt.using_pjrt() + def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: """Parses the TPU devices given in the format as accepted by the