@@ -80,12 +80,14 @@ def __init__(
80
80
hy_proc_mesh : "Shared[HyProcMesh]" ,
81
81
host_mesh : "HostMesh" ,
82
82
region : Region ,
83
+ root_region : Region ,
83
84
_device_mesh : Optional ["DeviceMesh" ] = None ,
84
85
) -> None :
85
86
_proc_mesh_registry .add (self )
86
87
self ._proc_mesh = hy_proc_mesh
87
88
self ._host_mesh = host_mesh
88
89
self ._region = region
90
+ self ._root_region = root_region
89
91
self ._maybe_device_mesh = _device_mesh
90
92
self ._logging_manager = LoggingManager ()
91
93
self ._controller_controller : Optional ["_ControllerController" ] = None
@@ -107,7 +109,16 @@ async def task() -> Literal[True]:
107
109
108
110
@property
109
111
def host_mesh (self ) -> "HostMesh" :
110
- return self ._host_mesh
112
+ if self .extent .nelements != 1 :
113
+ raise NotImplementedError (
114
+ "`ProcMesh.host_mesh` is not yet supported for non-singleton proc meshes."
115
+ )
116
+ elif self ._host_mesh .is_fake_in_process :
117
+ from monarch ._src .actor .v1 .host_mesh import create_local_host_mesh
118
+
119
+ return create_local_host_mesh ("root_host" )
120
+ else :
121
+ return self ._host (0 )
111
122
112
123
@property
113
124
def _ndslice (self ) -> Slice :
@@ -134,6 +145,7 @@ async def task() -> HyProcMesh:
134
145
PythonTask .from_coroutine (task ()).spawn (),
135
146
self ._host_mesh ,
136
147
shape .region ,
148
+ self ._root_region ,
137
149
_device_mesh = device_mesh ,
138
150
)
139
151
@@ -176,7 +188,7 @@ def from_host_mesh(
176
188
setup : Callable [[], None ] | None = None ,
177
189
_attach_controller_controller : bool = True ,
178
190
) -> "ProcMesh" :
179
- pm = ProcMesh (hy_proc_mesh , host_mesh , region )
191
+ pm = ProcMesh (hy_proc_mesh , host_mesh , region , region )
180
192
181
193
if _attach_controller_controller :
182
194
instance = context ().actor_instance
@@ -341,7 +353,11 @@ async def __aexit__(
341
353
342
354
@classmethod
343
355
def _from_initialized_hy_proc_mesh (
344
- cls , hy_proc_mesh : HyProcMesh , host_mesh : "HostMesh" , region : Region
356
+ cls ,
357
+ hy_proc_mesh : HyProcMesh ,
358
+ host_mesh : "HostMesh" ,
359
+ region : Region ,
360
+ root_region : Region ,
345
361
) -> "ProcMesh" :
346
362
async def task () -> HyProcMesh :
347
363
return hy_proc_mesh
@@ -350,13 +366,25 @@ async def task() -> HyProcMesh:
350
366
PythonTask .from_coroutine (task ()).spawn (),
351
367
host_mesh ,
352
368
region ,
369
+ root_region ,
353
370
)
354
371
355
372
def __reduce_ex__ (self , protocol : ...) -> Tuple [Any , Tuple [Any , ...]]:
356
373
return ProcMesh ._from_initialized_hy_proc_mesh , (
357
374
self ._proc_mesh .block_on (),
358
- self .host_mesh ,
375
+ self ._host_mesh ,
359
376
self ._region ,
377
+ self ._root_region ,
378
+ )
379
+
380
+ def _host (self , proc_rank : int ) -> "HostMesh" :
381
+ base_proc_rank = self ._region .slice ().get (proc_rank )
382
+ n_procs = len (self ._root_region .slice ())
383
+ procs_per_host = n_procs // len (self ._host_mesh .region .slice ())
384
+ host_rank = base_proc_rank // procs_per_host
385
+ base_host_rank = self ._host_mesh .region .slice ().get (host_rank )
386
+ return self ._host_mesh .slice (
387
+ ** self ._host_mesh .region .point_of_base_rank (base_host_rank )
360
388
)
361
389
362
390
0 commit comments