Skip to content

Commit 7b11852

Browse files
eliothedemanmeta-codesync[bot]
authored andcommitted
re-land "[monarch] Pickle directly into rust buffers" after fixing typing issues (#1365)
Summary: Pull Request resolved: #1365 there were some typing issues in the tensor engine code which this fixes. Reviewed By: thomasywang, pablorfb-meta Differential Revision: D83490534 fbshipit-source-id: b14ba29980ccb220d3e471b4113baf1fb4a348e3
1 parent 82ec519 commit 7b11852

File tree

16 files changed

+624
-45
lines changed

16 files changed

+624
-45
lines changed

monarch_extension/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> {
7474
runtime.handle().clone(),
7575
Some(::hyperactor_mesh::bootstrap::BOOTSTRAP_INDEX_ENV.to_string()),
7676
);
77+
monarch_hyperactor::buffers::register_python_bindings(&get_or_add_new_module(
78+
module,
79+
"monarch_hyperactor.buffers",
80+
)?)?;
7781

7882
monarch_hyperactor::shape::register_python_bindings(&get_or_add_new_module(
7983
module,

monarch_extension/src/mesh_controller.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ use hyperactor_mesh::shared_cell::SharedCell;
3838
use hyperactor_mesh::shared_cell::SharedCellRef;
3939
use monarch_hyperactor::actor::PythonMessage;
4040
use monarch_hyperactor::actor::PythonMessageKind;
41+
use monarch_hyperactor::buffers::FrozenBuffer;
4142
use monarch_hyperactor::local_state_broker::LocalStateBrokerActor;
4243
use monarch_hyperactor::mailbox::PyPortId;
4344
use monarch_hyperactor::ndslice::PySlice;
@@ -532,8 +533,11 @@ impl History {
532533
let exe = remote_exception
533534
.call1((exception.backtrace, traceback, rank))
534535
.unwrap();
535-
let data: Vec<u8> = pickle.call1((exe,)).unwrap().extract().unwrap();
536-
PythonMessage::new_from_buf(PythonMessageKind::Exception { rank: Some(rank) }, data)
536+
let data: FrozenBuffer = pickle.call1((exe,)).unwrap().extract().unwrap();
537+
PythonMessage::new_from_buf(
538+
PythonMessageKind::Exception { rank: Some(rank) },
539+
data.inner,
540+
)
537541
}));
538542

539543
let mut invocation = invocation.lock().unwrap();

monarch_hyperactor/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ anyhow = "1.0.98"
2020
async-once-cell = "0.4.2"
2121
async-trait = "0.1.86"
2222
bincode = "1.3.3"
23+
bytes = { version = "1.10", features = ["serde"] }
2324
clap = { version = "4.5.42", features = ["derive", "env", "string", "unicode", "wrap_help"] }
2425
erased-serde = "0.3.27"
2526
fastrand = "2.1.1"
@@ -43,6 +44,7 @@ pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods", "py-clone
4344
pyo3-async-runtimes = { version = "0.24", features = ["attributes", "tokio-runtime"] }
4445
serde = { version = "1.0.219", features = ["derive", "rc"] }
4546
serde_bytes = "0.11"
47+
serde_multipart = { version = "0.0.0", path = "../serde_multipart" }
4648
tempfile = "3.22"
4749
thiserror = "2.0.12"
4850
tokio = { version = "1.47.1", features = ["full", "test-util", "tracing"] }

monarch_hyperactor/src/actor.rs

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use std::sync::Arc;
1313
use std::sync::OnceLock;
1414

1515
use async_trait::async_trait;
16+
use bytes::Bytes;
1617
use hyperactor::Actor;
1718
use hyperactor::ActorHandle;
1819
use hyperactor::ActorId;
@@ -32,6 +33,7 @@ use monarch_types::SerializablePyErr;
3233
use pyo3::IntoPyObjectExt;
3334
use pyo3::exceptions::PyBaseException;
3435
use pyo3::exceptions::PyRuntimeError;
36+
use pyo3::exceptions::PyTypeError;
3537
use pyo3::exceptions::PyValueError;
3638
use pyo3::prelude::*;
3739
use pyo3::types::PyBytes;
@@ -41,12 +43,15 @@ use pyo3::types::PyType;
4143
use serde::Deserialize;
4244
use serde::Serialize;
4345
use serde_bytes::ByteBuf;
46+
use serde_multipart::Part;
4447
use tokio::sync::Mutex;
4548
use tokio::sync::mpsc::UnboundedReceiver;
4649
use tokio::sync::mpsc::UnboundedSender;
4750
use tokio::sync::oneshot;
4851
use tracing::Instrument;
4952

53+
use crate::buffers::Buffer;
54+
use crate::buffers::FrozenBuffer;
5055
use crate::config::SHARED_ASYNCIO_RUNTIME;
5156
use crate::local_state_broker::BrokerId;
5257
use crate::local_state_broker::LocalStateBrokerMessage;
@@ -236,22 +241,24 @@ fn mailbox<'py, T: Actor>(py: Python<'py>, cx: &Context<'_, T>) -> Bound<'py, Py
236241
#[derive(Clone, Serialize, Deserialize, Named, PartialEq, Default)]
237242
pub struct PythonMessage {
238243
pub kind: PythonMessageKind,
239-
#[serde(with = "serde_bytes")]
240-
pub message: Vec<u8>,
244+
pub message: Part,
241245
}
242246

243247
struct ResolvedCallMethod {
244248
method: MethodSpecifier,
245-
bytes: Vec<u8>,
249+
bytes: FrozenBuffer,
246250
local_state: PyObject,
247251
/// Implements PortProtocol
248252
/// Concretely either a Port, DroppingPort, or LocalPort
249253
response_port: PyObject,
250254
}
251255

252256
impl PythonMessage {
253-
pub fn new_from_buf(kind: PythonMessageKind, message: Vec<u8>) -> Self {
254-
Self { kind, message }
257+
pub fn new_from_buf(kind: PythonMessageKind, message: impl Into<Part>) -> Self {
258+
Self {
259+
kind,
260+
message: message.into(),
261+
}
255262
}
256263

257264
pub fn into_rank(self, rank: usize) -> Self {
@@ -305,7 +312,9 @@ impl PythonMessage {
305312
.unwrap();
306313
Ok(ResolvedCallMethod {
307314
method: name,
308-
bytes: self.message,
315+
bytes: FrozenBuffer {
316+
inner: self.message.into_inner(),
317+
},
309318
local_state,
310319
response_port,
311320
})
@@ -341,7 +350,9 @@ impl PythonMessage {
341350
.unbind();
342351
Ok(ResolvedCallMethod {
343352
method: name,
344-
bytes: self.message,
353+
bytes: FrozenBuffer {
354+
inner: self.message.into_inner(),
355+
},
345356
local_state,
346357
response_port,
347358
})
@@ -359,7 +370,7 @@ impl std::fmt::Debug for PythonMessage {
359370
.field("kind", &self.kind)
360371
.field(
361372
"message",
362-
&hyperactor::data::HexFmt(self.message.as_slice()).to_string(),
373+
&hyperactor::data::HexFmt(&(*self.message)[..]).to_string(),
363374
)
364375
.finish()
365376
}
@@ -387,8 +398,20 @@ impl Bind for PythonMessage {
387398
impl PythonMessage {
388399
#[new]
389400
#[pyo3(signature = (kind, message))]
390-
pub fn new(kind: PythonMessageKind, message: &[u8]) -> Self {
391-
PythonMessage::new_from_buf(kind, message.to_vec())
401+
pub fn new<'py>(kind: PythonMessageKind, message: Bound<'py, PyAny>) -> PyResult<Self> {
402+
if let Ok(buff) = message.extract::<Bound<'py, FrozenBuffer>>() {
403+
let frozen = buff.borrow_mut();
404+
return Ok(PythonMessage::new_from_buf(kind, frozen.inner.clone()));
405+
} else if let Ok(buff) = message.extract::<Bound<'py, PyBytes>>() {
406+
return Ok(PythonMessage::new_from_buf(
407+
kind,
408+
Vec::from(buff.as_bytes()),
409+
));
410+
}
411+
412+
Err(PyTypeError::new_err(
413+
"PythonMessage(buff) takes Buffer or bytes objects only",
414+
))
392415
}
393416

394417
#[getter]
@@ -397,8 +420,10 @@ impl PythonMessage {
397420
}
398421

399422
#[getter]
400-
fn message<'a>(&self, py: Python<'a>) -> Bound<'a, PyBytes> {
401-
PyBytes::new(py, self.message.as_ref())
423+
fn message(&self) -> FrozenBuffer {
424+
FrozenBuffer {
425+
inner: self.message.clone().into_inner(),
426+
}
402427
}
403428
}
404429

@@ -842,7 +867,7 @@ mod tests {
842867
},
843868
response_port: Some(EitherPortRef::Unbounded(port_ref.clone().into())),
844869
},
845-
message: vec![1, 2, 3],
870+
message: Part::from(vec![1, 2, 3]),
846871
};
847872
{
848873
let mut erased = ErasedUnbound::try_from_message(message.clone()).unwrap();

0 commit comments

Comments
 (0)