Skip to content

Commit b97c82c

Browse files
committed
add a way to write python methods for rust-bound classes
This will be used to generate future python bindings for Context and Instance. The implementation documents how it works. This is the least bad way I could figure out how to do this since pyo3 does not natively let you subclass an existing python class. Differential Revision: [D80374838](https://our.internmc.facebook.com/intern/diff/D80374838/) ghstack-source-id: 303417472 Pull Request resolved: #901
1 parent 44f85b5 commit b97c82c

File tree

21 files changed

+430
-235
lines changed

21 files changed

+430
-235
lines changed

monarch_extension/src/lib.rs

Lines changed: 123 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,21 @@ mod tensor_worker;
2727

2828
mod blocking;
2929
mod panic;
30+
31+
use monarch_types::py_global;
3032
use pyo3::prelude::*;
3133

3234
#[pyfunction]
3335
fn has_tensor_engine() -> bool {
3436
cfg!(feature = "tensor_engine")
3537
}
3638

39+
py_global!(
40+
add_extension_methods,
41+
"monarch._src.actor.python_extension_methods",
42+
"add_extension_methods"
43+
);
44+
3745
fn get_or_add_new_module<'py>(
3846
module: &Bound<'py, PyModule>,
3947
module_name: &str,
@@ -46,22 +54,29 @@ fn get_or_add_new_module<'py>(
4654
if let Some(submodule) = submodule {
4755
current_module = submodule.extract()?;
4856
} else {
49-
let new_module = PyModule::new(current_module.py(), part)?;
50-
current_module.add_submodule(&new_module)?;
57+
let name = format!("monarch._rust_bindings.{}", parts.join("."));
58+
let new_module = PyModule::new(current_module.py(), &name)?;
59+
current_module.add(part, new_module.clone())?;
5160
current_module
5261
.py()
5362
.import("sys")?
5463
.getattr("modules")?
55-
.set_item(
56-
format!("monarch._rust_bindings.{}", parts.join(".")),
57-
new_module.clone(),
58-
)?;
64+
.set_item(name, new_module.clone())?;
5965
current_module = new_module;
6066
}
6167
}
6268
Ok(current_module)
6369
}
6470

