@@ -294,6 +294,11 @@ def bootstrap(env: dict[str, str]):
294294 per_host = {"procs" : num_procs },
295295 bootstrap = functools .partial (bootstrap , env = env_vars ),
296296 )
297+ uid = str (uuid .uuid4 ())
298+ # Generate a unique ID to map procmesh to hostmesh
299+ procs ._uid = uid
300+ print (f"Allocating procmesh with uid={ uid } " )
301+ print (f"Allocating procs._uid: { procs ._uid } " )
297302
298303 if with_gpus :
299304 # Set up environment variables for PyTorch distributed...
@@ -319,28 +324,35 @@ def bootstrap(env: dict[str, str]):
319324 self ._server_names .append (server_name )
320325 self ._proc_server_map [procs ] = server_name
321326
322- self ._proc_host_map [procs ] = host_mesh
327+ self ._proc_host_map [uid ] = host_mesh
323328
324329 # Spawn LocalFetcherActor for this ProcMesh and register with GlobalLoggingActor.
325330 # When called, the LocalFetcherActor is broadcast by Monarch to all ranks in the ProcMesh.
326331 if not FORGE_DISABLE_METRICS .get_value ():
327332 from forge .observability .metric_actors import get_or_create_metric_logger
328333
329334 _ = await get_or_create_metric_logger (procs , process_name = mesh_name )
330- return procs
335+
336+ print (f"Returning procmesh with uid={ uid } " )
337+ print (f"Returning procs._uid: { procs ._uid } " )
338+ return procs , uid
331339
332340 @endpoint
333- async def host_mesh_from_proc (self , proc_mesh : ProcMesh ):
334- if proc_mesh not in self ._proc_host_map :
341+ async def host_mesh_from_proc (self , uid : str | None ):
342+ # uid: str | None = getattr(proc_mesh, "_uid", None)
343+ print (f"self._proc_host_map: { self ._proc_host_map } " )
344+ print (f"proc_mesh._uid: { uid } " )
345+ if uid is None or uid not in self ._proc_host_map :
335346 raise ValueError (
336347 "The proc mesh was not allocated with an associated hostmesh."
337348 )
338- return self ._proc_host_map [proc_mesh ]
349+ return self ._proc_host_map [uid ]
339350
340351 @endpoint
341352 async def stop_proc_mesh (self , proc_mesh : ProcMesh ):
342353 """Stops a proc mesh."""
343- if proc_mesh not in self ._proc_host_map :
354+ uid : str | None = getattr (proc_mesh , "_uid" , None )
355+ if uid is None or uid not in self ._proc_host_map :
344356 logger .warning (
345357 f"proc mesh { proc_mesh } was requested to be stopped, but was either already stopped or "
346358 "was never registered with the provisioner."
@@ -363,7 +375,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh):
363375 if proc_mesh in self ._proc_server_map :
364376 server_name = self ._proc_server_map [proc_mesh ]
365377 commands .kill (server_name )
366- del self ._proc_host_map [proc_mesh ]
378+ del self ._proc_host_map [uid ]
367379
368380 @endpoint
369381 def register_service (self , service : "ServiceInterface" ) -> None :
@@ -464,7 +476,7 @@ async def get_proc_mesh(
464476
465477 """
466478 provisioner = await get_or_create_provisioner ()
467- return await provisioner .get_proc_mesh .call_one (
479+ procs , uid = await provisioner .get_proc_mesh .call_one (
468480 num_procs = process_config .procs ,
469481 with_gpus = process_config .with_gpus ,
470482 num_hosts = process_config .hosts ,
@@ -474,17 +486,20 @@ async def get_proc_mesh(
474486 port = port ,
475487 addr = addr ,
476488 )
489+ setattr (procs , "_uid" , uid )
490+ print (f"Setting procs._uid: { procs ._uid } " )
491+ return procs
477492
478493
479- async def host_mesh_from_proc (proc_mesh : ProcMesh ):
494+ async def host_mesh_from_proc (uid : str | None ):
480495 """Returns the host mesh that allocated the original proc_mesh.
481496
482497 This functionality will be enabled in Monarch, so this is a temporary
483498 API.
484499
485500 """
486501 provisioner = await get_or_create_provisioner ()
487- return await provisioner .host_mesh_from_proc .call_one (proc_mesh )
502+ return await provisioner .host_mesh_from_proc .call_one (uid )
488503
489504
490505async def register_service (service : "ServiceInterface" ) -> None :
0 commit comments