diff --git a/monarch_extension/src/lib.rs b/monarch_extension/src/lib.rs index b6f0e5c3d..f1382de98 100644 --- a/monarch_extension/src/lib.rs +++ b/monarch_extension/src/lib.rs @@ -27,6 +27,8 @@ mod tensor_worker; mod blocking; mod panic; + +use monarch_types::py_global; use pyo3::prelude::*; #[pyfunction] @@ -34,6 +36,12 @@ fn has_tensor_engine() -> bool { cfg!(feature = "tensor_engine") } +py_global!( + add_extension_methods, + "monarch._src.actor.python_extension_methods", + "add_extension_methods" +); + fn get_or_add_new_module<'py>( module: &Bound<'py, PyModule>, module_name: &str, @@ -46,22 +54,29 @@ fn get_or_add_new_module<'py>( if let Some(submodule) = submodule { current_module = submodule.extract()?; } else { - let new_module = PyModule::new(current_module.py(), part)?; - current_module.add_submodule(&new_module)?; + let name = format!("monarch._rust_bindings.{}", parts.join(".")); + let new_module = PyModule::new(current_module.py(), &name)?; + current_module.add(part, new_module.clone())?; current_module .py() .import("sys")? .getattr("modules")? - .set_item( - format!("monarch._rust_bindings.{}", parts.join(".")), - new_module.clone(), - )?; + .set_item(name, new_module.clone())?; current_module = new_module; } } Ok(current_module) } +fn register(module: &Bound<'_, PyModule>, module_path: &str, register_fn: F) -> PyResult<()> +where + F: FnOnce(&Bound<'_, PyModule>) -> PyResult<()>, +{ + let submodule = get_or_add_new_module(module, module_path)?; + register_fn(&submodule)?; + Ok(()) +} + #[pymodule] #[pyo3(name = "_rust_bindings")] pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> { @@ -71,153 +86,190 @@ pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> { runtime.handle().clone(), Some(::hyperactor_mesh::bootstrap::BOOTSTRAP_INDEX_ENV.to_string()), ); - - monarch_hyperactor::shape::register_python_bindings(&get_or_add_new_module( + register( module, "monarch_hyperactor.shape", - )?)?; - - monarch_hyperactor::selection::register_python_bindings(&get_or_add_new_module( + monarch_hyperactor::shape::register_python_bindings, + )?; + register( module, "monarch_hyperactor.selection", - )?)?; - - monarch_hyperactor::supervision::register_python_bindings(&get_or_add_new_module( + monarch_hyperactor::selection::register_python_bindings, + )?; + register( module, "monarch_hyperactor.supervision", - )?)?; + monarch_hyperactor::supervision::register_python_bindings, + )?; #[cfg(feature = "tensor_engine")] { - client::register_python_bindings(&get_or_add_new_module( + register( module, "monarch_extension.client", - )?)?; - tensor_worker::register_python_bindings(&get_or_add_new_module( + client::register_python_bindings, + )?; + register( module, "monarch_extension.tensor_worker", - )?)?; - controller::register_python_bindings(&get_or_add_new_module( + tensor_worker::register_python_bindings, + )?; + register( module, "monarch_extension.controller", - )?)?; - debugger::register_python_bindings(&get_or_add_new_module( + controller::register_python_bindings, + )?; + register( module, "monarch_extension.debugger", - )?)?; - monarch_messages::debugger::register_python_bindings(&get_or_add_new_module( + debugger::register_python_bindings, + )?; + register( module, "monarch_messages.debugger", - )?)?; - simulator_client::register_python_bindings(&get_or_add_new_module( + monarch_messages::debugger::register_python_bindings, + )?; + register( module, "monarch_extension.simulator_client", - )?)?; - ::controller::bootstrap::register_python_bindings(&get_or_add_new_module( + simulator_client::register_python_bindings, + )?; + register( module, "controller.bootstrap", - )?)?; - ::monarch_tensor_worker::bootstrap::register_python_bindings(&get_or_add_new_module( + ::controller::bootstrap::register_python_bindings, + )?; + register( module, "monarch_tensor_worker.bootstrap", - )?)?; - crate::convert::register_python_bindings(&get_or_add_new_module( + ::monarch_tensor_worker::bootstrap::register_python_bindings, + )?; + register( module, "monarch_extension.convert", - )?)?; - crate::mesh_controller::register_python_bindings(&get_or_add_new_module( + crate::convert::register_python_bindings, + )?; + register( module, "monarch_extension.mesh_controller", - )?)?; - monarch_rdma_extension::register_python_bindings(&get_or_add_new_module(module, "rdma")?)?; + crate::mesh_controller::register_python_bindings, + )?; + register( + module, + "rdma", + monarch_rdma_extension::register_python_bindings, + )?; } - simulation_tools::register_python_bindings(&get_or_add_new_module( + register( module, "monarch_extension.simulation_tools", - )?)?; - monarch_hyperactor::bootstrap::register_python_bindings(&get_or_add_new_module( + simulation_tools::register_python_bindings, + )?; + register( module, "monarch_hyperactor.bootstrap", - )?)?; + monarch_hyperactor::bootstrap::register_python_bindings, + )?; - monarch_hyperactor::proc::register_python_bindings(&get_or_add_new_module( + register( module, "monarch_hyperactor.proc", - )?)?; + monarch_hyperactor::proc::register_python_bindings, + )?; - monarch_hyperactor::actor::register_python_bindings(&get_or_add_new_module( + register( module, "monarch_hyperactor.actor", - )?)?; + monarch_hyperactor::actor::register_python_bindings, + )?; - monarch_hyperactor::pytokio::register_python_bindings(&get_or_add_new_module( + register( module, "monarch_hyperactor.pytokio", - )?)?; - - monarch_hyperactor::mailbox::register_python_bindings(&get_or_add_new_module( + monarch_hyperactor::pytokio::register_python_bindings, + )?; + register( module, "monarch_hyperactor.mailbox", - )?)?; + monarch_hyperactor::mailbox::register_python_bindings, + )?; - monarch_hyperactor::alloc::register_python_bindings(&get_or_add_new_module( + register( module, "monarch_hyperactor.alloc", - )?)?; - monarch_hyperactor::channel::register_python_bindings(&get_or_add_new_module( + monarch_hyperactor::alloc::register_python_bindings, + )?; + register( module, "monarch_hyperactor.channel", - )?)?; - monarch_hyperactor::actor_mesh::register_python_bindings(&get_or_add_new_module( + monarch_hyperactor::channel::register_python_bindings, + )?; + register( module, "monarch_hyperactor.actor_mesh", - )?)?; - monarch_hyperactor::proc_mesh::register_python_bindings(&get_or_add_new_module( + monarch_hyperactor::actor_mesh::register_python_bindings, + )?; + register( module, "monarch_hyperactor.proc_mesh", - )?)?; + monarch_hyperactor::proc_mesh::register_python_bindings, + )?; - monarch_hyperactor::runtime::register_python_bindings(&get_or_add_new_module( + register( module, "monarch_hyperactor.runtime", - )?)?; - monarch_hyperactor::telemetry::register_python_bindings(&get_or_add_new_module( + monarch_hyperactor::runtime::register_python_bindings, + )?; + register( module, "monarch_hyperactor.telemetry", - )?)?; - code_sync::register_python_bindings(&get_or_add_new_module( + monarch_hyperactor::telemetry::register_python_bindings, + )?; + register( module, "monarch_extension.code_sync", - )?)?; + code_sync::register_python_bindings, + )?; - crate::panic::register_python_bindings(&get_or_add_new_module( + register( module, "monarch_extension.panic", - )?)?; + crate::panic::register_python_bindings, + )?; - crate::blocking::register_python_bindings(&get_or_add_new_module( + register( module, "monarch_extension.blocking", - )?)?; + crate::blocking::register_python_bindings, + )?; - crate::logging::register_python_bindings(&get_or_add_new_module( + register( module, "monarch_extension.logging", - )?)?; + crate::logging::register_python_bindings, + )?; #[cfg(fbcode_build)] { - monarch_hyperactor::meta::alloc::register_python_bindings(&get_or_add_new_module( + register( module, "monarch_hyperactor.meta.alloc", - )?)?; - monarch_hyperactor::meta::alloc_mock::register_python_bindings(&get_or_add_new_module( + monarch_hyperactor::meta::alloc::register_python_bindings, + )?; + register( module, "monarch_hyperactor.meta.alloc_mock", - )?)?; + monarch_hyperactor::meta::alloc_mock::register_python_bindings, + )?; } // Add feature detection function module.add_function(wrap_pyfunction!(has_tensor_engine, module)?)?; + // this should be called last. otherwise cross references in pyi files will not have been + // added to sys.modules yet. + let maybe_module = module.py().import("monarch._src"); + if maybe_module.is_ok() { + add_extension_methods(module.py()).call1((module,))?; + } Ok(()) } diff --git a/monarch_types/src/lib.rs b/monarch_types/src/lib.rs index a291d7c7b..8678d02d3 100644 --- a/monarch_types/src/lib.rs +++ b/monarch_types/src/lib.rs @@ -16,3 +16,22 @@ pub use pyobject::PickledPyObject; pub use python::SerializablePyErr; pub use python::TryIntoPyObjectUnsafe; pub use pytree::PyTree; + +/// Macro to generate a Python object lookup function with caching +/// +/// # Arguments +/// * `$fn_name` - Name of the Rust function to generate +/// * `$python_path` - Path to the Python object as a string (e.g., "module.submodule.function") +#[macro_export] +macro_rules! py_global { + ($fn_name:ident, $python_module:literal, $python_class:literal) => { + fn $fn_name<'py>(py: ::pyo3::Python<'py>) -> ::pyo3::Bound<'py, ::pyo3::PyAny> { + static CACHE: ::pyo3::sync::GILOnceCell<::pyo3::PyObject> = + ::pyo3::sync::GILOnceCell::new(); + CACHE + .import(py, $python_module, $python_class) + .unwrap() + .clone() + } + }; +} diff --git a/python/monarch/_rust_bindings/controller.pyi b/python/monarch/_rust_bindings/controller.pyi new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/python/monarch/_rust_bindings/controller.pyi @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/python/monarch/_rust_bindings/controller/__init__.pyi b/python/monarch/_rust_bindings/controller/__init__.pyi new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/python/monarch/_rust_bindings/controller/__init__.pyi @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/python/monarch/_rust_bindings/controller/bootstrap.pyi b/python/monarch/_rust_bindings/controller/bootstrap.pyi index 4b3130fb7..b07c656c2 100644 --- a/python/monarch/_rust_bindings/controller/bootstrap.pyi +++ b/python/monarch/_rust_bindings/controller/bootstrap.pyi @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import final, List, Optional, Tuple +from typing import final, List, Optional, Tuple, TYPE_CHECKING @final class ControllerCommand: @@ -197,32 +197,33 @@ class ControllerServerRequest: Python binding for the Rust ControllerServerRequest enum. """ - @final - class Run(ControllerServerRequest): - """ - Create a Run request variant. + if TYPE_CHECKING: + @final + class Run(ControllerServerRequest): + """ + Create a Run request variant. - Args: - command: The RunCommand to execute + Args: + command: The RunCommand to execute - Returns: - A ControllerServerRequest.Run instance - """ - def __init__( - self, - command: RunCommand, - ) -> None: ... + Returns: + A ControllerServerRequest.Run instance + """ + def __init__( + self, + command: RunCommand, + ) -> None: ... - @final - class Exit(ControllerServerRequest): - """ - Create an Exit request variant. + @final + class Exit(ControllerServerRequest): + """ + Create an Exit request variant. - Returns: - A ControllerServerRequest.Exit instance - """ + Returns: + A ControllerServerRequest.Exit instance + """ - pass + pass def to_json(self) -> str: """ @@ -241,19 +242,20 @@ class ControllerServerResponse: Python binding for the Rust ControllerServerResponse enum. """ - @final - class Finished(ControllerServerResponse): - """ - Create a Finished response variant. + if TYPE_CHECKING: + @final + class Finished(ControllerServerResponse): + """ + Create a Finished response variant. - Args: - error: An optional error message if the operation failed + Args: + error: An optional error message if the operation failed - Returns: - A ControllerServerResponse.Finished instance - """ + Returns: + A ControllerServerResponse.Finished instance + """ - error: Optional[str] + error: Optional[str] @classmethod def from_json(cls, json: str) -> "ControllerServerResponse": diff --git a/python/monarch/_rust_bindings/monarch_extension/client.pyi b/python/monarch/_rust_bindings/monarch_extension/client.pyi index 6c69af037..98121db91 100644 --- a/python/monarch/_rust_bindings/monarch_extension/client.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/client.pyi @@ -4,12 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, ClassVar, Dict, final, List, NamedTuple, Union +from typing import Any, ClassVar, Dict, final, List, NamedTuple, TYPE_CHECKING, Union from monarch._rust_bindings.monarch_extension.tensor_worker import Ref from monarch._rust_bindings.monarch_hyperactor.proc import ActorId, Proc, Serialized from monarch._rust_bindings.monarch_hyperactor.shape import Slice as NDSlice -from monarch._rust_bindings.monarch_messages.debugger import DebuggerActionType + +if TYPE_CHECKING: + from monarch._rust_bindings.monarch_messages.debugger import DebuggerActionType class Exception: """ @@ -113,9 +115,9 @@ class WorkerResponse: @final class LogLevel: - INFO: ClassVar[LogLevel] - WARNING: ClassVar[LogLevel] - ERROR: ClassVar[LogLevel] + INFO: ClassVar["LogLevel"] + WARNING: ClassVar["LogLevel"] + ERROR: ClassVar["LogLevel"] @final class LogMessage: @@ -189,7 +191,7 @@ class ClientActor: def __init__(self, proc: Proc, actor_name: str) -> None: ... @staticmethod - def new_with_parent(proc: Proc, parent_id: ActorId) -> ClientActor: + def new_with_parent(proc: Proc, parent_id: ActorId) -> "ClientActor": """ Create a new client actor with the given parent id. This is used to create a client actor that is a child of another client actor. @@ -235,7 +237,7 @@ class ClientActor: def get_next_message( self, *, timeout_msec: int | None = None - ) -> LogMessage | WorkerResponse | DebuggerMessage | None: + ) -> "LogMessage | WorkerResponse | DebuggerMessage | None": """Get the next message sent to the actor. Arguments: @@ -248,7 +250,7 @@ class ClientActor: """Stop the system.""" ... - def drain_and_stop(self) -> List[LogMessage | WorkerResponse | DebuggerMessage]: + def drain_and_stop(self) -> "List[LogMessage | WorkerResponse | DebuggerMessage]": """Stop the actor and drain all messages.""" ... @@ -276,7 +278,7 @@ class DebuggerMessage: """ def __init__( - self, *, debugger_actor_id: ActorId, action: DebuggerActionType + self, *, debugger_actor_id: ActorId, action: "DebuggerActionType" ) -> None: ... @property def debugger_actor_id(self) -> ActorId: @@ -284,6 +286,6 @@ class DebuggerMessage: ... @property - def action(self) -> DebuggerActionType: + def action(self) -> "DebuggerActionType": """Get the debugger action.""" ... diff --git a/python/monarch/_rust_bindings/monarch_extension/code_sync.pyi b/python/monarch/_rust_bindings/monarch_extension/code_sync.pyi index 66413522c..5099bec0c 100644 --- a/python/monarch/_rust_bindings/monarch_extension/code_sync.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/code_sync.pyi @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from pathlib import Path -from typing import final +from typing import final, TYPE_CHECKING from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh @@ -15,13 +15,15 @@ class WorkspaceLocation: """ Python binding for the Rust WorkspaceLocation enum. """ - @final - class Constant(WorkspaceLocation): - def __init__(self, path) -> None: ... - @final - class FromEnvVar(WorkspaceLocation): - def __init__(self, var) -> None: ... + if TYPE_CHECKING: + @final + class Constant(WorkspaceLocation): + def __init__(self, path) -> None: ... + + @final + class FromEnvVar(WorkspaceLocation): + def __init__(self, var) -> None: ... def resolve(self) -> Path: """ @@ -44,8 +46,8 @@ class CodeSyncMethod: Python binding for the Rust CodeSyncMethod enum. """ - Rsync: CodeSyncMethod - CondaSync: CodeSyncMethod + Rsync: "CodeSyncMethod" + CondaSync: "CodeSyncMethod" @final class RemoteWorkspace: @@ -75,7 +77,7 @@ class CodeSyncMeshClient: @staticmethod def spawn_blocking( proc_mesh: ProcMesh, - ) -> CodeSyncMeshClient: ... + ) -> "CodeSyncMeshClient": ... async def sync_workspace( self, *, diff --git a/python/monarch/_rust_bindings/monarch_extension/controller.pyi b/python/monarch/_rust_bindings/monarch_extension/controller.pyi index d13a37af6..11c9fc583 100644 --- a/python/monarch/_rust_bindings/monarch_extension/controller.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/controller.pyi @@ -55,7 +55,7 @@ class Node: ... @staticmethod - def from_serialized(serialized: Serialized) -> Node: + def from_serialized(serialized: Serialized) -> "Node": """Deserialize the message from a Serialized object.""" ... @@ -87,6 +87,6 @@ class Send: ... @staticmethod - def from_serialized(serialized: Serialized) -> Send: + def from_serialized(serialized: Serialized) -> "Send": """Deserialize the message from a Serialized object.""" ... diff --git a/python/monarch/_rust_bindings/monarch_extension/debugger.pyi b/python/monarch/_rust_bindings/monarch_extension/debugger.pyi index c0aafa367..f44c8d83a 100644 --- a/python/monarch/_rust_bindings/monarch_extension/debugger.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/debugger.pyi @@ -4,19 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import final, Optional, Union +from typing import final, Optional, TYPE_CHECKING, Union from monarch._rust_bindings.monarch_hyperactor.proc import Serialized -from monarch._rust_bindings.monarch_messages.debugger import ( - DebuggerAction, - DebuggerActionType, -) +from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction + +if TYPE_CHECKING: + from monarch._rust_bindings.monarch_messages.debugger import DebuggerActionType @final class DebuggerMessage: """A message for debugger communication between worker and client.""" - def __init__(self, action: DebuggerActionType) -> None: + def __init__(self, action: "DebuggerActionType") -> None: """ Create a new DebuggerMessage. @@ -26,7 +26,7 @@ class DebuggerMessage: ... @property - def action(self) -> DebuggerActionType: + def action(self) -> "DebuggerActionType": """Get the debugger action contained in this message.""" ... @@ -47,7 +47,7 @@ class PdbActor: """Create a new PdbActor.""" ... - def send(self, action: DebuggerActionType) -> None: + def send(self, action: "DebuggerActionType") -> None: """ Send a debugger action to the worker. @@ -56,7 +56,7 @@ class PdbActor: """ ... - def receive(self) -> Optional[DebuggerActionType]: + def receive(self) -> Optional["DebuggerActionType"]: """ Receive a debugger action from the worker. diff --git a/python/monarch/_rust_bindings/monarch_extension/logging.pyi b/python/monarch/_rust_bindings/monarch_extension/logging.pyi index 5d6f11960..96df742b9 100644 --- a/python/monarch/_rust_bindings/monarch_extension/logging.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/logging.pyi @@ -17,7 +17,7 @@ class LoggingMeshClient: Python binding for the Rust LoggingMeshClient. """ @staticmethod - def spawn(proc_mesh: ProcMesh) -> PythonTask[LoggingMeshClient]: ... + def spawn(proc_mesh: ProcMesh) -> PythonTask["LoggingMeshClient"]: ... def set_mode( self, stream_to_client: bool, aggregate_window_sec: int | None, level: int ) -> None: ... diff --git a/python/monarch/_rust_bindings/monarch_extension/tensor_worker.pyi b/python/monarch/_rust_bindings/monarch_extension/tensor_worker.pyi index 9936090bc..aa5bde303 100644 --- a/python/monarch/_rust_bindings/monarch_extension/tensor_worker.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/tensor_worker.pyi @@ -4,9 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, final, Optional, Sequence, Tuple +from typing import Callable, final, Optional, Sequence, TYPE_CHECKING -import torch +if TYPE_CHECKING: + import torch from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.shape import Slice @@ -29,12 +30,12 @@ class Ref: ... def __repr__(self) -> str: ... - def __lt__(self, other: Ref) -> bool: ... - def __le__(self, other: Ref) -> bool: ... - def __eq__(self, value: Ref) -> bool: ... - def __ne__(self, value: Ref) -> bool: ... - def __gt__(self, other: Ref) -> bool: ... - def __ge__(self, other: Ref) -> bool: ... + def __lt__(self, other: "Ref") -> bool: ... + def __le__(self, other: "Ref") -> bool: ... + def __eq__(self, value: "Ref") -> bool: ... + def __ne__(self, value: "Ref") -> bool: ... + def __gt__(self, other: "Ref") -> bool: ... + def __ge__(self, other: "Ref") -> bool: ... def __hash__(self) -> int: ... def __getnewargs_ex__(self) -> tuple[tuple, dict]: ... @@ -80,11 +81,11 @@ class TensorFactory: *, size: Sequence[int], # pyre-ignore - dtype: torch.dtype, + dtype: "torch.dtype", # pyre-ignore - layout: torch.layout, + layout: "torch.layout", # pyre-ignore - device: torch.device, + device: "torch.device", ) -> None: ... @property def size(self) -> tuple[int, ...]: @@ -92,12 +93,12 @@ class TensorFactory: ... @property - def dtype(self) -> torch.dtype: + def dtype(self) -> "torch.dtype": """The data type of the tensor.""" ... @property - def layout(self) -> torch.layout: + def layout(self) -> "torch.layout": """The layout of the tensor.""" ... @@ -149,11 +150,11 @@ class StreamCreationMode: Used to specify what CUDA stream to use for the worker stream creation. """ - UseDefaultStream: StreamCreationMode - CreateNewStream: StreamCreationMode + UseDefaultStream: "StreamCreationMode" + CreateNewStream: "StreamCreationMode" - def __eq__(self, value: StreamCreationMode) -> bool: ... - def __ne__(self, value: StreamCreationMode) -> bool: ... + def __eq__(self, value: "StreamCreationMode") -> bool: ... + def __ne__(self, value: "StreamCreationMode") -> bool: ... def __repr__(self) -> str: ... def __int__(self) -> int: ... @@ -161,15 +162,15 @@ class StreamCreationMode: class ReductionType: """Used to specify the reduction type for the Reduce command.""" - Stack: ReductionType - Sum: ReductionType - Prod: ReductionType - Max: ReductionType - Min: ReductionType - Avg: ReductionType + Stack: "ReductionType" + Sum: "ReductionType" + Prod: "ReductionType" + Max: "ReductionType" + Min: "ReductionType" + Avg: "ReductionType" - def __eq__(self, value: ReductionType) -> bool: ... - def __ne__(self, value: ReductionType) -> bool: ... + def __eq__(self, value: "ReductionType") -> bool: ... + def __ne__(self, value: "ReductionType") -> bool: ... def __repr__(self) -> str: ... def __int__(self) -> int: ... diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi index 1fc4607b2..b6167d177 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi @@ -137,10 +137,10 @@ class Exception(PythonMessageKind): class CallMethod(PythonMessageKind): def __init__( - self, name: MethodSpecifier, response_port: PortRef | OncePortRef | None + self, name: "MethodSpecifier", response_port: PortRef | OncePortRef | None ) -> None: ... @property - def name(self) -> MethodSpecifier: ... + def name(self) -> "MethodSpecifier": ... @property def response_port(self) -> PortRef | OncePortRef | None: ... diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi index 7ce339179..9fd0206c0 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi @@ -41,7 +41,7 @@ class PythonActorMesh(ActorMeshProtocol): pass class PythonActorMeshImpl: - def get_supervision_event(self) -> ActorSupervisionEvent | None: + def get_supervision_event(self) -> Optional["ActorSupervisionEvent"]: """ Returns supervision event if there is any. """ diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi index a1ea21473..35207db55 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi @@ -6,12 +6,14 @@ # pyre-strict -from typing import final, Protocol +from typing import final, Optional, Protocol, TYPE_CHECKING -from monarch._rust_bindings.monarch_hyperactor.actor import ( - PythonMessage, - UndeliverableMessageEnvelope, -) +from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage + +if TYPE_CHECKING: + from monarch._rust_bindings.monarch_hyperactor.actor import ( + UndeliverableMessageEnvelope, + ) from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask @@ -43,7 +45,7 @@ class PortId: ... @staticmethod - def from_string(port_id_str: str) -> PortId: + def from_string(port_id_str: str) -> "PortId": """ Parse a port id from the provided string. """ @@ -58,7 +60,7 @@ class PortHandle: def send(self, message: PythonMessage) -> None: """Send a message to the port's receiver.""" - def bind(self) -> PortRef: + def bind(self) -> "PortRef": """Bind this port. The returned port ref can be used to reach the port externally.""" ... @@ -68,7 +70,7 @@ class PortRef: A reference to a remote port over which PythonMessages can be sent. """ - def send(self, mailbox: Mailbox, message: PythonMessage) -> None: + def send(self, mailbox: "Mailbox", message: PythonMessage) -> None: """Send a single message to the port's receiver.""" ... @@ -93,10 +95,10 @@ class UndeliverablePortReceiver: """ A receiver to which undeliverable message envelopes are sent. """ - async def recv(self) -> UndeliverableMessageEnvelope: + async def recv(self) -> "UndeliverableMessageEnvelope": """Receive a single undeliverable message from the port's sender.""" ... - def blocking_recv(self) -> UndeliverableMessageEnvelope: + def blocking_recv(self) -> "UndeliverableMessageEnvelope": """Receive a single undeliverable message from the port's sender.""" ... @@ -110,7 +112,7 @@ class OncePortHandle: """Send a single message to the port's receiver.""" ... - def bind(self) -> OncePortRef: + def bind(self) -> "OncePortRef": """Bind this port. The returned port ID can be used to reach the port externally.""" ... @@ -120,7 +122,7 @@ class OncePortRef: A reference to a remote once port over which a single PythonMessages can be sent. """ - def send(self, mailbox: Mailbox, message: PythonMessage) -> None: + def send(self, mailbox: "Mailbox", message: PythonMessage) -> None: """Send a single message to the port's receiver.""" ... @@ -148,7 +150,7 @@ class Mailbox: ... def open_accum_port( - self, accumulator: Accumulator + self, accumulator: "Accumulator" ) -> tuple[PortHandle, PortReceiver]: """Open a accum port.""" ... @@ -202,7 +204,7 @@ class Accumulator(Protocol): Define the initial state of this accumulator. """ @property - def reducer(self) -> Reducer | None: ... + def reducer(self) -> Optional["Reducer"]: ... """ The reducer associated with this accumulator. """ diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/proc.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/proc.pyi index 0f7c5abe3..217c8ddad 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/proc.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/proc.pyi @@ -6,9 +6,12 @@ # pyre-strict -from typing import final, Optional, Type +from typing import final, Optional, Type, TYPE_CHECKING -from monarch._rust_bindings.monarch_hyperactor.actor import Actor, PythonActorHandle +from monarch._rust_bindings.monarch_hyperactor.actor import PythonActorHandle + +if TYPE_CHECKING: + from monarch._rust_bindings.monarch_hyperactor.actor import Actor from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox def init_proc( @@ -18,7 +21,7 @@ def init_proc( timeout: int = 5, supervision_update_interval: int = 0, listen_addr: Optional[str] = None, -) -> Proc: +) -> "Proc": """ Helper function to bootstrap a new Proc. @@ -86,7 +89,7 @@ class ActorId: ... @staticmethod - def from_string(actor_id_str: str) -> ActorId: + def from_string(actor_id_str: str) -> "ActorId": """ Create an ActorId from a string representation. @@ -120,7 +123,7 @@ class Proc: """Destroy the Proc.""" ... - async def spawn(self, actor: Type[Actor]) -> PythonActorHandle: + async def spawn(self, actor: Type["Actor"]) -> PythonActorHandle: """ Spawn a new actor. diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi index 435f4ed2f..ddb35e31f 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi @@ -6,9 +6,10 @@ # pyre-strict -from typing import AsyncIterator, final, Literal, overload, Type +from typing import AsyncIterator, final, Literal, overload, Type, TYPE_CHECKING -from monarch._rust_bindings.monarch_hyperactor.actor import Actor +if TYPE_CHECKING: + from monarch._rust_bindings.monarch_hyperactor.actor import Actor from monarch._rust_bindings.monarch_hyperactor.actor_mesh import ( PythonActorMesh, PythonActorMeshImpl, @@ -23,7 +24,7 @@ from monarch._rust_bindings.monarch_hyperactor.shape import Shape @final class ProcMesh: @classmethod - def allocate_nonblocking(self, alloc: Alloc) -> PythonTask[ProcMesh]: + def allocate_nonblocking(self, alloc: Alloc) -> PythonTask["ProcMesh"]: """ Allocate a process mesh according to the provided alloc. Returns when the mesh is fully allocated. @@ -34,7 +35,7 @@ class ProcMesh: ... def spawn_nonblocking( - self, name: str, actor: Type[Actor] + self, name: str, actor: Type["Actor"] ) -> PythonTask[PythonActorMesh]: """ Spawn a new actor on this mesh. @@ -47,9 +48,9 @@ class ProcMesh: @staticmethod def spawn_async( - proc_mesh: Shared[ProcMesh], name: str, actor: Type[Actor], emulated: bool + proc_mesh: Shared["ProcMesh"], name: str, actor: Type["Actor"], emulated: bool ) -> PythonActorMesh: ... - async def monitor(self) -> ProcMeshMonitor: + async def monitor(self) -> "ProcMeshMonitor": """ Returns a supervision monitor for this mesh. """ diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/pytokio.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/pytokio.pyi index 298808f07..79960ef47 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/pytokio.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/pytokio.pyi @@ -34,7 +34,7 @@ class PythonTask(Generic[T], Awaitable[T]): """ ... - def spawn(self) -> Shared[T]: + def spawn(self) -> "Shared[T]": """ Schedule this task to run on concurrently on the tokio runtime. Returns a handle that can be awaited on multiple times so the @@ -43,7 +43,7 @@ class PythonTask(Generic[T], Awaitable[T]): ... @staticmethod - def from_coroutine(coro: Coroutine[Any, Any, T]) -> PythonTask[T]: + def from_coroutine(coro: Coroutine[Any, Any, T]) -> "PythonTask[T]": """ Create a PythonTask from a python coroutine. The coroutine should only await on other PythonTasks created using the pytokio APIs. @@ -53,7 +53,7 @@ class PythonTask(Generic[T], Awaitable[T]): ... @staticmethod - def spawn_blocking(fn: Callable[[], T]) -> Shared[T]: + def spawn_blocking(fn: Callable[[], T]) -> "Shared[T]": """ Concurrently run a python function in a way where it is acceptable for it to make synchronous calls back into Tokio. See tokio::spawn_blocking for more information. @@ -67,7 +67,7 @@ class PythonTask(Generic[T], Awaitable[T]): """ ... - def with_timeout(self, seconds: float) -> PythonTask[T]: + def with_timeout(self, seconds: float) -> "PythonTask[T]": """ Perform the task but throw a TimeoutException if not finished in 'seconds' seconds. """ diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/selection.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/selection.pyi index 679bc7597..50ff96120 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/selection.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/selection.pyi @@ -18,7 +18,7 @@ class Selection: """ def __repr__(self) -> str: ... @classmethod - def from_string(cls, s: str) -> Selection: + def from_string(cls, s: str) -> "Selection": """Parse a selection expression from a string. Accepts a compact string syntax such as `"(*, 0:4)"` or `"0 & (1 | 2)"`, @@ -30,7 +30,7 @@ class Selection: ... @classmethod - def any(cls) -> Selection: + def any(cls) -> "Selection": """Selects one element nondeterministically — use this to mean "route to a single random node". @@ -41,7 +41,7 @@ class Selection: ... @classmethod - def all(cls) -> Selection: + def all(cls) -> "Selection": """Selects all elements in the mesh — use this to mean "route to all nodes". diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi index e979f1e29..cb3588ea5 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi @@ -77,7 +77,7 @@ class Slice: @overload def __getitem__(self, i: int) -> int: ... @overload - def __getitem__(self, i: slice[Any, Any, Any]) -> tuple[int, ...]: ... + def __getitem__(self, i: "slice[Any, Any, Any]") -> tuple[int, ...]: ... def __len__(self) -> int: """Returns the complete size of the slice.""" ... @@ -130,7 +130,7 @@ class Shape: """ ... - def select(self, label: str, slice: slice[Any, Any, Any]) -> "Shape": + def select(self, label: str, slice: "slice[Any, Any, Any]") -> "Shape": """ Restrict this shape along a named dimension using a slice. The dimension is kept but its size may change. diff --git a/python/monarch/_rust_bindings/monarch_tensor_worker/bootstrap.pyi b/python/monarch/_rust_bindings/monarch_tensor_worker/bootstrap.pyi index a1cf235ed..cfc5ac737 100644 --- a/python/monarch/_rust_bindings/monarch_tensor_worker/bootstrap.pyi +++ b/python/monarch/_rust_bindings/monarch_tensor_worker/bootstrap.pyi @@ -4,45 +4,46 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import final, Optional, Tuple +from typing import final, Optional, Tuple, TYPE_CHECKING class WorkerServerRequest: """ Python binding for the Rust WorkerServerRequest enum. """ - @final - class Run(WorkerServerRequest): - """ - Create a Run request variant. - - Args: - world_id: The ID of the world - proc_id: The ID of the process - bootstrap_addr: The bootstrap address - - Returns: - A WorkerServerRequest.Run instance - """ - def __init__( - self, - *, - world_id: str, - proc_id: str, - bootstrap_addr: str, - labels: list[Tuple[str, str]], - ) -> None: ... - - @final - class Exit(WorkerServerRequest): - """ - Create an Exit request variant. - - Returns: - A WorkerServerRequest.Exit instance - """ - - pass + if TYPE_CHECKING: + @final + class Run(WorkerServerRequest): + """ + Create a Run request variant. + + Args: + world_id: The ID of the world + proc_id: The ID of the process + bootstrap_addr: The bootstrap address + + Returns: + A WorkerServerRequest.Run instance + """ + def __init__( + self, + *, + world_id: str, + proc_id: str, + bootstrap_addr: str, + labels: list[Tuple[str, str]], + ) -> None: ... + + @final + class Exit(WorkerServerRequest): + """ + Create an Exit request variant. + + Returns: + A WorkerServerRequest.Exit instance + """ + + pass def to_json(self) -> str: """ @@ -61,19 +62,20 @@ class WorkerServerResponse: Python binding for the Rust WorkerServerResponse enum. """ - @final - class Finished(WorkerServerResponse): - """ - Create a Finished response variant. + if TYPE_CHECKING: + @final + class Finished(WorkerServerResponse): + """ + Create a Finished response variant. - Args: - error: An optional error message if the operation failed + Args: + error: An optional error message if the operation failed - Returns: - A WorkerServerResponse.Finished instance - """ + Returns: + A WorkerServerResponse.Finished instance + """ - error: Optional[str] + error: Optional[str] @classmethod def from_json(cls, json: str) -> "WorkerServerResponse": diff --git a/python/monarch/_src/actor/python_extension_methods.py b/python/monarch/_src/actor/python_extension_methods.py new file mode 100644 index 000000000..40812e55b --- /dev/null +++ b/python/monarch/_src/actor/python_extension_methods.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import types + +from pathlib import Path + +import monarch + + +def load_module_from_path(base_path, module_specifier): + parts = module_specifier.split(".") + file_path = str(Path(base_path).joinpath(*parts).with_suffix(".pyi")) + loader = importlib.machinery.SourceFileLoader(module_specifier, file_path) + spec = importlib.util.spec_from_file_location( + module_specifier, file_path, loader=loader + ) + + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + except FileNotFoundError: + return None + return module + + +def patch_class(rust_entry, python_entry): + for name, implementation in python_entry.__dict__.items(): + if hasattr(rust_entry, name): + # do not patch in the stub methods that + # are already defined by the rust implementation + continue + if not callable(implementation): + continue + setattr(rust_entry, name, implementation) + + +def patch_module(rust, python): + for name in dir(rust): + python_entry = getattr(python, name, None) + if python_entry is None: + continue + rust_entry = getattr(rust, name) + if not isinstance(rust_entry, type): + continue + patch_class(rust_entry, python_entry) + + +def add_extension_methods(bindings: types.ModuleType): + """ + When we bind a rust struct into Python, it is sometimes faster to implement + parts of the desired Python API in Python. It is also easier to understand + what the class does in terms of these methods. + + We also want to avoid having to wrap rust objects in another layer of python objects + because: + * wrappers double the python overhead + * it is easy to confuse which level of wrappers and API takes, especially + along the python<->rust boundary. + + To avoid wrappers we first define the class in pyo3. + We then write the python typing stubs in the pyi file for the functions rust defined. + We also add any python extension methods, including their implementation, + to the stub files. + + This function then loads the stub files and patch the real rust implementation + with those typing methods. + + Using the stub files themselves can seem like an odd choice but has a lot of + desirable properties: + + * we get accurate typechecking in: + - the implementation of extension methods + - the use of rust methods + - the use of extension methods + * go to definition in the IDE will go to the stub file, so it is easy to find + the python impelmentations compared to putting them somewhere else + * With no wrappers, any time rust code returns a class defined this way, + it automatically gains its extension methods. + """ + base_path = str(Path(monarch.__file__).parent.parent) + + def scan(module): + for item in dir(module): + value = getattr(module, item, None) + if isinstance(value, types.ModuleType): + scan(value) + + python = load_module_from_path(base_path, module.__name__) + if python is not None: + patch_module(module, python) + + scan(bindings)