71+
fn register<F>(module: &Bound<'_, PyModule>, module_path: &str, register_fn: F) -> PyResult<()>
72+
where
73+
F: FnOnce(&Bound<'_, PyModule>) -> PyResult<()>,
74+
{
75+
let submodule = get_or_add_new_module(module, module_path)?;
76+
register_fn(&submodule)?;
77+
Ok(())
78+
}
79+
6580
#[pymodule]
6681
#[pyo3(name = "_rust_bindings")]
6782
pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> {
@@ -71,153 +86,188 @@ pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> {
7186
runtime.handle().clone(),
7287
Some(::hyperactor_mesh::bootstrap::BOOTSTRAP_INDEX_ENV.to_string()),
7388
);
74-
75-
monarch_hyperactor::shape::register_python_bindings(&get_or_add_new_module(
89+
register(
7690
module,
7791
"monarch_hyperactor.shape",
78-
)?)?;
79-
80-
monarch_hyperactor::selection::register_python_bindings(&get_or_add_new_module(
92+
monarch_hyperactor::shape::register_python_bindings,
93+
)?;
94+
register(
8195
module,
8296
"monarch_hyperactor.selection",
83-
)?)?;
84-
85-
monarch_hyperactor::supervision::register_python_bindings(&get_or_add_new_module(
97+
monarch_hyperactor::selection::register_python_bindings,
98+
)?;
99+
register(
86100
module,
87101
"monarch_hyperactor.supervision",
88-
)?)?;
102+
monarch_hyperactor::supervision::register_python_bindings,
103+
)?;
89104

90105
#[cfg(feature = "tensor_engine")]
91106
{
92-
client::register_python_bindings(&get_or_add_new_module(
107+
register(
93108
module,
94109
"monarch_extension.client",
95-
)?)?;
96-
tensor_worker::register_python_bindings(&get_or_add_new_module(
110+
client::register_python_bindings,
111+
)?;
112+
register(
97113
module,
98114
"monarch_extension.tensor_worker",
99-
)?)?;
100-
controller::register_python_bindings(&get_or_add_new_module(
115+
tensor_worker::register_python_bindings,
116+
)?;
117+
register(
101118
module,
102119
"monarch_extension.controller",
103-
)?)?;
104-
debugger::register_python_bindings(&get_or_add_new_module(
120+
controller::register_python_bindings,
121+
)?;
122+
register(
105123
module,
106124
"monarch_extension.debugger",
107-
)?)?;
108-
monarch_messages::debugger::register_python_bindings(&get_or_add_new_module(
125+
debugger::register_python_bindings,
126+
)?;
127+
register(
109128
module,
110129
"monarch_messages.debugger",
111-
)?)?;
112-
simulator_client::register_python_bindings(&get_or_add_new_module(
130+
monarch_messages::debugger::register_python_bindings,
131+
)?;
132+
register(
113133
module,
114134
"monarch_extension.simulator_client",
115-
)?)?;
116-
::controller::bootstrap::register_python_bindings(&get_or_add_new_module(
135+
simulator_client::register_python_bindings,
136+
)?;
137+
register(
117138
module,
118139
"controller.bootstrap",
119-
)?)?;
120-
::monarch_tensor_worker::bootstrap::register_python_bindings(&get_or_add_new_module(
140+
::controller::bootstrap::register_python_bindings,
141+
)?;
142+
register(
121143
module,
122144
"monarch_tensor_worker.bootstrap",
123-
)?)?;
124-
crate::convert::register_python_bindings(&get_or_add_new_module(
145+
::monarch_tensor_worker::bootstrap::register_python_bindings,
146+
)?;
147+
register(
125148
module,
126149
"monarch_extension.convert",
127-
)?)?;
128-
crate::mesh_controller::register_python_bindings(&get_or_add_new_module(
150+
crate::convert::register_python_bindings,
151+
)?;
152+
register(
129153
module,
130154
"monarch_extension.mesh_controller",
131-
)?)?;
132-
monarch_rdma_extension::register_python_bindings(&get_or_add_new_module(module, "rdma")?)?;
155+
crate::mesh_controller::register_python_bindings,
156+
)?;
157+
register(
158+
module,
159+
"rdma",
160+
monarch_rdma_extension::register_python_bindings,
161+
)?;
133162
}
134-
simulation_tools::register_python_bindings(&get_or_add_new_module(
163+
register(
135164
module,
136165
"monarch_extension.simulation_tools",
137-
)?)?;
138-
monarch_hyperactor::bootstrap::register_python_bindings(&get_or_add_new_module(
166+
simulation_tools::register_python_bindings,
167+
)?;
168+
register(
139169
module,
140170
"monarch_hyperactor.bootstrap",
141-
)?)?;
171+
monarch_hyperactor::bootstrap::register_python_bindings,
172+
)?;
142173

143-
monarch_hyperactor::proc::register_python_bindings(&get_or_add_new_module(
174+
register(
144175
module,
145176
"monarch_hyperactor.proc",
146-
)?)?;
177+
monarch_hyperactor::proc::register_python_bindings,
178+
)?;
147179

148-
monarch_hyperactor::actor::register_python_bindings(&get_or_add_new_module(
180+
register(
149181
module,
150182
"monarch_hyperactor.actor",
151-
)?)?;
183+
monarch_hyperactor::actor::register_python_bindings,
184+
)?;
152185

153-
monarch_hyperactor::pytokio::register_python_bindings(&get_or_add_new_module(
186+
register(
154187
module,
155188
"monarch_hyperactor.pytokio",
156-
)?)?;
157-
158-
monarch_hyperactor::mailbox::register_python_bindings(&get_or_add_new_module(
189+
monarch_hyperactor::pytokio::register_python_bindings,
190+
)?;
191+
register(
159192
module,
160193
"monarch_hyperactor.mailbox",
161-
)?)?;
194+
monarch_hyperactor::mailbox::register_python_bindings,
195+
)?;
162196

163-
monarch_hyperactor::alloc::register_python_bindings(&get_or_add_new_module(
197+
register(
164198
module,
165199
"monarch_hyperactor.alloc",
166-
)?)?;
167-
monarch_hyperactor::channel::register_python_bindings(&get_or_add_new_module(
200+
monarch_hyperactor::alloc::register_python_bindings,
201+
)?;
202+
register(
168203
module,
169204
"monarch_hyperactor.channel",
170-
)?)?;
171-
monarch_hyperactor::actor_mesh::register_python_bindings(&get_or_add_new_module(
205+
monarch_hyperactor::channel::register_python_bindings,
206+
)?;
207+
register(
172208
module,
173209
"monarch_hyperactor.actor_mesh",
174-
)?)?;
175-
monarch_hyperactor::proc_mesh::register_python_bindings(&get_or_add_new_module(
210+
monarch_hyperactor::actor_mesh::register_python_bindings,
211+
)?;
212+
register(
176213
module,
177214
"monarch_hyperactor.proc_mesh",
178-
)?)?;
215+
monarch_hyperactor::proc_mesh::register_python_bindings,
216+
)?;
179217

180-
monarch_hyperactor::runtime::register_python_bindings(&get_or_add_new_module(
218+
register(
181219
module,
182220
"monarch_hyperactor.runtime",
183-
)?)?;
184-
monarch_hyperactor::telemetry::register_python_bindings(&get_or_add_new_module(
221+
monarch_hyperactor::runtime::register_python_bindings,
222+
)?;
223+
register(
185224
module,
186225
"monarch_hyperactor.telemetry",
187-
)?)?;
188-
code_sync::register_python_bindings(&get_or_add_new_module(
226+
monarch_hyperactor::telemetry::register_python_bindings,
227+
)?;
228+
register(
189229
module,
190230
"monarch_extension.code_sync",
191-
)?)?;
231+
code_sync::register_python_bindings,
232+
)?;
192233

193-
crate::panic::register_python_bindings(&get_or_add_new_module(
234+
register(
194235
module,
195236
"monarch_extension.panic",
196-
)?)?;
237+
crate::panic::register_python_bindings,
238+
)?;
197239

198-
crate::blocking::register_python_bindings(&get_or_add_new_module(
240+
register(
199241
module,
200242
"monarch_extension.blocking",
201-
)?)?;
243+
crate::blocking::register_python_bindings,
244+
)?;
202245

203-
crate::logging::register_python_bindings(&get_or_add_new_module(
246+
register(
204247
module,
205248
"monarch_extension.logging",
206-
)?)?;
249+
crate::logging::register_python_bindings,
250+
)?;
207251

208252
#[cfg(fbcode_build)]
209253
{
210-
monarch_hyperactor::meta::alloc::register_python_bindings(&get_or_add_new_module(
254+
register(
211255
module,
212256
"monarch_hyperactor.meta.alloc",
213-
)?)?;
214-
monarch_hyperactor::meta::alloc_mock::register_python_bindings(&get_or_add_new_module(
257+
monarch_hyperactor::meta::alloc::register_python_bindings,
258+
)?;
259+
register(
215260
module,
216261
"monarch_hyperactor.meta.alloc_mock",
217-
)?)?;
262+
monarch_hyperactor::meta::alloc_mock::register_python_bindings,
263+
)?;
218264
}
219265
// Add feature detection function
220266
module.add_function(wrap_pyfunction!(has_tensor_engine, module)?)?;
221267

268+
// this should be called last. otherwise cross references in pyi files will not have been
269+
// added to sys.modules yet.
270+
add_extension_methods(module.py()).call1((module,))?;
271+
222272
Ok(())
223273
}

monarch_types/src/lib.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,27 @@ pub use pyobject::PickledPyObject;
1616
pub use python::SerializablePyErr;
1717
pub use python::TryIntoPyObjectUnsafe;
1818
pub use pytree::PyTree;
19+
20+
/// Macro to generate a Python object lookup function with caching
21+
///
22+
/// # Arguments
23+
/// * `$fn_name` - Name of the Rust function to generate
24+
/// * `$python_path` - Path to the Python object as a string (e.g., "module.submodule.function")
25+
///
26+
/// # Example
27+
/// ```rust
28+
/// py_global!(sys_modules, "sys.modules");
29+
/// ```
30+
#[macro_export]
31+
macro_rules! py_global {
32+
($fn_name:ident, $python_module:literal, $python_class:literal) => {
33+
fn $fn_name<'py>(py: ::pyo3::Python<'py>) -> ::pyo3::Bound<'py, ::pyo3::PyAny> {
34+
static CACHE: ::pyo3::sync::GILOnceCell<::pyo3::PyObject> =
35+
::pyo3::sync::GILOnceCell::new();
36+
CACHE
37+
.import(py, $python_module, $python_class)
38+
.unwrap()
39+
.clone()
40+
}
41+
};
42+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.

0 commit comments

Comments
 (0)