1212import socket
1313import uuid
1414
15+ import torch
16+
1517from monarch ._src .actor .actor_mesh import ActorMesh
1618from monarch ._src .actor .shape import Extent
1719
@@ -40,8 +42,14 @@ class _RemoteInfoFetcher(Actor):
4042 """An actor responsible for getting remote host information."""
4143
4244 @endpoint
43- def get_info (self ) -> tuple [str , str ]:
44- return socket .gethostname (), _get_port ()
45+ def get_info (self ) -> tuple [str , str , int ]:
46+ """Returns hostname, port, and GPU count."""
47+ try :
48+ gpu_count = torch .cuda .device_count ()
49+ except Exception :
50+ # If torch is not available or CUDA is not available, assume no GPUs
51+ gpu_count = 0
52+ return socket .gethostname (), _get_port (), gpu_count
4553
4654
4755class EnvSetter (Actor ):
@@ -77,8 +85,8 @@ def set_env(self, env_vars: dict[str, str]):
7785 os .environ [k ] = v
7886
7987
80- async def get_remote_info (host_mesh : HostMesh ) -> tuple [str , str ]:
81- """Returns the host name and port of the host mesh."""
88+ async def get_host_info (host_mesh : HostMesh ) -> tuple [str , str , int ]:
89+ """Returns the host name, port, and GPU count of the host mesh."""
8290 throwaway_procs = host_mesh .spawn_procs (per_host = {"procs" : 1 })
8391 fetcher = throwaway_procs .spawn ("_fetcher" , _RemoteInfoFetcher )
8492
@@ -88,11 +96,11 @@ async def get_remote_info(host_mesh: HostMesh) -> tuple[str, str]:
8896 fetcher = fetcher .slice (** singleton_slice )
8997 # Fetcher should be a singleton at this point - call_one() will fail otherwise
9098
91- host , port = await fetcher .get_info .call_one ()
99+ host , port , gpu_count = await fetcher .get_info .call_one ()
92100
93101 # Stopping this proc is the right thing to do, but Monarch does not yet handle manual stops well.
94102 # await throwaway_procs.stop()
95- return host , port
103+ return host , port , gpu_count
96104
97105
98106async def set_environment (proc_mesh : ProcMesh , env_vars : dict [str , str ]):
@@ -110,14 +118,37 @@ async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
110118
111119
112120class GpuManager :
113- """Tracks and assigns GPU devices on a host."""
121+ """Tracks and assigns GPU devices on a host.
122+
123+ Args:
124+ available_devices: Set of GPU device IDs to manage. If None, uses all devices from 0 to max_device_count-1.
125+ max_device_count: Maximum number of GPU devices on this host. Defaults to 8.
126+
127+ """
114128
115- def __init__ (self , available_devices : set [int ] | None = None ):
129+ def __init__ (
130+ self , available_devices : set [int ] | None = None , max_device_count : int = 8
131+ ):
116132 if available_devices is None :
117- available_devices = set (range (0 , 8 ))
118- assert all (isinstance (x , int ) for x in available_devices )
119- assert all (x >= 0 and x < 8 for x in available_devices )
133+ available_devices = set (range (0 , max_device_count ))
134+ else :
135+ # Validate types first
136+ assert all (
137+ isinstance (x , int ) for x in available_devices
138+ ), f"All device IDs must be integers, got: { available_devices } "
139+ # When available_devices is provided (e.g., from CUDA_VISIBLE_DEVICES),
140+ # adjust max_device_count to accommodate the highest device ID
141+ if available_devices :
142+ max_device_count = max (max (available_devices ) + 1 , max_device_count )
143+
144+ assert all (
145+ isinstance (x , int ) for x in available_devices
146+ ), f"All device IDs must be integers, got: { available_devices } "
147+ assert all (
148+ x >= 0 for x in available_devices
149+ ), f"All device IDs must be non-negative, got: { available_devices } "
120150 self .available_gpus = available_devices
151+ self .max_device_count = max_device_count
121152
122153 def get_available_gpus (self ) -> list [str ]:
123154 """Returns a list of available GPU devices."""
@@ -166,8 +197,18 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
166197 f"Invalid CUDA_VISIBLE_DEVICES format: '{ cuda_visible_devices } '. "
167198 f"Expected comma-separated integers (e.g., '0,1,2'). Error: { e } "
168199 ) from e
200+
201+ # Get the actual GPU count for the local host
202+ try :
203+ local_gpu_count = torch .cuda .device_count ()
204+ except Exception :
205+ # If torch is not available or CUDA is not available, assume no GPUs
206+ local_gpu_count = 0
207+
169208 self ._host_gpu_map = {
170- self ._this_host_id : GpuManager (available_local_devices ),
209+ self ._this_host_id : GpuManager (
210+ available_local_devices , max_device_count = local_gpu_count
211+ ),
171212 }
172213 self ._proc_host_map = {}
173214 self ._host_mesh_map = {}
@@ -272,9 +313,18 @@ async def get_proc_mesh(
272313 num_hosts = num_hosts ,
273314 )
274315 host_id = uuid .uuid1 ()
275- gpu_manager = GpuManager ()
316+ # Get host info including GPU count from the remote host
317+ host_addr , host_port , remote_gpu_count = await get_host_info (
318+ host_mesh
319+ )
320+ gpu_manager = GpuManager (max_device_count = remote_gpu_count )
276321 self ._host_gpu_map [host_id ] = gpu_manager
277322 host_mesh ._host_id = host_id
323+ # Use the fetched addr/port if not explicitly provided
324+ if addr is None :
325+ addr = host_addr
326+ if port is None :
327+ port = host_port
278328 else :
279329 host_id = host_mesh ._host_id
280330 gpu_manager = self ._host_gpu_map [host_id ]
@@ -286,7 +336,7 @@ async def get_proc_mesh(
286336
287337 if with_gpus :
288338 if not addr or not port :
289- addr , port = await get_remote_info (host_mesh )
339+ addr , port , _ = await get_host_info (host_mesh )
290340 gpu_ids = gpu_manager .get_gpus (num_procs )
291341
292342 env_vars ["MASTER_ADDR" ] = addr
0 commit comments