Skip to content

Commit 4c2c4cb

Browse files
zdevitofacebook-github-bot
authored andcommitted
add a way to write python methods for rust-bound classes, round 2 (#913)
Summary: Pull Request resolved: #913 This will be used to generate future python bindings for Context and Instance. The implementation documents how it works. This is the least bad way I could figure out how to do this since pyo3 does not natively let you subclass an existing python class. ghstack-source-id: 304046443 exported-using-ghexport Reviewed By: mariusae Differential Revision: D80487153 fbshipit-source-id: 780376778ad7dcc36a10ae6b575d32ad4672914b
1 parent 1691ae0 commit 4c2c4cb

File tree

4 files changed

+120
-2
lines changed

4 files changed

+120
-2
lines changed

monarch_extension/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ mod tensor_worker;
2828
mod blocking;
2929
mod panic;
3030
mod trace;
31+
32+
use monarch_types::py_global;
3133
use pyo3::prelude::*;
3234

3335
#[pyfunction]

monarch_types/src/lib.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,22 @@ pub use pyobject::PickledPyObject;
1616
pub use python::SerializablePyErr;
1717
pub use python::TryIntoPyObjectUnsafe;
1818
pub use pytree::PyTree;
19+
20+
/// Macro to generate a Python object lookup function with caching
21+
///
22+
/// # Arguments
23+
/// * `$fn_name` - Name of the Rust function to generate
24+
/// * `$python_path` - Path to the Python object as a string (e.g., "module.submodule.function")
25+
#[macro_export]
26+
macro_rules! py_global {
27+
($fn_name:ident, $python_module:literal, $python_class:literal) => {
28+
fn $fn_name<'py>(py: ::pyo3::Python<'py>) -> ::pyo3::Bound<'py, ::pyo3::PyAny> {
29+
static CACHE: ::pyo3::sync::GILOnceCell<::pyo3::PyObject> =
30+
::pyo3::sync::GILOnceCell::new();
31+
CACHE
32+
.import(py, $python_module, $python_class)
33+
.unwrap()
34+
.clone()
35+
}
36+
};
37+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib
8+
9+
from typing import cast, Type, TypeVar
10+
11+
12+
T = TypeVar("T")
13+
14+
15+
class PatchRustClass:
16+
def __init__(self, rust_class: Type):
17+
self.rust_class = rust_class
18+
19+
def __call__(self, python_class: Type[T]) -> Type[T]:
20+
assert self.rust_class.__module__ == python_class.__module__
21+
assert self.rust_class.__name__ == python_class.__name__
22+
for name, implementation in python_class.__dict__.items():
23+
if hasattr(self.rust_class, name):
24+
# do not patch in the stub methods that
25+
# are already defined by the rust implementation
26+
continue
27+
if not callable(implementation):
28+
continue
29+
setattr(self.rust_class, name, implementation)
30+
return cast(Type[T], self.rust_class)
31+
32+
33+
def rust_struct(name: str) -> PatchRustClass:
34+
"""
35+
When we bind a rust struct into Python, it is sometimes faster to implement
36+
parts of the desired Python API in Python. It is also easier to understand
37+
what the class does in terms of these methods.
38+
39+
We also want to avoid having to wrap rust objects in another layer of python objects
40+
because:
41+
* wrappers double the python overhead
42+
* it is easy to confuse which level of wrappers and API takes, especially
43+
along the python<->rust boundary.
44+
45+
46+
To avoid wrappers we first define the class in pyo3. Lets say we add a class
47+
monarch_hyperactor::actor_mesh::TestClass which we will want to extend with python methods in
48+
the monarch/actor/_src/actor_mesh.py. In rust we will define the class as
49+
50+
#[pyclass(name = "TestClass", module = "monarch._src.actor_mesh")]
51+
struct TestClass {}
52+
#[pymethods]
53+
impl TestClass {
54+
fn hello(&self) {
55+
println!("hello");
56+
}
57+
}
58+
59+
Then rather than writing typing stubs in a pyi file we write the stub code directly in
60+
monarch/actor/_src/actor_mesh.py along with any helper methods:
61+
62+
@rust_struct("monarch_hyperactor::actor_mesh::TestClass")
63+
class TestClass:
64+
def hello(self) -> None:
65+
...
66+
def hello_world(self) -> None:
67+
self.hello()
68+
print("world")
69+
70+
This class annotation then merges the python extension methods with the rust
71+
class implementation. Any rust code that returns the TestClass will have the `hello_world`
72+
extension method attached. Python typechecking always things TestClass is the python code,
73+
so typing works.
74+
75+
It is ok to have the pyclass module not match where it is defined because (1) we patch it into the right place
76+
to make sure pickling works, and (2) the rust_struct annotation points directly to where to find the rust code,
77+
and will be discovered by goto line in the IDE.
78+
"""
79+
80+
*modules, name = name.split("::")
81+
module_name = ".".join(modules)
82+
module = importlib.import_module(f"monarch._rust_bindings.{module_name}")
83+
84+
rust_class = getattr(module, name)
85+
86+
return PatchRustClass(rust_class)

python/tests/test_python_actors.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,16 @@ def test_value_mesh() -> None:
246246

247247
@pytest.mark.timeout(60)
248248
def test_rust_binding_modules_correct() -> None:
249+
"""
250+
This tests that rust bindings will survive pickling correctly.
251+
252+
To correctly define a rust binding, either
253+
254+
(1) Set its module to "monarch._rust_bindings.rust_crate.rust_module",
255+
and make sure it is registered in monarch_extension/lib.rs
256+
(2) Set its module to some existing python file, and use @rust_struct to install
257+
the rust struct in that file and patch in any python extension methods.
258+
"""
249259
import monarch._rust_bindings as bindings
250260

251261
def check(module, path):
@@ -255,8 +265,9 @@ def check(module, path):
255265
if isinstance(value, ModuleType):
256266
check(value, f"{path}.{name}")
257267
elif hasattr(value, "__module__"):
258-
assert value.__name__ == name
259-
assert value.__module__ == path
268+
value_module = importlib.import_module(value.__module__)
269+
resolved_value = getattr(value_module, value.__name__)
270+
assert value is resolved_value
260271

261272
check(bindings, "monarch._rust_bindings")
262273

0 commit comments

Comments
 (0)