11"""Computational configuration for 2DTM."""
22
3- from typing import Annotated
3+ from typing import Annotated , Optional , Union
44
55import torch
6- from pydantic import BaseModel , Field , field_validator
6+ from pydantic import BaseModel , Field
7+
8+ # Type alias for non-negative integer
9+ NonNegativeInt = Annotated [int , Field (ge = 0 )]
710
811
912class ComputationalConfig (BaseModel ):
1013 """Serialization of computational resources allocated for 2DTM.
1114
15+ NOTE: The field `gpu_ids` is not validated at instantiation past being one of the
16+ valid types. For example, if "cuda:0" is specified but no CUDA device is available,
17+ the instantiation will succeed, and only upon translating `gpu_ids` to a list of
18+ `torch.device` objects will an error be raised. This is done to allow for
19+ configuration files to be loaded without requiring the actual hardware to be
20+ present at the time of loading.
21+
1222 Attributes
1323 ----------
14- gpu_ids : list[int]
15- Which GPU(s) to use for computation, defaults to 0 which will use device at
16- index 0. A value of -2 or less corresponds to CPU device. A value of -1 will
17- use all available GPUs.
24+ gpu_ids : Optional[Union[int, list[int], str, list[str]]]
25+ Field which specifies which GPUs to use for computation. The following types
26+ of values are allowed:
27+ - A single integer, e.g. 0, which means to use GPU with ID 0.
28+ - A list of integers, e.g. [0, 2], which means to use GPUs with IDs 0 and 2.
29+ - A device specifier string, e.g. "cuda:0", which means to use GPU with ID 0.
30+ - A list of device specifier strings, e.g. ["cuda:0", "cuda:1"], which means to
31+ use GPUs with IDs 0 and 1.
32+ - The specific string "all" which means to use all available GPUs identified
33+ by torch.cuda.device_count().
34+ - The specific string "cpu" which means to use CPU.
1835 num_cpus : int
1936 Total number of CPUs to use, defaults to 1.
2037 """
2138
22- gpu_ids : int | list [int ] = [0 ]
23- num_cpus : Annotated [int , Field (ge = 1 )] = 1
24-
25- @field_validator ("gpu_ids" ) # type: ignore
26- def validate_gpu_ids (cls , v ): # pylint: disable=no-self-argument
27- """Validate input value for GPU ids."""
28- if isinstance (v , int ):
29- v = [v ]
30-
31- # Check if -1 appears, it is only value in list
32- if - 1 in v and len (v ) > 1 :
33- raise ValueError (
34- "If -1 (all GPUs) is in the list, it must be the only value."
35- )
36-
37- # Check if -2 appears, it is only value in list
38- if - 2 in v and len (v ) > 1 :
39- raise ValueError ("If -2 (CPU) is in the list, it must be the only value." )
40-
41- return v
39+ # Type-hinting here is ensuring non-negative integers, and list of at least one
40+ gpu_ids : Optional [
41+ Union [
42+ str ,
43+ NonNegativeInt ,
44+ Annotated [list [NonNegativeInt ], Field (min_length = 1 )],
45+ Annotated [list [str ], Field (min_length = 1 )],
46+ ]
47+ ] = [0 ]
48+ num_cpus : NonNegativeInt = 1
4249
4350 @property
4451 def gpu_devices (self ) -> list [torch .device ]:
@@ -48,13 +55,29 @@ def gpu_devices(self) -> list[torch.device]:
4855 -------
4956 list[torch.device]
5057 """
51- # Case where gpu_ids is integer
52- if isinstance (self .gpu_ids , int ):
53- self .gpu_ids = [self .gpu_ids ]
54-
55- if - 1 in self .gpu_ids :
58+ # Handle special string cases first
59+ if self .gpu_ids == "all" :
60+ if not torch .cuda .is_available ():
61+ raise ValueError ("No CUDA devices available." )
5662 return [torch .device (f"cuda:{ i } " ) for i in range (torch .cuda .device_count ())]
57- if - 2 in self .gpu_ids :
63+
64+ if self .gpu_ids == "cpu" :
5865 return [torch .device ("cpu" )]
5966
60- return [torch .device (f"cuda:{ gpu_id } " ) for gpu_id in self .gpu_ids ]
67+ # Normalize to list for uniform processing
68+ gpu_list = self .gpu_ids if isinstance (self .gpu_ids , list ) else [self .gpu_ids ]
69+
70+ # Process each item in the normalized list
71+ devices = []
72+ for gpu_id in gpu_list :
73+ if isinstance (gpu_id , int ):
74+ devices .append (torch .device (f"cuda:{ gpu_id } " ))
75+ elif isinstance (gpu_id , str ):
76+ devices .append (torch .device (gpu_id ))
77+ else :
78+ raise TypeError (
79+ f"Invalid type for gpu_ids element: { type (gpu_id )} . "
80+ "Expected int or str."
81+ )
82+
83+ return devices
0 commit comments