Skip to content

Commit 7e52cb1

Browse files
samluryefacebook-github-bot
authored andcommitted
Enable log forwarding for v1 API (#1375)
Summary: Pull Request resolved: #1375 Use logging manager to spawn log client/log forwarders on v1 proc meshes ghstack-source-id: 313150889 Reviewed By: mariusae Differential Revision: D83360166 fbshipit-source-id: 456673311006e001569c4b02923ee881f970e3d9
1 parent 5a863c3 commit 7e52cb1

File tree

4 files changed

+91
-40
lines changed

4 files changed

+91
-40
lines changed

monarch_hyperactor/src/v1/logging.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use hyperactor_mesh::logging::LogClientMessage;
1717
use hyperactor_mesh::logging::LogForwardActor;
1818
use hyperactor_mesh::logging::LogForwardMessage;
1919
use hyperactor_mesh::v1::ActorMesh;
20+
use hyperactor_mesh::v1::Name;
2021
use hyperactor_mesh::v1::actor_mesh::ActorMeshRef;
2122
use ndslice::View;
2223
use pyo3::Bound;
@@ -80,7 +81,10 @@ impl LoggingMeshClient {
8081
PyPythonTask::new(async move {
8182
let client_actor: ActorHandle<LogClientActor> =
8283
instance_dispatch!(instance, async move |cx_instance| {
83-
cx_instance.proc().spawn("log_client", ()).await
84+
cx_instance
85+
.proc()
86+
.spawn(&Name::new("log_client").to_string(), ())
87+
.await
8488
})?;
8589
let client_actor_ref = client_actor.bind();
8690
let forwarder_mesh = instance_dispatch!(instance, async |cx_instance| {

python/monarch/_src/actor/logging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def flush_all_proc_mesh_logs(v1: bool = False) -> None:
5050
from monarch._src.actor.v1.proc_mesh import get_active_proc_meshes
5151

5252
for pm in get_active_proc_meshes():
53-
pm._logging_manager.flush()
53+
if pm._logging_manager._logging_mesh_client is not None:
54+
pm._logging_manager.flush()
5455

5556

5657
class LoggingManager:

python/monarch/_src/actor/v1/proc_mesh.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,7 @@ async def task(
200200
) -> HyProcMesh:
201201
hy_proc_mesh = await hy_proc_mesh_task
202202

203-
# FIXME: Fix log forwarding.
204-
# await pm._logging_manager.init(hy_proc_mesh, stream_log_to_client)
203+
await pm._logging_manager.init(hy_proc_mesh, stream_log_to_client)
205204

206205
if setup_actor is not None:
207206
await setup_actor.setup.call()

python/tests/test_python_actors.py

Lines changed: 83 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,19 @@
3737
)
3838
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
3939
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
40+
from monarch._rust_bindings.monarch_hyperactor.shape import Extent
4041

4142
from monarch._src.actor.actor_mesh import ActorMesh, Channel, context, Port
42-
from monarch._src.actor.allocator import AllocHandle
43+
from monarch._src.actor.allocator import AllocHandle, ProcessAllocator
4344
from monarch._src.actor.future import Future
4445
from monarch._src.actor.host_mesh import (
4546
create_local_host_mesh,
4647
fake_in_process_host,
4748
HostMesh,
4849
)
49-
from monarch._src.actor.proc_mesh import ProcMesh
50+
from monarch._src.actor.proc_mesh import _get_bootstrap_args, ProcMesh
5051
from monarch._src.actor.v1.host_mesh import (
52+
_bootstrap_cmd,
5153
fake_in_process_host as fake_in_process_host_v1,
5254
HostMesh as HostMeshV1,
5355
this_host as this_host_v1,
@@ -466,7 +468,7 @@ async def no_more(self) -> None:
466468

467469

468470
@pytest.mark.parametrize("v1", [True, False])
469-
@pytest.mark.timeout(30)
471+
@pytest.mark.timeout(60)
470472
async def test_async_concurrency(v1: bool):
471473
"""Test that async endpoints will be processed concurrently."""
472474
pm = spawn_procs_on_this_host(v1, {})
@@ -603,8 +605,9 @@ def _handle_undeliverable_message(
603605
return True
604606

605607

608+
@pytest.mark.parametrize("v1", [True, False])
606609
@pytest.mark.timeout(60)
607-
async def test_actor_log_streaming() -> None:
610+
async def test_actor_log_streaming(v1: bool) -> None:
608611
# Save original file descriptors
609612
original_stdout_fd = os.dup(1) # stdout
610613
original_stderr_fd = os.dup(2) # stderr
@@ -631,7 +634,7 @@ async def test_actor_log_streaming() -> None:
631634
sys.stderr = stderr_file
632635

633636
try:
634-
pm = spawn_procs_on_this_host(v1=False, per_host={"gpus": 2})
637+
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
635638
am = pm.spawn("printer", Printer)
636639

637640
# Disable streaming logs to client
@@ -671,7 +674,10 @@ async def test_actor_log_streaming() -> None:
671674
await am.print.call("has print streaming too")
672675
await am.log.call("has log streaming as level matched")
673676

674-
await pm.stop()
677+
if not v1:
678+
await pm.stop()
679+
else:
680+
await asyncio.sleep(1)
675681

676682
# Flush all outputs
677683
stdout_file.flush()
@@ -752,8 +758,9 @@ async def test_actor_log_streaming() -> None:
752758
pass
753759

754760

761+
@pytest.mark.parametrize("v1", [True, False])
755762
@pytest.mark.timeout(120)
756-
async def test_alloc_based_log_streaming() -> None:
763+
async def test_alloc_based_log_streaming(v1: bool) -> None:
757764
"""Test both AllocHandle.stream_logs = False and True cases."""
758765

759766
async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
@@ -770,23 +777,45 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
770777

771778
try:
772779
# Create proc mesh with custom stream_logs setting
773-
host_mesh = create_local_host_mesh()
774-
alloc_handle = host_mesh._alloc(hosts=1, gpus=2)
780+
if not v1:
781+
host_mesh = create_local_host_mesh()
782+
alloc_handle = host_mesh._alloc(hosts=1, gpus=2)
783+
784+
# Override the stream_logs setting
785+
custom_alloc_handle = AllocHandle(
786+
alloc_handle._hy_alloc, alloc_handle._extent, stream_logs
787+
)
788+
789+
pm = ProcMesh.from_alloc(custom_alloc_handle)
790+
else:
775791

776-
# Override the stream_logs setting
777-
custom_alloc_handle = AllocHandle(
778-
alloc_handle._hy_alloc, alloc_handle._extent, stream_logs
779-
)
792+
class ProcessAllocatorStreamLogs(ProcessAllocator):
793+
def _stream_logs(self) -> bool:
794+
return stream_logs
795+
796+
alloc = ProcessAllocatorStreamLogs(*_get_bootstrap_args())
797+
798+
host_mesh = HostMeshV1.allocate_nonblocking(
799+
"host",
800+
Extent(["hosts"], [1]),
801+
alloc,
802+
bootstrap_cmd=_bootstrap_cmd(),
803+
)
804+
805+
pm = host_mesh.spawn_procs(name="proc", per_host={"gpus": 2})
780806

781-
pm = ProcMesh.from_alloc(custom_alloc_handle)
782807
am = pm.spawn("printer", Printer)
783808

784809
await pm.initialized
785810

786811
for _ in range(5):
787812
await am.print.call(f"{test_name} print streaming")
788813

789-
await pm.stop()
814+
if not v1:
815+
await pm.stop()
816+
else:
817+
# Wait for at least the aggregation window (3 seconds)
818+
await asyncio.sleep(5)
790819

791820
# Flush all outputs
792821
stdout_file.flush()
@@ -810,18 +839,18 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
810839
# When stream_logs=False, logs should not be streamed to client
811840
assert not re.search(
812841
rf"similar log lines.*{test_name} print streaming", stdout_content
813-
), f"stream_logs=True case: {stdout_content}"
842+
), f"stream_logs=False case: {stdout_content}"
814843
assert re.search(
815844
rf"{test_name} print streaming", stdout_content
816-
), f"stream_logs=True case: {stdout_content}"
845+
), f"stream_logs=False case: {stdout_content}"
817846
else:
818847
# When stream_logs=True, logs should be streamed to client (no aggregation by default)
819848
assert re.search(
820849
rf"similar log lines.*{test_name} print streaming", stdout_content
821-
), f"stream_logs=False case: {stdout_content}"
850+
), f"stream_logs=True case: {stdout_content}"
822851
assert not re.search(
823852
rf"\[[0-9]\]{test_name} print streaming", stdout_content
824-
), f"stream_logs=False case: {stdout_content}"
853+
), f"stream_logs=True case: {stdout_content}"
825854

826855
finally:
827856
# Ensure file descriptors are restored even if something goes wrong
@@ -836,8 +865,9 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
836865
await test_stream_logs_case(True, "stream_logs_true")
837866

838867

868+
@pytest.mark.parametrize("v1", [True, False])
839869
@pytest.mark.timeout(60)
840-
async def test_logging_option_defaults() -> None:
870+
async def test_logging_option_defaults(v1: bool) -> None:
841871
# Save original file descriptors
842872
original_stdout_fd = os.dup(1) # stdout
843873
original_stderr_fd = os.dup(2) # stderr
@@ -864,14 +894,18 @@ async def test_logging_option_defaults() -> None:
864894
sys.stderr = stderr_file
865895

866896
try:
867-
pm = spawn_procs_on_this_host(v1=False, per_host={"gpus": 2})
897+
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
868898
am = pm.spawn("printer", Printer)
869899

870900
for _ in range(5):
871901
await am.print.call("print streaming")
872902
await am.log.call("log streaming")
873903

874-
await pm.stop()
904+
if not v1:
905+
await pm.stop()
906+
else:
907+
# Wait for > default aggregation window (3 seconds)
908+
await asyncio.sleep(5)
875909

876910
# Flush all outputs
877911
stdout_file.flush()
@@ -949,7 +983,8 @@ def __init__(self):
949983

950984
# oss_skip: pytest keeps complaining about mocking get_ipython module
951985
@pytest.mark.oss_skip
952-
async def test_flush_called_only_once() -> None:
986+
@pytest.mark.parametrize("v1", [True, False])
987+
async def test_flush_called_only_once(v1: bool) -> None:
953988
"""Test that flush is called only once when ending an ipython cell"""
954989
mock_ipython = MockIPython()
955990
with unittest.mock.patch(
@@ -961,8 +996,8 @@ async def test_flush_called_only_once() -> None:
961996
"monarch._src.actor.logging.flush_all_proc_mesh_logs"
962997
) as mock_flush:
963998
# Create 2 proc meshes with a large aggregation window
964-
pm1 = this_host().spawn_procs(per_host={"gpus": 2})
965-
_ = this_host().spawn_procs(per_host={"gpus": 2})
999+
pm1 = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
1000+
_ = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
9661001
# flush not yet called unless post_run_cell
9671002
assert mock_flush.call_count == 0
9681003
assert mock_ipython.events.registers == 0
@@ -976,8 +1011,9 @@ async def test_flush_called_only_once() -> None:
9761011

9771012
# oss_skip: pytest keeps complaining about mocking get_ipython module
9781013
@pytest.mark.oss_skip
1014+
@pytest.mark.parametrize("v1", [True, False])
9791015
@pytest.mark.timeout(180)
980-
async def test_flush_logs_ipython() -> None:
1016+
async def test_flush_logs_ipython(v1: bool) -> None:
9811017
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
9821018
# Save original file descriptors
9831019
original_stdout_fd = os.dup(1) # stdout
@@ -1003,8 +1039,8 @@ async def test_flush_logs_ipython() -> None:
10031039
), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
10041040
# Make sure we can register and unregister callbacks
10051041
for _ in range(3):
1006-
pm1 = this_host().spawn_procs(per_host={"gpus": 2})
1007-
pm2 = this_host().spawn_procs(per_host={"gpus": 2})
1042+
pm1 = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
1043+
pm2 = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
10081044
am1 = pm1.spawn("printer", Printer)
10091045
am2 = pm2.spawn("printer", Printer)
10101046

@@ -1108,8 +1144,9 @@ async def test_flush_logs_fast_exit() -> None:
11081144
), process.stdout
11091145

11101146

1147+
@pytest.mark.parametrize("v1", [True, False])
11111148
@pytest.mark.timeout(60)
1112-
async def test_flush_on_disable_aggregation() -> None:
1149+
async def test_flush_on_disable_aggregation(v1: bool) -> None:
11131150
"""Test that logs are flushed when disabling aggregation.
11141151
11151152
This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation."
@@ -1130,7 +1167,7 @@ async def test_flush_on_disable_aggregation() -> None:
11301167
sys.stdout = stdout_file
11311168

11321169
try:
1133-
pm = this_host().spawn_procs(per_host={"gpus": 2})
1170+
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
11341171
am = pm.spawn("printer", Printer)
11351172

11361173
# Set a long aggregation window to ensure logs aren't flushed immediately
@@ -1151,7 +1188,11 @@ async def test_flush_on_disable_aggregation() -> None:
11511188
for _ in range(5):
11521189
await am.print.call("single log line")
11531190

1154-
await pm.stop()
1191+
if not v1:
1192+
await pm.stop()
1193+
else:
1194+
# Wait for > default aggregation window (3 secs)
1195+
await asyncio.sleep(5)
11551196

11561197
# Flush all outputs
11571198
stdout_file.flush()
@@ -1197,14 +1238,15 @@ async def test_flush_on_disable_aggregation() -> None:
11971238
pass
11981239

11991240

1241+
@pytest.mark.parametrize("v1", [True, False])
12001242
@pytest.mark.timeout(120)
1201-
async def test_multiple_ongoing_flushes_no_deadlock() -> None:
1243+
async def test_multiple_ongoing_flushes_no_deadlock(v1: bool) -> None:
12021244
"""
12031245
The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked.
12041246
Because now a flush call is purely sync, it is very easy to get into a deadlock.
12051247
So we assert the last flush call will not get into such a state.
12061248
"""
1207-
pm = this_host().spawn_procs(per_host={"gpus": 4})
1249+
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 4})
12081250
am = pm.spawn("printer", Printer)
12091251

12101252
# Generate some logs that will be aggregated but not flushed immediately
@@ -1227,8 +1269,9 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
12271269
futures[-1].get()
12281270

12291271

1272+
@pytest.mark.parametrize("v1", [True, False])
12301273
@pytest.mark.timeout(60)
1231-
async def test_adjust_aggregation_window() -> None:
1274+
async def test_adjust_aggregation_window(v1: bool) -> None:
12321275
"""Test that the flush deadline is updated when the aggregation window is adjusted.
12331276
12341277
This tests the corner case: "This can happen if the user has adjusted the aggregation window."
@@ -1249,7 +1292,7 @@ async def test_adjust_aggregation_window() -> None:
12491292
sys.stdout = stdout_file
12501293

12511294
try:
1252-
pm = this_host().spawn_procs(per_host={"gpus": 2})
1295+
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
12531296
am = pm.spawn("printer", Printer)
12541297

12551298
# Set a long aggregation window initially
@@ -1267,7 +1310,11 @@ async def test_adjust_aggregation_window() -> None:
12671310
for _ in range(3):
12681311
await am.print.call("second batch of logs")
12691312

1270-
await pm.stop()
1313+
if not v1:
1314+
await pm.stop()
1315+
else:
1316+
# Wait for > aggregation window (2 secs)
1317+
await asyncio.sleep(4)
12711318

12721319
# Flush all outputs
12731320
stdout_file.flush()

0 commit comments

Comments
 (0)