2020
2121from lightning .fabric .accelerators .accelerator import Accelerator
2222from lightning .fabric .accelerators .registry import _AcceleratorRegistry
23- from lightning .fabric .utilities .device_parser import _check_data_type
23+ from lightning .fabric .utilities .imports import _raise_enterprise_not_available
24+
25+ _XLA_AVAILABLE = RequirementCache ("torch_xla>=1.13" , "torch_xla" )
26+ _XLA_GREATER_EQUAL_2_1 = RequirementCache ("torch_xla>=2.1" )
27+ _XLA_GREATER_EQUAL_2_5 = RequirementCache ("torch_xla>=2.5" )
2428
2529
2630class XLAAccelerator (Accelerator ):
@@ -31,38 +35,38 @@ class XLAAccelerator(Accelerator):
3135 """
3236
3337 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
34- if not _XLA_AVAILABLE :
35- raise ModuleNotFoundError (str (_XLA_AVAILABLE ))
36- if not _using_pjrt ():
37- raise RuntimeError ("The XLA XRT runtime is not supported anymore." )
38+ _raise_enterprise_not_available ()
3839 super ().__init__ (* args , ** kwargs )
3940
41+ from pytorch_lightning_enterprise .accelerators .xla import XLAAccelerator as EnterpriseXLAAccelerator
42+
43+ self .accelerator_impl = EnterpriseXLAAccelerator (* args , ** kwargs )
44+
4045 @override
4146 def setup_device (self , device : torch .device ) -> None :
42- pass
47+ return self . accelerator_impl . setup_device ( device )
4348
4449 @override
4550 def teardown (self ) -> None :
46- pass
51+ return self . accelerator_impl . teardown ()
4752
4853 @staticmethod
4954 @override
5055 def parse_devices (devices : int | str | list [int ]) -> int | list [int ]:
5156 """Accelerator device parsing logic."""
52- return _parse_tpu_devices (devices )
57+ _raise_enterprise_not_available ()
58+ from pytorch_lightning_enterprise .accelerators .xla import XLAAccelerator as EnterpriseXLAAccelerator
59+
60+ return EnterpriseXLAAccelerator .parse_devices (devices )
5361
5462 @staticmethod
5563 @override
5664 def get_parallel_devices (devices : int | list [int ]) -> list [torch .device ]:
5765 """Gets parallel devices for the Accelerator."""
58- devices = _parse_tpu_devices (devices )
59- if isinstance (devices , int ):
60- return [torch .device ("xla" , i ) for i in range (devices )]
61- # list of devices is not supported, just a specific index, fine to access [0]
62- return [torch .device ("xla" , devices [0 ])]
63- # we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
64- # accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
65- # it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy
66+ _raise_enterprise_not_available ()
67+ from pytorch_lightning_enterprise .accelerators .xla import XLAAccelerator as EnterpriseXLAAccelerator
68+
69+ return EnterpriseXLAAccelerator .get_parallel_devices (devices )
6670
6771 @staticmethod
6872 @override
@@ -71,16 +75,10 @@ def get_parallel_devices(devices: int | list[int]) -> list[torch.device]:
7175 @functools .lru_cache (maxsize = 1 )
7276 def auto_device_count () -> int :
7377 """Get the devices when set to auto."""
74- if not _XLA_AVAILABLE :
75- return 0
76- if _XLA_GREATER_EQUAL_2_1 :
77- from torch_xla ._internal import tpu
78-
79- return tpu .num_available_devices ()
80- from torch_xla .experimental import tpu
78+ _raise_enterprise_not_available ()
79+ from pytorch_lightning_enterprise .accelerators .xla import XLAAccelerator as EnterpriseXLAAccelerator
8180
82- device_count_on_version = {2 : 8 , 3 : 8 , 4 : 4 }
83- return device_count_on_version .get (tpu .version (), 8 )
81+ return EnterpriseXLAAccelerator .auto_device_count ()
8482
8583 @staticmethod
8684 @override
@@ -92,6 +90,9 @@ def is_available() -> bool:
9290 # XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
9391 # when `torch_xla` is imported but not used
9492 return False
93+ except ModuleNotFoundError as e :
94+ warnings .warn (str (e ))
95+ return False
9596
9697 @staticmethod
9798 @override
@@ -106,74 +107,3 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
106107 cls ,
107108 description = cls .__name__ ,
108109 )
109-
110-
111- # PJRT support requires this minimum version
112- _XLA_AVAILABLE = RequirementCache ("torch_xla>=1.13" , "torch_xla" )
113- _XLA_GREATER_EQUAL_2_1 = RequirementCache ("torch_xla>=2.1" )
114- _XLA_GREATER_EQUAL_2_5 = RequirementCache ("torch_xla>=2.5" )
115-
116-
117- def _using_pjrt () -> bool :
118- # `using_pjrt` is removed in torch_xla 2.5
119- if _XLA_GREATER_EQUAL_2_5 :
120- from torch_xla import runtime as xr
121-
122- return xr .device_type () is not None
123- # delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
124- if _XLA_GREATER_EQUAL_2_1 :
125- from torch_xla import runtime as xr
126-
127- return xr .using_pjrt ()
128-
129- from torch_xla .experimental import pjrt
130-
131- return pjrt .using_pjrt ()
132-
133-
134- def _parse_tpu_devices (devices : int | str | list [int ]) -> int | list [int ]:
135- """Parses the TPU devices given in the format as accepted by the
136- :class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.
137-
138- Args:
139- devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
140- An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
141- A single element list of int or string can be used to indicate the specific TPU core to use.
142-
143- Returns:
144- A list of tpu cores to be used.
145-
146- """
147- _check_data_type (devices )
148- if isinstance (devices , str ):
149- devices = _parse_tpu_devices_str (devices )
150- _check_tpu_devices_valid (devices )
151- return devices
152-
153-
154- def _check_tpu_devices_valid (devices : object ) -> None :
155- device_count = XLAAccelerator .auto_device_count ()
156- if (
157- # support number of devices
158- isinstance (devices , int )
159- and devices in {1 , device_count }
160- # support picking a specific device
161- or isinstance (devices , (list , tuple ))
162- and len (devices ) == 1
163- and 0 <= devices [0 ] <= device_count - 1
164- ):
165- return
166- raise ValueError (
167- f"`devices` can only be 'auto', 1, { device_count } or [<0-{ device_count - 1 } >] for TPUs. Got { devices !r} "
168- )
169-
170-
171- def _parse_tpu_devices_str (devices : str ) -> int | list [int ]:
172- devices = devices .strip ()
173- try :
174- return int (devices )
175- except ValueError :
176- try :
177- return [int (x .strip ()) for x in devices .split ("," ) if len (x ) > 0 ]
178- except ValueError :
179- raise ValueError (f"Could not parse the selected TPU devices: { devices !r} " )
0 commit comments