1515import warnings
1616from contextlib import contextmanager
1717from functools import lru_cache
18- from typing import Dict , Generator , List , Optional , Set , Union
18+ from typing import cast , Dict , Generator , List , Optional , Union
1919
2020import torch
2121from lightning_utilities .core .rank_zero import rank_zero_info
2222
2323from lightning_fabric .accelerators .accelerator import Accelerator
24- from lightning_fabric .utilities .imports import (
25- _TORCH_GREATER_EQUAL_1_12 ,
26- _TORCH_GREATER_EQUAL_1_13 ,
27- _TORCH_GREATER_EQUAL_2_0 ,
28- )
24+ from lightning_fabric .utilities .imports import _TORCH_GREATER_EQUAL_1_12 , _TORCH_GREATER_EQUAL_2_0
2925
3026
3127class CUDAAccelerator (Accelerator ):
@@ -161,11 +157,11 @@ def num_cuda_devices() -> int:
161157 Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support,
162158 if the platform allows it.
163159 """
164- if _TORCH_GREATER_EQUAL_1_13 :
160+ if _TORCH_GREATER_EQUAL_2_0 :
165161 return torch .cuda .device_count ()
166162
167163 # Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879
168- # TODO: Remove once minimum supported PyTorch version is 1.13
164+ # TODO: Remove once minimum supported PyTorch version is 2.0
169165 nvml_count = _device_count_nvml ()
170166 return torch .cuda .device_count () if nvml_count < 0 else nvml_count
171167
@@ -180,63 +176,167 @@ def is_cuda_available() -> bool:
180176 return torch .cuda .is_available () if _TORCH_GREATER_EQUAL_2_0 else num_cuda_devices () > 0
181177
182178
183- # TODO: Remove once minimum supported PyTorch version is 1.13
184- def _parse_visible_devices () -> Set [ int ]:
185- """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879 ."""
179+ # TODO: Remove once minimum supported PyTorch version is 2.0
180+ def _parse_visible_devices () -> Union [ List [ int ], List [ str ] ]:
181+ """Parse CUDA_VISIBLE_DEVICES environment variable ."""
186182 var = os .getenv ("CUDA_VISIBLE_DEVICES" )
187183 if var is None :
188- return { x for x in range (64 )}
184+ return list ( range (64 ))
189185
190186 def _strtoul (s : str ) -> int :
191- """Return -1 or integer sequence string starts with. """
192- if len ( s ) == 0 :
187+ """Return -1 or positive integer sequence string starts with, """
188+ if not s :
193189 return - 1
194190 for idx , c in enumerate (s ):
195- if not c .isdigit ():
191+ if not ( c .isdigit () or ( idx == 0 and c in "+-" ) ):
196192 break
197193 if idx + 1 == len (s ):
198194 idx += 1
199195 return int (s [:idx ]) if idx > 0 else - 1
200196
197+ def parse_list_with_prefix (lst : str , prefix : str ) -> List [str ]:
198+ rcs : List [str ] = []
199+ for elem in lst .split ("," ):
200+ # Repeated id results in empty set
201+ if elem in rcs :
202+ return cast (List [str ], [])
203+ # Anything other but prefix is ignored
204+ if not elem .startswith (prefix ):
205+ break
206+ rcs .append (elem )
207+ return rcs
208+
209+ if var .startswith ("GPU-" ):
210+ return parse_list_with_prefix (var , "GPU-" )
211+ if var .startswith ("MIG-" ):
212+ return parse_list_with_prefix (var , "MIG-" )
201213 # CUDA_VISIBLE_DEVICES uses something like strtoul
202214 # which makes `1gpu2,2ampere` is equivalent to `1,2`
203- rc : Set [int ] = set ()
215+ rc : List [int ] = []
204216 for elem in var .split ("," ):
205- rc .add (_strtoul (elem .strip ()))
217+ x = _strtoul (elem .strip ())
218+ # Repeated ordinal results in empty set
219+ if x in rc :
220+ return cast (List [int ], [])
221+ # Negative value aborts the sequence
222+ if x < 0 :
223+ break
224+ rc .append (x )
206225 return rc
207226
208227
209- # TODO: Remove once minimum supported PyTorch version is 1.13
228+ # TODO: Remove once minimum supported PyTorch version is 2.0
210229def _raw_device_count_nvml () -> int :
211- """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879 ."""
212- from ctypes import c_int , CDLL
230+ """Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed ."""
231+ from ctypes import byref , c_int , CDLL
213232
214233 nvml_h = CDLL ("libnvidia-ml.so.1" )
215234 rc = nvml_h .nvmlInit ()
216235 if rc != 0 :
217236 warnings .warn ("Can't initialize NVML" )
218237 return - 1
219- dev_arr = ( c_int * 1 ) (- 1 )
220- rc = nvml_h .nvmlDeviceGetCount_v2 (dev_arr )
238+ dev_count = c_int (- 1 )
239+ rc = nvml_h .nvmlDeviceGetCount_v2 (byref ( dev_count ) )
221240 if rc != 0 :
222241 warnings .warn ("Can't get nvml device count" )
223242 return - 1
224243 del nvml_h
225- return dev_arr [0 ]
244+ return dev_count .value
245+
226246
247+ # TODO: Remove once minimum supported PyTorch version is 2.0
248+ def _raw_device_uuid_nvml () -> Optional [List [str ]]:
249+ """Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
250+ from ctypes import byref , c_int , c_void_p , CDLL , create_string_buffer
227251
228- # TODO: Remove once minimum supported PyTorch version is 1.13
252+ nvml_h = CDLL ("libnvidia-ml.so.1" )
253+ rc = nvml_h .nvmlInit ()
254+ if rc != 0 :
255+ warnings .warn ("Can't initialize NVML" )
256+ return None
257+ dev_count = c_int (- 1 )
258+ rc = nvml_h .nvmlDeviceGetCount_v2 (byref (dev_count ))
259+ if rc != 0 :
260+ warnings .warn ("Can't get nvml device count" )
261+ return None
262+ uuids : List [str ] = []
263+ for idx in range (dev_count .value ):
264+ dev_id = c_void_p ()
265+ rc = nvml_h .nvmlDeviceGetHandleByIndex_v2 (idx , byref (dev_id ))
266+ if rc != 0 :
267+ warnings .warn ("Can't get device handle" )
268+ return None
269+ buf_len = 96
270+ buf = create_string_buffer (buf_len )
271+ rc = nvml_h .nvmlDeviceGetUUID (dev_id , buf , buf_len )
272+ if rc != 0 :
273+ warnings .warn ("Can't get device UUID" )
274+ return None
275+ uuids .append (buf .raw .decode ("ascii" ).strip ("\0 " ))
276+ del nvml_h
277+ return uuids
278+
279+
280+ # TODO: Remove once minimum supported PyTorch version is 2.0
281+ def _transform_uuid_to_ordinals (candidates : List [str ], uuids : List [str ]) -> List [int ]:
282+ """Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials
283+ IDs."""
284+
285+ def uuid_to_orinal (candidate : str , uuids : List [str ]) -> int :
286+ best_match = - 1
287+ for idx , uuid in enumerate (uuids ):
288+ if not uuid .startswith (candidate ):
289+ continue
290+ # Ambigous candidate
291+ if best_match != - 1 :
292+ return - 1
293+ best_match = idx
294+ return best_match
295+
296+ rc : List [int ] = []
297+ for candidate in candidates :
298+ idx = uuid_to_orinal (candidate , uuids )
299+ # First invalid ordinal stops parsing
300+ if idx < 0 :
301+ break
302+ # Duplicates result in empty set
303+ if idx in rc :
304+ return cast (List [int ], [])
305+ rc .append (idx )
306+ return rc
307+
308+
309+ # TODO: Remove once minimum supported PyTorch version is 2.0
229310def _device_count_nvml () -> int :
230- """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
311+ """Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account.
312+
313+ Negative value is returned if NVML discovery or initialization has failed.
314+ """
315+ visible_devices = _parse_visible_devices ()
316+ if not visible_devices :
317+ return 0
231318 try :
232- raw_cnt = _raw_device_count_nvml ()
233- if raw_cnt <= 0 :
234- return raw_cnt
235- return len (set (range (raw_cnt )).intersection (_parse_visible_devices ()))
319+ if type (visible_devices [0 ]) is str :
320+ # Skip MIG parsing
321+ if visible_devices [0 ].startswith ("MIG-" ):
322+ return - 1
323+ uuids = _raw_device_uuid_nvml ()
324+ if uuids is None :
325+ return - 1
326+ visible_devices = _transform_uuid_to_ordinals (cast (List [str ], visible_devices ), uuids )
327+ else :
328+ raw_cnt = _raw_device_count_nvml ()
329+ if raw_cnt <= 0 :
330+ return raw_cnt
331+ # Trim the list up to a maximum available device
332+ for idx , val in enumerate (visible_devices ):
333+ if cast (int , val ) >= raw_cnt :
334+ return idx
236335 except OSError :
237336 return - 1
238337 except AttributeError :
239338 return - 1
339+ return len (visible_devices )
240340
241341
242342def _check_cuda_matmul_precision (device : torch .device ) -> None :
0 commit comments