diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 14e381c87..10ca33684 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -13,6 +13,7 @@ import inspect import itertools import logging +import threading from abc import abstractproperty from dataclasses import dataclass @@ -49,6 +50,8 @@ ) from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer +from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport +from monarch._rust_bindings.monarch_hyperactor.config import configure from monarch._rust_bindings.monarch_hyperactor.context import Instance as HyInstance from monarch._rust_bindings.monarch_hyperactor.mailbox import ( Mailbox, @@ -230,6 +233,39 @@ def context() -> Context: return c +_transport: Optional[ChannelTransport] = None +_transport_lock = threading.Lock() + + +def enable_transport(transport: ChannelTransport) -> None: + """ + Allow monarch to communicate with transport type 'transport' + This must be called before any other calls in the monarch API. + If it isn't called, we will implicitly call + `monarch.enable_transport(ChannelTransport.Unix)` on the first monarch call. + + Currently only one transport type may be enabled at one time. + In the future we may allow multiple to be enabled. + """ + if _context.get(None) is not None: + raise RuntimeError( + "`enable_transport()` must be called before any other calls in the monarch API. " + "If it isn't called, we will implicitly call `monarch.enable_transport(ChannelTransport.Unix)` " + "on the first monarch call." + ) + + global _transport + with _transport_lock: + if _transport is not None and _transport != transport: + raise RuntimeError( + f"Only one transport type may be enabled at one time. " + f"Currently enabled transport type is `{_transport}`. " + f"Attempted to enable transport type `{transport}`." + ) + _transport = transport + configure(default_transport=transport) + + @dataclass class DebugContext: pdb_wrapper: Optional[PdbWrapper] = None diff --git a/python/monarch/_src/actor/debugger/debug_controller.py b/python/monarch/_src/actor/debugger/debug_controller.py index 245a70d80..8f62f37c0 100644 --- a/python/monarch/_src/actor/debugger/debug_controller.py +++ b/python/monarch/_src/actor/debugger/debug_controller.py @@ -9,7 +9,7 @@ import functools from typing import Dict, List, Optional, Tuple -from monarch._src.actor.actor_mesh import Actor +from monarch._src.actor.actor_mesh import Actor, context from monarch._src.actor.debugger.debug_command import ( Attach, Cast, @@ -33,8 +33,11 @@ ) from monarch._src.actor.debugger.pdb_wrapper import DebuggerWrite from monarch._src.actor.endpoint import endpoint -from monarch._src.actor.proc_mesh import get_or_spawn_controller +from monarch._src.actor.proc_mesh import ( + get_or_spawn_controller as get_or_spawn_controller_v0, +) from monarch._src.actor.sync_state import fake_sync_state +from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller, ProcMesh from monarch.tools.debug_env import ( _get_debug_server_host, _get_debug_server_port, @@ -243,4 +246,7 @@ async def debugger_write( @functools.cache def debug_controller() -> DebugController: with fake_sync_state(): - return get_or_spawn_controller("debug_controller", DebugController).get() + if isinstance(context().actor_instance.proc_mesh, ProcMesh): + return get_or_spawn_controller("debug_controller", DebugController).get() + else: + return get_or_spawn_controller_v0("debug_controller", DebugController).get() diff --git a/python/monarch/_src/actor/source_loader.py b/python/monarch/_src/actor/source_loader.py index 0765ad27c..6eeb3e2a3 100644 --- a/python/monarch/_src/actor/source_loader.py +++ b/python/monarch/_src/actor/source_loader.py @@ -10,10 +10,13 @@ import importlib.abc import linecache -from monarch._src.actor.actor_mesh import _context, Actor +from monarch._src.actor.actor_mesh import _context, Actor, context from monarch._src.actor.endpoint import endpoint -from monarch._src.actor.proc_mesh import get_or_spawn_controller +from monarch._src.actor.proc_mesh import ( + get_or_spawn_controller as get_or_spawn_controller_v0, +) from monarch._src.actor.sync_state import fake_sync_state +from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller, ProcMesh class SourceLoaderController(Actor): @@ -25,7 +28,14 @@ def get_source(self, filename: str) -> str: @functools.cache def source_loader_controller() -> SourceLoaderController: with fake_sync_state(): - return get_or_spawn_controller("source_loader", SourceLoaderController).get() + if isinstance(context().actor_instance.proc_mesh, ProcMesh): + return get_or_spawn_controller( + "source_loader", SourceLoaderController + ).get() + else: + return get_or_spawn_controller_v0( + "source_loader", SourceLoaderController + ).get() @functools.cache diff --git a/python/monarch/_src/actor/telemetry/__init__.py b/python/monarch/_src/actor/telemetry/__init__.py index cf76b0849..46b2fb44d 100644 --- a/python/monarch/_src/actor/telemetry/__init__.py +++ b/python/monarch/_src/actor/telemetry/__init__.py @@ -30,12 +30,14 @@ class TracingForwarder(logging.Handler): def emit(self, record: logging.LogRecord) -> None: # Try to add actor_id from the current context to the logging record try: - from monarch._src.actor.actor_mesh import context - - ctx = context() - if ctx and ctx.actor_instance and ctx.actor_instance.actor_id: - # Add actor_id as an attribute to the logging record - setattr(record, "actor_id", str(ctx.actor_instance.actor_id)) + from monarch._src.actor.actor_mesh import _context, context + + # Don't initialize the context if it hasn't been initialized yet. + if _context.get(None) is not None: + ctx = context() + if ctx and ctx.actor_instance and ctx.actor_instance.actor_id: + # Add actor_id as an attribute to the logging record + setattr(record, "actor_id", str(ctx.actor_instance.actor_id)) except Exception: # If we can't get the context or actor_id for any reason, just continue # without adding the actor_id field diff --git a/python/monarch/_src/actor/v1/host_mesh.py b/python/monarch/_src/actor/v1/host_mesh.py index 84bb8c649..0eb789cce 100644 --- a/python/monarch/_src/actor/v1/host_mesh.py +++ b/python/monarch/_src/actor/v1/host_mesh.py @@ -142,7 +142,7 @@ def spawn_procs( name = "" return self._spawn_nonblocking( - name, Extent(list(per_host.keys()), list(per_host.values())), setup, False + name, Extent(list(per_host.keys()), list(per_host.values())), setup, True ) def _spawn_nonblocking( diff --git a/python/monarch/_src/actor/v1/proc_mesh.py b/python/monarch/_src/actor/v1/proc_mesh.py index 0c7544c5b..e5836634f 100644 --- a/python/monarch/_src/actor/v1/proc_mesh.py +++ b/python/monarch/_src/actor/v1/proc_mesh.py @@ -28,7 +28,7 @@ from weakref import WeakSet from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared -from monarch._rust_bindings.monarch_hyperactor.shape import Region, Shape, Slice +from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Region, Shape, Slice from monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh import ( ProcMesh as HyProcMesh, @@ -400,7 +400,9 @@ def get_or_spawn( if name not in self._controllers: from monarch._src.actor.v1.host_mesh import this_proc - self._controllers[name] = this_proc().spawn(name, Class, *args, **kwargs) + proc = this_proc() + proc._controller_controller = _get_controller_controller()[1] + self._controllers[name] = proc.spawn(name, Class, *args, **kwargs) return cast(TActor, self._controllers[name]) @@ -422,11 +424,16 @@ def _get_controller_controller() -> "Tuple[ProcMesh, _ControllerController]": _cc_proc_mesh = fake_in_process_host( "controller_controller_host" - ).spawn_procs(name="controller_controller_proc") + )._spawn_nonblocking( + name="controller_controller_proc", + per_host=Extent([], []), + setup=None, + _attach_controller_controller=False, + ) _controller_controller = _cc_proc_mesh.spawn( "controller_controller", _ControllerController ) - assert _cc_proc_mesh is not None + assert _cc_proc_mesh is not None and _controller_controller is not None return _cc_proc_mesh, _controller_controller @@ -447,7 +454,11 @@ def get_or_spawn_controller( A Future that resolves to a reference to the actor. """ cc = context().actor_instance._controller_controller - if not isinstance(cc, _ControllerController): + if ( + cc is not None + and cast(ActorMesh[_ControllerController], cc)._class + is not _ControllerController + ): # This can happen in the client process cc = _get_controller_controller()[1] return cc.get_or_spawn.call_one(name, Class, *args, **kwargs) diff --git a/python/monarch/actor/__init__.py b/python/monarch/actor/__init__.py index bfa4192ae..37b633ff9 100644 --- a/python/monarch/actor/__init__.py +++ b/python/monarch/actor/__init__.py @@ -9,6 +9,7 @@ Monarch Actor API - Public interface for actor functionality. """ +from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport from monarch._rust_bindings.monarch_hyperactor.shape import Extent from monarch._src.actor.actor_mesh import ( Accumulator, @@ -20,6 +21,7 @@ current_actor_name, current_rank, current_size, + enable_transport, Endpoint, Point, Port, @@ -77,4 +79,5 @@ "Extent", "run_worker_loop_forever", "attach_to_workers", + "enable_transport", ] diff --git a/python/tests/test_debugger.py b/python/tests/test_debugger.py index 623cc91fe..f62b8709c 100644 --- a/python/tests/test_debugger.py +++ b/python/tests/test_debugger.py @@ -6,6 +6,7 @@ # pyre-unsafe import asyncio +import enum import functools import importlib.resources import os @@ -14,7 +15,7 @@ import signal import subprocess import sys -from typing import cast, List, Optional, Tuple +from typing import cast, List, Optional, Tuple, Type, TypeVar from unittest.mock import AsyncMock, patch import cloudpickle @@ -25,7 +26,13 @@ import pytest import torch -from monarch._src.actor.actor_mesh import Actor, ActorError, current_rank, IN_PAR +from monarch._src.actor.actor_mesh import ( + Actor, + ActorError, + context, + current_rank, + IN_PAR, +) from monarch._src.actor.debugger.debug_command import ( Attach, Cast, @@ -42,9 +49,21 @@ DebugSessionInfo, DebugSessions, ) -from monarch._src.actor.endpoint import endpoint -from monarch._src.actor.proc_mesh import proc_mesh +from monarch._src.actor.endpoint import endpoint, Extent +from monarch._src.actor.future import Future +from monarch._src.actor.proc_mesh import ( + proc_mesh as proc_mesh_v0, + ProcMesh as ProcMeshV0, +) from monarch._src.actor.source_loader import SourceLoaderController +from monarch._src.actor.v1.host_mesh import ( + create_local_host_mesh, + ProcMesh as ProcMeshV1, + this_host as this_host_v1, +) +from monarch._src.actor.v1.proc_mesh import ( + get_or_spawn_controller as get_or_spawn_controller_v1, +) from monarch.tools.debug_env import ( _MONARCH_DEBUG_SERVER_HOST_ENV_VAR, _MONARCH_DEBUG_SERVER_PORT_ENV_VAR, @@ -52,6 +71,44 @@ from pyre_extensions import none_throws + +class ApiVersion(enum.Enum): + V0 = "v0" + V1 = "v1" + + +TActor = TypeVar("TActor", bound=Actor) + + +def get_or_spawn_controller( + api: ApiVersion, name: str, klass: Type[TActor], *args, **kwargs +) -> Future[TActor]: + match api: + case ApiVersion.V0: + return actor.get_or_spawn_controller(name, klass, *args, **kwargs) + case ApiVersion.V1: + return get_or_spawn_controller_v1(name, klass, *args, **kwargs) + case _: + raise ValueError(f"Unknown API version: {api}") + + +def proc_mesh( + api: ApiVersion, + *, + gpus: int = 1, + hosts: int = 1, +) -> ProcMeshV0 | ProcMeshV1: + match api: + case ApiVersion.V0: + return proc_mesh_v0(gpus=gpus, hosts=hosts) + case ApiVersion.V1: + return create_local_host_mesh( + "hosts", extent=Extent(["hosts"], [hosts]) + ).spawn_procs(per_host={"gpus": gpus}) + case _: + raise ValueError(f"Unknown API version: {api}") + + needs_cuda = pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA not available", @@ -155,6 +212,18 @@ async def to_debug(self): rank = current_rank().rank return _debugee_actor_internal(rank) + @endpoint + async def name(self) -> str: + return context().actor_instance.actor_id.actor_name + + @endpoint + async def nested(self) -> "DebugeeActor": + return ( + this_host_v1() + .spawn_procs(per_host={"hosts": 2, "gpus": 2}) + .spawn("debugee_nested", DebugeeActor) + ) + class DebugControllerForTesting(DebugController): def __init__(self): @@ -186,42 +255,43 @@ async def _wait_for_breakpoints( raise RuntimeError("timed out waiting for breakpoints") -# We have to run this test in a separate process because there is only one -# debug controller per process, and we don't want this to interfere with -# the other two tests that access the debug controller. -@isolate_in_subprocess(env=debug_env) -@pytest.mark.skipif( - torch.cuda.device_count() < 2, - reason="Not enough GPUs, this test requires at least 2 GPUs", -) -@pytest.mark.timeout(60) -async def test_debug() -> None: +async def _test_debug(api: ApiVersion, nested: bool) -> None: + if not nested: + proc = proc_mesh(api, hosts=2, gpus=2) + debugee = proc.spawn("debugee", DebugeeActor) + else: + proc = create_local_host_mesh( + "host", extent=Extent(["hosts"], [1]) + ).spawn_procs() + debugee = proc.spawn("debugee", DebugeeActor).nested.choose().get() + name = debugee.name.choose().get() + input_mock = AsyncMock() input_mock.side_effect = [ - "attach debugee 1", + f"attach {name} 1", "n", "n", "n", "n", "detach", - "attach debugee 1", + f"attach {name} 1", "detach", "quit", - "cast debugee ranks(0,3) n", - "cast debugee ranks(0,3) n", + f"cast {name} ranks(0,3) n", + f"cast {name} ranks(0,3) n", # Attaching to 0 and 3 ensures that when we call "list" # the next time, their function/lineno info will be # up-to-date. - "attach debugee 0", + f"attach {name} 0", "detach", - "attach debugee 3", + f"attach {name} 3", "detach", "quit", - "attach debugee 2", + f"attach {name} 2", "c", "detach", "quit", - "attach debugee 2", + f"attach {name} 2", "bt", "c", "quit", @@ -241,10 +311,8 @@ def _patch_output(msg): with patch( "monarch._src.actor.debugger.debug_io.DebugStdIO.input", new=input_mock ), patch("monarch._src.actor.debugger.debug_io.DebugStdIO.output", new=output_mock): - proc = proc_mesh(hosts=2, gpus=2) - debugee = proc.spawn("debugee", DebugeeActor) - debug_controller = await actor.get_or_spawn_controller( - "debug_controller", DebugControllerForTesting + debug_controller = await get_or_spawn_controller( + api, "debug_controller", DebugControllerForTesting ) fut = debugee.to_debug.call() @@ -357,26 +425,60 @@ def _patch_output(msg): await fut -# See earlier comment +# We have to run this test in a separate process because there is only one +# debug controller per process, and we don't want this to interfere with +# the other tests that access the debug controller. +@isolate_in_subprocess(env=debug_env) +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +@pytest.mark.timeout(60) +async def test_debug_v0(): + await _test_debug(ApiVersion.V0, nested=False) + + +# See earlier comment. +@isolate_in_subprocess(env=debug_env) +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +@pytest.mark.timeout(60) +async def test_debug_v1(): + await _test_debug(ApiVersion.V1, nested=False) + + +# See earlier comment. @isolate_in_subprocess(env=debug_env) @pytest.mark.skipif( torch.cuda.device_count() < 2, reason="Not enough GPUs, this test requires at least 2 GPUs", ) @pytest.mark.timeout(60) -async def test_debug_multi_actor() -> None: +async def test_debug_v1_nested(): + await _test_debug(ApiVersion.V1, nested=True) + + +async def _test_debug_multi_actor(api: ApiVersion) -> None: + proc = proc_mesh(api, hosts=2, gpus=2) + debugee_1 = proc.spawn("debugee_1", DebugeeActor) + debugee_2 = proc.spawn("debugee_2", DebugeeActor) + name_1 = debugee_1.name.choose().get() + name_2 = debugee_2.name.choose().get() + input_mock = AsyncMock() input_mock.side_effect = [ - "attach debugee_2 2", + f"attach {name_2} 2", "n", "detach", - "attach debugee_1 1", + f"attach {name_1} 1", "n", "detach", "quit", - "cast debugee_1 ranks(:) c", - "cast debugee_2 ranks(:) c", - "attach debugee_2 2", + f"cast {name_1} ranks(:) c", + f"cast {name_2} ranks(:) c", + f"attach {name_2} 2", "c", "quit", "continue", @@ -386,11 +488,8 @@ async def test_debug_multi_actor() -> None: with patch( "monarch._src.actor.debugger.debug_io.DebugStdIO.input", side_effect=input_mock ): - proc = proc_mesh(hosts=2, gpus=2) - debugee_1 = proc.spawn("debugee_1", DebugeeActor) - debugee_2 = proc.spawn("debugee_2", DebugeeActor) - debug_controller = await actor.get_or_spawn_controller( - "debug_controller", DebugControllerForTesting + debug_controller = await get_or_spawn_controller( + api, "debug_controller", DebugControllerForTesting ) fut_1 = debugee_1.to_debug.call() @@ -404,7 +503,7 @@ async def test_debug_multi_actor() -> None: info = breakpoints[i] initial_linenos[info.rank] = info.lineno assert info.rank == i % 4 - assert info.actor_name == "debugee_1" if i < 4 else "debugee_2" + assert info.actor_name == name_1 if i < 4 else name_2 assert info.coords == {"hosts": info.rank // 2, "gpus": info.rank % 2} assert info.function == "test_debugger._debugee_actor_internal" assert info.lineno == cast(int, breakpoints[0].lineno) + 5 * info.rank @@ -414,17 +513,15 @@ async def test_debug_multi_actor() -> None: breakpoints = await _wait_for_breakpoints(debug_controller, 8) for i in range(len(breakpoints)): if i == 1: - assert breakpoints[i].actor_name == "debugee_1" + assert breakpoints[i].actor_name == name_1 assert breakpoints[i].rank == 1 assert breakpoints[i].lineno == initial_linenos[breakpoints[i].rank] + 1 elif i == 6: - assert breakpoints[i].actor_name == "debugee_2" + assert breakpoints[i].actor_name == name_2 assert breakpoints[i].rank == 2 assert breakpoints[i].lineno == initial_linenos[breakpoints[i].rank] + 1 else: - assert ( - breakpoints[i].actor_name == "debugee_1" if i < 4 else "debugee_2" - ) + assert breakpoints[i].actor_name == name_1 if i < 4 else name_2 assert breakpoints[i].rank == i % 4 assert breakpoints[i].lineno == initial_linenos[breakpoints[i].rank] @@ -433,7 +530,7 @@ async def test_debug_multi_actor() -> None: breakpoints = await _wait_for_breakpoints(debug_controller, 1) with pytest.raises(ActorError, match="ValueError: bad rank"): await fut_2 - assert breakpoints[0].actor_name == "debugee_1" + assert breakpoints[0].actor_name == name_1 assert breakpoints[0].rank == 2 assert breakpoints[0].function == "test_debugger._bad_rank" @@ -444,6 +541,28 @@ async def test_debug_multi_actor() -> None: await fut_1 +# See earlier comment +@isolate_in_subprocess(env=debug_env) +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +@pytest.mark.timeout(60) +async def test_debug_multi_actor_v0(): + await _test_debug_multi_actor(ApiVersion.V0) + + +# See earlier comment +@isolate_in_subprocess(env=debug_env) +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +@pytest.mark.timeout(60) +async def test_debug_multi_actor_v1(): + await _test_debug_multi_actor(ApiVersion.V1) + + async def test_debug_sessions_insert_get_remove() -> None: mock_sessions = [] for actor_name in ("actor_a", "actor_b"): @@ -778,18 +897,12 @@ async def test_debug_command_parser_invalid_inputs(invalid_input): assert await DebugCommand.parse(DebugStdIO(), invalid_input) is None -# See earlier comment -@isolate_in_subprocess(env={"MONARCH_CLI_BIN": cli_bin, **debug_env}) -@pytest.mark.skipif( - torch.cuda.device_count() < 2, - reason="Not enough GPUs, this test requires at least 2 GPUs", -) -@pytest.mark.timeout(60) -async def test_debug_cli(): - proc = proc_mesh(hosts=2, gpus=2) +async def _test_debug_cli(api: ApiVersion): + proc = proc_mesh(api, hosts=2, gpus=2) debugee = proc.spawn("debugee", DebugeeActor) - debug_controller = actor.get_or_spawn_controller( - "debug_controller", DebugControllerForTesting + name = debugee.name.choose().get() + debug_controller = get_or_spawn_controller( + api, "debug_controller", DebugControllerForTesting ).get() fut = debugee.to_debug.call() @@ -860,13 +973,13 @@ async def create_debug_cli_proc() -> ( debug_cli_stdin.writelines( [ - b"attach debugee 1\n", + f"attach {name} 1\n".encode(), b"n\n", b"n\n", b"n\n", b"n\n", b"detach\n", - b"attach debugee 1\n", + f"attach {name} 1\n".encode(), b"print('test separator')\n", b"detach\n", ] @@ -911,14 +1024,14 @@ async def create_debug_cli_proc() -> ( debug_cli_stdin.writelines( [ - b"cast debugee ranks(0,3) n\n", - b"cast debugee ranks(0,3) n\n", + f"cast {name} ranks(0,3) n\n".encode(), + f"cast {name} ranks(0,3) n\n".encode(), # Attaching to 0 and 3 ensures that when we call "list" # the next time, their function/lineno info will be # up-to-date. - b"attach debugee 0\n", + f"attach {name} 0\n".encode(), b"detach\n", - b"attach debugee 3\n", + f"attach {name} 3\n".encode(), b"detach\n", ] ) @@ -926,7 +1039,9 @@ async def create_debug_cli_proc() -> ( # Make sure we have run all the commands before killing the CLI, otherwise # the commands may not actually be sent to the debug controller. - await debug_cli_stdout.readuntil(b"Detached from debug session for debugee 3") + await debug_cli_stdout.readuntil( + f"Detached from debug session for {name} 3".encode() + ) if debug_cli_proc: # Even if we kill the proc using a signal, we should be able to reconnect # without issue. @@ -953,7 +1068,7 @@ async def create_debug_cli_proc() -> ( debug_cli_stdout, ) = await create_debug_cli_proc() - debug_cli_stdin.writelines([b"attach debugee 2\n", b"c\n"]) + debug_cli_stdin.writelines([f"attach {name} 2\n".encode(), b"c\n"]) await debug_cli_stdin.drain() # Make sure we have run all the commands before killing the CLI, otherwise @@ -980,7 +1095,7 @@ async def create_debug_cli_proc() -> ( debug_cli_stdout, ) = await create_debug_cli_proc() - debug_cli_stdin.writelines([b"attach debugee 2\n", b"bt\n", b"c\n"]) + debug_cli_stdin.writelines([f"attach {name} 2\n".encode(), b"bt\n", b"c\n"]) await debug_cli_stdin.drain() expected_output = ( @@ -990,7 +1105,9 @@ async def create_debug_cli_proc() -> ( ) output = ( - await debug_cli_stdout.readuntil(b"Detached from debug session for debugee 2") + await debug_cli_stdout.readuntil( + f"Detached from debug session for {name} 2".encode() + ) ).decode() assert len(re.findall(expected_output, output)) == 1 @@ -1029,6 +1146,28 @@ async def create_debug_cli_proc() -> ( await fut +# See earlier comment +@isolate_in_subprocess(env={"MONARCH_CLI_BIN": cli_bin, **debug_env}) +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +@pytest.mark.timeout(60) +async def test_debug_cli_v0(): + await _test_debug_cli(ApiVersion.V0) + + +# See earlier comment +@isolate_in_subprocess(env={"MONARCH_CLI_BIN": cli_bin, **debug_env}) +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +@pytest.mark.timeout(60) +async def test_debug_cli_v1(): + await _test_debug_cli(ApiVersion.V1) + + class_closure_source = """class ClassClosure: def __init__(self, arg): self.arg = arg @@ -1088,12 +1227,12 @@ def debug_class_closure(self, class_closure) -> int: def debug_func(self, func, class_closure) -> int: return func(class_closure) + @endpoint + async def name(self) -> str: + return context().actor_instance.actor_id.actor_name -# We have to run this test in a subprocess because it requires a special -# instantiation of the debug controller singleton. -@isolate_in_subprocess(env=debug_env) -@pytest.mark.timeout(60) -async def test_debug_with_pickle_by_value(): + +async def _test_debug_with_pickle_by_value(api: ApiVersion): """ This test tests debugger functionality when there are breakpoints in code that has been pickled by value (as opposed to pickling by reference, @@ -1118,22 +1257,25 @@ async def test_debug_with_pickle_by_value(): The test unpickles these and sends them to an actor endpoint, in which breakpoints will be hit and we can test the special pdb handling logic. """ + pm = proc_mesh(api, gpus=1, hosts=1) + debugee = pm.spawn("debugee", ClosureDebugeeActor) + name = debugee.name.choose().get() input_mock = AsyncMock() input_mock.side_effect = [ - "attach debugee 0", + f"attach {name} 0", "c", "quit", - "attach debugee 0", + f"attach {name} 0", "bt", "c", "quit", - "attach debugee 0", + f"attach {name} 0", "b /tmp/monarch_test/class_closure:10", "c", "detach", "quit", - "attach debugee 0", + f"attach {name} 0", "c", "detach", "quit", @@ -1153,21 +1295,17 @@ def _patch_output(msg): with patch( "monarch._src.actor.debugger.debug_io.DebugStdIO.input", new=input_mock ), patch("monarch._src.actor.debugger.debug_io.DebugStdIO.output", new=output_mock): - pm = proc_mesh(gpus=1, hosts=1) - - debug_controller = actor.get_or_spawn_controller( - "debug_controller", DebugControllerForTesting + debug_controller = get_or_spawn_controller( + api, "debug_controller", DebugControllerForTesting ).get() # Spawn a special source loader that knows how to retrieve the source code # for /tmp/monarch_test/class_closure.py and # /tmp/monarch_test/function_closure.py - actor.get_or_spawn_controller( - "source_loader", SourceLoaderControllerWithMockedSource + get_or_spawn_controller( + api, "source_loader", SourceLoaderControllerWithMockedSource ).get() - debugee = pm.spawn("debugee", ClosureDebugeeActor) - class_closure = load_class_closure() func_bp_true, func_bp_false = load_func_closure() @@ -1240,3 +1378,19 @@ def _patch_output(msg): await fut await pm.stop() + + +# We have to run this test in a subprocess because it requires a special +# instantiation of the debug controller singleton. +@isolate_in_subprocess(env=debug_env) +@pytest.mark.timeout(60) +async def test_debug_with_pickle_by_value_v0(): + await _test_debug_with_pickle_by_value(ApiVersion.V0) + + +# We have to run this test in a subprocess because it requires a special +# instantiation of the debug controller singleton. +@isolate_in_subprocess(env=debug_env) +@pytest.mark.timeout(60) +async def test_debug_with_pickle_by_value_v1(): + await _test_debug_with_pickle_by_value(ApiVersion.V1) diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 1aa33cff1..13b1614aa 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -63,7 +63,10 @@ this_host as this_host_v1, this_proc as this_proc_v1, ) -from monarch._src.actor.v1.proc_mesh import ProcMesh as ProcMeshV1 +from monarch._src.actor.v1.proc_mesh import ( + get_or_spawn_controller as get_or_spawn_controller_v1, + ProcMesh as ProcMeshV1, +) from monarch.actor import ( Accumulator, @@ -1697,15 +1700,18 @@ def return_root(self): return self._root @endpoint - async def spawning_from_endpoint(self, name, root) -> None: - await get_or_spawn_controller(name, SpawningActorFromEndpointActor, root=root) + async def spawning_from_endpoint(self, name, root, get_or_spawn) -> None: + await get_or_spawn(name, SpawningActorFromEndpointActor, root=root) +@pytest.mark.parametrize( + "get_or_spawn", [get_or_spawn_controller, get_or_spawn_controller_v1] +) @pytest.mark.timeout(60) -def test_get_or_spawn_controller_inside_actor_endpoint(): - actor_1 = get_or_spawn_controller("actor_1", SpawningActorFromEndpointActor).get() - actor_1.spawning_from_endpoint.call_one("actor_2", root="actor_1").get() - actor_2 = get_or_spawn_controller("actor_2", SpawningActorFromEndpointActor).get() +def test_get_or_spawn_controller_inside_actor_endpoint(get_or_spawn): + actor_1 = get_or_spawn("actor_1", SpawningActorFromEndpointActor).get() + actor_1.spawning_from_endpoint.call_one("actor_2", "actor_1", get_or_spawn).get() + actor_2 = get_or_spawn("actor_2", SpawningActorFromEndpointActor).get() # verify that actor_2 was spawned from actor_1 with the correct root assert actor_2.return_root.call_one().get() == "actor_1"