1313from dataclasses_json import dataclass_json
1414
1515from .. import envs
16- from ..detector import detect_backend
16+ from ..detector import (
17+ Devices ,
18+ ManufacturerEnum ,
19+ detect_devices ,
20+ manufacturer_to_backend ,
21+ )
1722from .__utils__ import correct_runner_image , safe_json , safe_yaml
1823
1924if TYPE_CHECKING :
@@ -1088,47 +1093,146 @@ class Deployer(ABC):
10881093 """
10891094 Thread pool for the deployer.
10901095 """
1091- _runtime_visible_devices_env_name : str | None = None
1096+ _visible_devices_env : dict [ str , list [ str ]] | None = None
10921097 """
1093- Recorded backend visible devices env name,
1094- such as "NVIDIA_VISIBLE_DEVICES", "AMD_VISIBLE_DEVICES", etc.
1095- If failed to detect backend, it will be "UNKNOWN_VISIBLE_DEVICES".
1098+ Recorded visible devices envs,
1099+ the key is the runtime visible devices env name,
1100+ the value is the list of backend visible devices env names.
1101+ For example:
1102+ {
1103+ "NVIDIA_VISIBLE_DEVICES": ["CUDA_VISIBLE_DEVICES"],
1104+ "AMD_VISIBLE_DEVICES": ["HIP_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES"]
1105+ }.
10961106 """
1097- _backend_visible_devices_env_names : list [str ] | None = None
1107+ _visible_devices_values : dict [ str , list [str ] ] | None = None
10981108 """
1099- Recorded runtime visible devices env name list,
1100- such as ["CUDA_VISIBLE_DEVICES"], ["ROCR_VISIBLE_DEVICES"], etc.
1101- If failed to detect backend, it will be ["UNKNOWN_VISIBLE_DEVICES"].
1109+ Recorded visible devices values,
1110+ the key is the runtime visible devices env name,
1111+ the value is the list of device indexes or uuids.
1112+ For example:
1113+ {
1114+ "NVIDIA_VISIBLE_DEVICES": ["0"],
1115+ "AMD_VISIBLE_DEVICES": ["0", "1"]
1116+ }.
11021117 """
11031118
1119+ @staticmethod
1120+ @abstractmethod
1121+ def is_supported () -> bool :
1122+ """
1123+ Check if the deployer is supported in the current environment.
1124+
1125+ Returns:
1126+ True if supported, False otherwise.
1127+
1128+ """
1129+ raise NotImplementedError
1130+
1131+ @staticmethod
1132+ def _default_args (func ):
1133+ def wrapper (self , * args , async_mode = None , ** kwargs ):
1134+ if async_mode is None :
1135+ async_mode = envs .GPUSTACK_RUNTIME_DEPLOY_ASYNC
1136+ return func (
1137+ self ,
1138+ * args ,
1139+ async_mode = async_mode ,
1140+ ** kwargs ,
1141+ )
1142+
1143+ return wrapper
1144+
11041145 def __init__ (self , name : str ):
11051146 self ._name = name
1106- self ._runtime_visible_devices_env_name = (
1107- "UNKNOWN_RUNTIME_BACKEND_VISIBLE_DEVICES"
1108- )
1109- self ._backend_visible_devices_env_names = []
1110-
1111- if backend := detect_backend ():
1112- rk = envs .GPUSTACK_RUNTIME_DETECT_BACKEND_MAP_RESOURCE_KEY .get (backend )
1113- ren = envs .GPUSTACK_RUNTIME_DEPLOY_RESOURCE_KEY_MAP_RUNTIME_VISIBLE_DEVICES .get (
1114- rk ,
1115- )
1116- ben = envs .GPUSTACK_RUNTIME_DEPLOY_RESOURCE_KEY_MAP_BACKEND_VISIBLE_DEVICES .get (
1117- rk ,
1118- )
1119- if ren :
1120- self ._runtime_visible_devices_env_name = ren
1121- if ben :
1122- self ._backend_visible_devices_env_names = ben
11231147
11241148 def __enter__ (self ):
11251149 return self
11261150
11271151 def __exit__ (self , exc_type , exc_value , traceback ):
11281152 self .close ()
11291153
1154+ def _fetch_visible_devices_env_values (self ):
1155+ """
1156+ Fetch the visible devices environment variables and values.
1157+ """
1158+ if self ._visible_devices_env :
1159+ return
1160+
1161+ self ._visible_devices_env = {}
1162+ self ._visible_devices_values = {}
1163+
1164+ devices : dict [ManufacturerEnum , Devices ] = {}
1165+ for dev in detect_devices (fast = False ):
1166+ if dev .manufacturer not in devices :
1167+ devices [dev .manufacturer ] = []
1168+ devices [dev .manufacturer ].append (dev )
1169+
1170+ if devices :
1171+ value_with_index = (
1172+ envs .GPUSTACK_RUNTIME_DEPLOY_RUNTIME_VISIBLE_DEVICES_VALUE_MODE .lower ()
1173+ == "index"
1174+ )
1175+
1176+ for manu , devs in devices .items ():
1177+ backend = manufacturer_to_backend (manu )
1178+ rk = envs .GPUSTACK_RUNTIME_DETECT_BACKEND_MAP_RESOURCE_KEY .get (backend )
1179+ ren = envs .GPUSTACK_RUNTIME_DEPLOY_RESOURCE_KEY_MAP_RUNTIME_VISIBLE_DEVICES .get (
1180+ rk ,
1181+ )
1182+ ben = envs .GPUSTACK_RUNTIME_DEPLOY_RESOURCE_KEY_MAP_BACKEND_VISIBLE_DEVICES .get (
1183+ rk ,
1184+ )
1185+ if ren and ben :
1186+ self ._visible_devices_env [ren ] = ben
1187+ self ._visible_devices_values [ren ] = [
1188+ (str (dev .index ) if value_with_index else dev .uuid )
1189+ for dev in devs
1190+ ]
1191+
1192+ if self ._visible_devices_env :
1193+ return
1194+
1195+ # Fallback to unknown backend
1196+ self ._visible_devices_env ["UNKNOWN_RUNTIME_VISIBLE_DEVICES" ] = []
1197+ self ._visible_devices_values ["UNKNOWN_RUNTIME_VISIBLE_DEVICES" ] = ["all" ]
1198+
1199+ def visible_devices_env_values (
1200+ self ,
1201+ ) -> (dict [str , list [str ]], dict [str , list [str ]]):
1202+ """
1203+ Return the visible devices environment variables and values mappings.
1204+ For example:
1205+ (
1206+ {
1207+ "NVIDIA_VISIBLE_DEVICES": ["CUDA_VISIBLE_DEVICES"],
1208+ "AMD_VISIBLE_DEVICES": ["HIP_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES"]
1209+ }.
1210+ {
1211+ "NVIDIA_VISIBLE_DEVICES": ["0"],
1212+ "AMD_VISIBLE_DEVICES": ["0", "1"]
1213+ }
1214+ ).
1215+
1216+ Returns:
1217+ A tuple of two dictionaries:
1218+ - The first dictionary maps runtime visible devices environment variable names
1219+ to lists of backend visible devices environment variable names.
1220+ - The second dictionary maps runtime visible devices environment variable names
1221+ to lists of device indexes or UUIDs.
1222+
1223+ """
1224+ self ._fetch_visible_devices_env_values ()
1225+ return self ._visible_devices_env , self ._visible_devices_values
1226+
11301227 @property
11311228 def name (self ) -> str :
1229+ """
1230+ Return the name of the deployer.
1231+
1232+ Returns:
1233+ The name of the deployer.
1234+
1235+ """
11321236 return self ._name
11331237
11341238 def close (self ):
@@ -1152,32 +1256,6 @@ def pool(self):
11521256 self ._pool = ThreadPoolExecutor (max_workers = pool_threads )
11531257 return self ._pool
11541258
1155- @staticmethod
1156- @abstractmethod
1157- def is_supported () -> bool :
1158- """
1159- Check if the deployer is supported in the current environment.
1160-
1161- Returns:
1162- True if supported, False otherwise.
1163-
1164- """
1165- raise NotImplementedError
1166-
1167- @staticmethod
1168- def _default_args (func ):
1169- def wrapper (self , * args , async_mode = None , ** kwargs ):
1170- if async_mode is None :
1171- async_mode = envs .GPUSTACK_RUNTIME_DEPLOY_ASYNC
1172- return func (
1173- self ,
1174- * args ,
1175- async_mode = async_mode ,
1176- ** kwargs ,
1177- )
1178-
1179- return wrapper
1180-
11811259 @_default_args
11821260 def create (
11831261 self ,
0 commit comments