@@ -13,6 +13,7 @@ use std::sync::Arc;
13
13
use std:: sync:: OnceLock ;
14
14
15
15
use async_trait:: async_trait;
16
+ use bytes:: Bytes ;
16
17
use hyperactor:: Actor ;
17
18
use hyperactor:: ActorHandle ;
18
19
use hyperactor:: ActorId ;
@@ -32,6 +33,7 @@ use monarch_types::SerializablePyErr;
32
33
use pyo3:: IntoPyObjectExt ;
33
34
use pyo3:: exceptions:: PyBaseException ;
34
35
use pyo3:: exceptions:: PyRuntimeError ;
36
+ use pyo3:: exceptions:: PyTypeError ;
35
37
use pyo3:: exceptions:: PyValueError ;
36
38
use pyo3:: prelude:: * ;
37
39
use pyo3:: types:: PyBytes ;
@@ -41,12 +43,15 @@ use pyo3::types::PyType;
41
43
use serde:: Deserialize ;
42
44
use serde:: Serialize ;
43
45
use serde_bytes:: ByteBuf ;
46
+ use serde_multipart:: Part ;
44
47
use tokio:: sync:: Mutex ;
45
48
use tokio:: sync:: mpsc:: UnboundedReceiver ;
46
49
use tokio:: sync:: mpsc:: UnboundedSender ;
47
50
use tokio:: sync:: oneshot;
48
51
use tracing:: Instrument ;
49
52
53
+ use crate :: buffers:: Buffer ;
54
+ use crate :: buffers:: FrozenBuffer ;
50
55
use crate :: config:: SHARED_ASYNCIO_RUNTIME ;
51
56
use crate :: local_state_broker:: BrokerId ;
52
57
use crate :: local_state_broker:: LocalStateBrokerMessage ;
@@ -236,22 +241,24 @@ fn mailbox<'py, T: Actor>(py: Python<'py>, cx: &Context<'_, T>) -> Bound<'py, Py
236
241
#[ derive( Clone , Serialize , Deserialize , Named , PartialEq , Default ) ]
237
242
pub struct PythonMessage {
238
243
pub kind : PythonMessageKind ,
239
- #[ serde( with = "serde_bytes" ) ]
240
- pub message : Vec < u8 > ,
244
+ pub message : Part ,
241
245
}
242
246
243
247
struct ResolvedCallMethod {
244
248
method : MethodSpecifier ,
245
- bytes : Vec < u8 > ,
249
+ bytes : FrozenBuffer ,
246
250
local_state : PyObject ,
247
251
/// Implements PortProtocol
248
252
/// Concretely either a Port, DroppingPort, or LocalPort
249
253
response_port : PyObject ,
250
254
}
251
255
252
256
impl PythonMessage {
253
- pub fn new_from_buf ( kind : PythonMessageKind , message : Vec < u8 > ) -> Self {
254
- Self { kind, message }
257
+ pub fn new_from_buf ( kind : PythonMessageKind , message : impl Into < Part > ) -> Self {
258
+ Self {
259
+ kind,
260
+ message : message. into ( ) ,
261
+ }
255
262
}
256
263
257
264
pub fn into_rank ( self , rank : usize ) -> Self {
@@ -305,7 +312,9 @@ impl PythonMessage {
305
312
. unwrap ( ) ;
306
313
Ok ( ResolvedCallMethod {
307
314
method : name,
308
- bytes : self . message ,
315
+ bytes : FrozenBuffer {
316
+ inner : self . message . into_inner ( ) ,
317
+ } ,
309
318
local_state,
310
319
response_port,
311
320
} )
@@ -341,7 +350,9 @@ impl PythonMessage {
341
350
. unbind ( ) ;
342
351
Ok ( ResolvedCallMethod {
343
352
method : name,
344
- bytes : self . message ,
353
+ bytes : FrozenBuffer {
354
+ inner : self . message . into_inner ( ) ,
355
+ } ,
345
356
local_state,
346
357
response_port,
347
358
} )
@@ -359,7 +370,7 @@ impl std::fmt::Debug for PythonMessage {
359
370
. field ( "kind" , & self . kind )
360
371
. field (
361
372
"message" ,
362
- & hyperactor:: data:: HexFmt ( self . message . as_slice ( ) ) . to_string ( ) ,
373
+ & hyperactor:: data:: HexFmt ( & ( * self . message ) [ .. ] ) . to_string ( ) ,
363
374
)
364
375
. finish ( )
365
376
}
@@ -387,8 +398,20 @@ impl Bind for PythonMessage {
387
398
impl PythonMessage {
388
399
#[ new]
389
400
#[ pyo3( signature = ( kind, message) ) ]
390
- pub fn new ( kind : PythonMessageKind , message : & [ u8 ] ) -> Self {
391
- PythonMessage :: new_from_buf ( kind, message. to_vec ( ) )
401
+ pub fn new < ' py > ( kind : PythonMessageKind , message : Bound < ' py , PyAny > ) -> PyResult < Self > {
402
+ if let Ok ( buff) = message. extract :: < Bound < ' py , FrozenBuffer > > ( ) {
403
+ let frozen = buff. borrow_mut ( ) ;
404
+ return Ok ( PythonMessage :: new_from_buf ( kind, frozen. inner . clone ( ) ) ) ;
405
+ } else if let Ok ( buff) = message. extract :: < Bound < ' py , PyBytes > > ( ) {
406
+ return Ok ( PythonMessage :: new_from_buf (
407
+ kind,
408
+ Vec :: from ( buff. as_bytes ( ) ) ,
409
+ ) ) ;
410
+ }
411
+
412
+ Err ( PyTypeError :: new_err (
413
+ "PythonMessage(buff) takes Buffer or bytes objects only" ,
414
+ ) )
392
415
}
393
416
394
417
#[ getter]
@@ -397,8 +420,10 @@ impl PythonMessage {
397
420
}
398
421
399
422
#[ getter]
400
- fn message < ' a > ( & self , py : Python < ' a > ) -> Bound < ' a , PyBytes > {
401
- PyBytes :: new ( py, self . message . as_ref ( ) )
423
+ fn message ( & self ) -> FrozenBuffer {
424
+ FrozenBuffer {
425
+ inner : self . message . clone ( ) . into_inner ( ) ,
426
+ }
402
427
}
403
428
}
404
429
@@ -842,7 +867,7 @@ mod tests {
842
867
} ,
843
868
response_port : Some ( EitherPortRef :: Unbounded ( port_ref. clone ( ) . into ( ) ) ) ,
844
869
} ,
845
- message : vec ! [ 1 , 2 , 3 ] ,
870
+ message : Part :: from ( vec ! [ 1 , 2 , 3 ] ) ,
846
871
} ;
847
872
{
848
873
let mut erased = ErasedUnbound :: try_from_message ( message. clone ( ) ) . unwrap ( ) ;
0 commit comments