@@ -42,14 +42,19 @@ class _RemoteInfoFetcher(Actor):
4242 """An actor responsible for getting remote host information."""
4343
4444 @endpoint
45- def get_info (self ) -> tuple [str , str , int ]:
46- """Returns hostname, port, and GPU count."""
45+ def get_info (self ) -> tuple [str , str ]:
46+ """Returns hostname and port."""
47+ return socket .gethostname (), _get_port ()
48+
49+ @endpoint
50+ def get_gpu_count (self ) -> int :
51+ """Returns the number of GPUs available on this host."""
4752 try :
4853 gpu_count = torch .cuda .device_count ()
4954 except Exception :
5055 # If torch is not available or CUDA is not available, assume no GPUs
5156 gpu_count = 0
52- return socket . gethostname (), _get_port (), gpu_count
57+ return gpu_count
5358
5459
5560class EnvSetter (Actor ):
@@ -85,8 +90,8 @@ def set_env(self, env_vars: dict[str, str]):
8590 os .environ [k ] = v
8691
8792
88- async def get_host_info (host_mesh : HostMesh ) -> tuple [str , str , int ]:
89- """Returns the host name, port, and GPU count of the host mesh."""
93+ async def get_remote_info (host_mesh : HostMesh ) -> tuple [str , str ]:
94+ """Returns the host name and port of the host mesh."""
9095 throwaway_procs = host_mesh .spawn_procs (per_host = {"procs" : 1 })
9196 fetcher = throwaway_procs .spawn ("_fetcher" , _RemoteInfoFetcher )
9297
@@ -95,12 +100,24 @@ async def get_host_info(host_mesh: HostMesh) -> tuple[str, str, int]:
95100 singleton_slice = {k : slice (0 , 1 ) for k in fetcher .extent .keys ()}
96101 fetcher = fetcher .slice (** singleton_slice )
97102 # Fetcher should be a singleton at this point - call_one() will fail otherwise
98-
99- host , port , gpu_count = await fetcher .get_info .call_one ()
103+ host , port = await fetcher .get_info .call_one ()
100104
101105 # Stopping this proc is the right thing to do, but Monarch does not yet handle manual stops well.
102106 # await throwaway_procs.stop()
103- return host , port , gpu_count
107+ return host , port
108+
109+
110+ async def get_host_gpus (host_mesh : HostMesh ) -> int :
111+ """Returns the number of GPUs available on the host mesh."""
112+ throwaway_procs = host_mesh .spawn_procs (per_host = {"procs" : 1 })
113+ fetcher = throwaway_procs .spawn ("_gpu_counter" , _RemoteInfoFetcher )
114+
115+ # Reduce to a singleton
116+ singleton_slice = {k : slice (0 , 1 ) for k in fetcher .extent .keys ()}
117+ fetcher = fetcher .slice (** singleton_slice )
118+
119+ gpu_count = await fetcher .get_gpu_count .call_one ()
120+ return gpu_count
104121
105122
106123async def set_environment (proc_mesh : ProcMesh , env_vars : dict [str , str ]):
@@ -313,18 +330,11 @@ async def get_proc_mesh(
313330 num_hosts = num_hosts ,
314331 )
315332 host_id = uuid .uuid1 ()
316- # Get host info including GPU count from the remote host
317- host_addr , host_port , remote_gpu_count = await get_host_info (
318- host_mesh
319- )
333+ # Get the GPU count from the remote host
334+ remote_gpu_count = await get_host_gpus (host_mesh )
320335 gpu_manager = GpuManager (max_device_count = remote_gpu_count )
321336 self ._host_gpu_map [host_id ] = gpu_manager
322337 host_mesh ._host_id = host_id
323- # Use the fetched addr/port if not explicitly provided
324- if addr is None :
325- addr = host_addr
326- if port is None :
327- port = host_port
328338 else :
329339 host_id = host_mesh ._host_id
330340 gpu_manager = self ._host_gpu_map [host_id ]
@@ -336,7 +346,7 @@ async def get_proc_mesh(
336346
337347 if with_gpus :
338348 if not addr or not port :
339- addr , port , _ = await get_host_info (host_mesh )
349+ addr , port = await get_remote_info (host_mesh )
340350 gpu_ids = gpu_manager .get_gpus (num_procs )
341351
342352 env_vars ["MASTER_ADDR" ] = addr
0 commit comments