@@ -133,6 +133,7 @@ def __init__(
133
133
self ,
134
134
hy_proc_mesh : "Shared[HyProcMesh]" ,
135
135
shape : Shape ,
136
+ _fork_processes : bool ,
136
137
_device_mesh : Optional ["DeviceMesh" ] = None ,
137
138
) -> None :
138
139
self ._proc_mesh = hy_proc_mesh
@@ -146,6 +147,7 @@ def __init__(
146
147
self ._code_sync_client : Optional [CodeSyncMeshClient ] = None
147
148
self ._logging_mesh_client : Optional [LoggingMeshClient ] = None
148
149
self ._maybe_device_mesh : Optional ["DeviceMesh" ] = _device_mesh
150
+ self ._fork_processes = _fork_processes
149
151
self ._stopped = False
150
152
151
153
@property
@@ -163,41 +165,50 @@ async def task() -> Literal[True]:
163
165
164
166
return Future (coro = task ())
165
167
166
- def _init_manager_actors (self , setup : Callable [[], None ] | None = None ) -> None :
168
+ def _init_manager_actors (
169
+ self , setup : Callable [[], None ] | None = None , _fork_processes : bool = True
170
+ ) -> None :
167
171
self ._proc_mesh = PythonTask .from_coroutine (
168
- self ._init_manager_actors_coro (self ._proc_mesh , setup )
172
+ self ._init_manager_actors_coro (self ._proc_mesh , setup , _fork_processes )
169
173
).spawn ()
170
174
171
175
async def _init_manager_actors_coro (
172
176
self ,
173
177
proc_mesh_ : "Shared[HyProcMesh]" ,
174
178
setup : Callable [[], None ] | None = None ,
179
+ _fork_processes : bool = True ,
175
180
) -> "HyProcMesh" :
176
181
proc_mesh : HyProcMesh = await proc_mesh_
177
182
# WARNING: it is unsafe to await self._proc_mesh here
178
183
# because self._proc_mesh is the result of this function itself!
179
184
180
- self ._logging_mesh_client = await LoggingMeshClient .spawn (proc_mesh = proc_mesh )
181
- self ._logging_mesh_client .set_mode (
182
- stream_to_client = True ,
183
- aggregate_window_sec = 3 ,
184
- level = logging .INFO ,
185
- )
186
- if HAS_IPYTHON and get_ipython () is not None :
187
- # For ipython environment, a cell can end fast with threads running in background.
188
- # Flush all the ongoing logs proactively to avoid missing logs.
189
- assert self ._logging_mesh_client is not None
190
- logging_client : LoggingMeshClient = self ._logging_mesh_client
191
- ipython = get_ipython ()
185
+ if _fork_processes :
186
+ # logging mesh is only makes sense with forked (remote or local) processes
187
+ self ._logging_mesh_client = await LoggingMeshClient .spawn (
188
+ proc_mesh = proc_mesh
189
+ )
190
+ self ._logging_mesh_client .set_mode (
191
+ stream_to_client = True ,
192
+ aggregate_window_sec = 3 ,
193
+ level = logging .INFO ,
194
+ )
195
+ if HAS_IPYTHON and get_ipython () is not None :
196
+ # For ipython environment, a cell can end fast with threads running in background.
197
+ # Flush all the ongoing logs proactively to avoid missing logs.
198
+ assert self ._logging_mesh_client is not None
199
+ logging_client : LoggingMeshClient = self ._logging_mesh_client
200
+ ipython = get_ipython ()
192
201
193
- # pyre-ignore[21]
194
- from IPython .core .interactiveshell import ExecutionResult
202
+ # pyre-ignore[21]
203
+ from IPython .core .interactiveshell import ExecutionResult
195
204
196
- # pyre-ignore[11]
197
- def flush_logs (_ : ExecutionResult ) -> None :
198
- return Future (coro = logging_client .flush (proc_mesh ).spawn ().task ()).get ()
205
+ # pyre-ignore[11]
206
+ def flush_logs (_ : ExecutionResult ) -> None :
207
+ return Future (
208
+ coro = logging_client .flush (proc_mesh ).spawn ().task ()
209
+ ).get ()
199
210
200
- ipython .events .register ("post_run_cell" , flush_logs )
211
+ ipython .events .register ("post_run_cell" , flush_logs )
201
212
202
213
_rdma_manager = (
203
214
# type: ignore[16]
@@ -239,7 +250,12 @@ def _new_with_shape(self, shape: Shape) -> "ProcMesh":
239
250
if self ._maybe_device_mesh is None
240
251
else self ._device_mesh ._new_with_shape (shape )
241
252
)
242
- pm = ProcMesh (self ._proc_mesh , shape , _device_mesh = device_mesh )
253
+ pm = ProcMesh (
254
+ self ._proc_mesh ,
255
+ shape ,
256
+ _device_mesh = device_mesh ,
257
+ _fork_processes = self ._fork_processes ,
258
+ )
243
259
pm ._slice = True
244
260
return pm
245
261
@@ -284,6 +300,7 @@ def from_alloc(
284
300
alloc : AllocHandle ,
285
301
setup : Callable [[], None ] | None = None ,
286
302
_init_manager_actors : bool = True ,
303
+ _fork_processes : bool = True ,
287
304
) -> "ProcMesh" :
288
305
"""
289
306
Allocate a process mesh according to the provided alloc.
@@ -311,10 +328,10 @@ async def task() -> HyProcMesh:
311
328
list (alloc ._extent .keys ()),
312
329
Slice .new_row_major (list (alloc ._extent .values ())),
313
330
)
314
- pm = ProcMesh (PythonTask .from_coroutine (task ()).spawn (), shape )
331
+ pm = ProcMesh (PythonTask .from_coroutine (task ()).spawn (), shape , _fork_processes )
315
332
316
333
if _init_manager_actors :
317
- pm ._init_manager_actors (setup )
334
+ pm ._init_manager_actors (setup , _fork_processes )
318
335
return pm
319
336
320
337
def __repr__ (self ) -> str :
@@ -420,6 +437,11 @@ async def logging_option(
420
437
Returns:
421
438
None
422
439
"""
440
+ if not self ._fork_processes :
441
+ raise RuntimeError (
442
+ "Logging option is only available for allocators that fork processes. Allocators like LocalAllocator are not supported."
443
+ )
444
+
423
445
if level < 0 or level > 255 :
424
446
raise ValueError ("Invalid logging level: {}" .format (level ))
425
447
await self .initialized
@@ -510,7 +532,9 @@ def _proc_mesh_from_allocator(
510
532
# in the order of the dimensions.
511
533
spec : AllocSpec = AllocSpec (AllocConstraints (), hosts = hosts , gpus = gpus )
512
534
alloc = allocator .allocate (spec )
513
- return ProcMesh .from_alloc (alloc , setup , _init_manager_actors )
535
+ return ProcMesh .from_alloc (
536
+ alloc , setup , _init_manager_actors , allocator .fork_processses ()
537
+ )
514
538
515
539
516
540
def proc_mesh (
0 commit comments