Skip to content

Commit 29e9bc7

Browse files
pzhan9meta-codesync[bot]
authored andcommitted
Add DirectSend variant to SeqInfo
Differential Revision: D83839619
1 parent f7c80fa commit 29e9bc7

File tree

7 files changed

+127
-96
lines changed

7 files changed

+127
-96
lines changed

hyperactor/src/actor.rs

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,31 +1143,45 @@ mod tests {
11431143
.spawn::<GetSeqActor>("get_seq", tx.bind())
11441144
.await
11451145
.unwrap();
1146+
1147+
// Verify that unbound handle can send message.
1148+
actor_handle.send(&client, "unbound".to_string()).unwrap();
1149+
assert_eq!(
1150+
rx.recv().await.unwrap(),
1151+
("unbound".to_string(), SeqInfo::Unordered)
1152+
);
1153+
11461154
let actor_ref: ActorRef<GetSeqActor> = actor_handle.bind();
11471155

11481156
let session_id = client.sequencer().session_id();
11491157
let mut expected_seq = 0;
11501158
// Interleave messages sent through the handle and the reference.
1151-
for _ in 0..10 {
1152-
actor_handle.send(&client, "".to_string()).unwrap();
1159+
for m in 0..10 {
1160+
actor_handle.send(&client, format!("{m}")).unwrap();
11531161
expected_seq += 1;
11541162
assert_eq!(
1155-
rx.recv().await.unwrap().1,
1156-
SeqInfo {
1157-
session_id,
1158-
seq: expected_seq,
1159-
}
1163+
rx.recv().await.unwrap(),
1164+
(
1165+
format!("{m}"),
1166+
SeqInfo::Session {
1167+
session_id,
1168+
seq: expected_seq,
1169+
}
1170+
)
11601171
);
11611172

1162-
for _ in 0..2 {
1163-
actor_ref.port().send(&client, "".to_string()).unwrap();
1173+
for n in 0..2 {
1174+
actor_ref.port().send(&client, format!("{m}-{n}")).unwrap();
11641175
expected_seq += 1;
11651176
assert_eq!(
1166-
rx.recv().await.unwrap().1,
1167-
SeqInfo {
1168-
session_id,
1169-
seq: expected_seq,
1170-
}
1177+
rx.recv().await.unwrap(),
1178+
(
1179+
format!("{m}-{n}"),
1180+
SeqInfo::Session {
1181+
session_id,
1182+
seq: expected_seq,
1183+
}
1184+
)
11711185
);
11721186
}
11731187
}
@@ -1200,7 +1214,10 @@ mod tests {
12001214
let session_id = client.sequencer().session_id();
12011215
assert_eq!(
12021216
rx.recv().await.unwrap(),
1203-
("finally".to_string(), SeqInfo { session_id, seq: 1 })
1217+
(
1218+
"finally".to_string(),
1219+
SeqInfo::Session { session_id, seq: 1 }
1220+
)
12041221
);
12051222
}
12061223

@@ -1280,7 +1297,7 @@ mod tests {
12801297
for expect in expected {
12811298
let expected = (
12821299
expect.0,
1283-
SeqInfo {
1300+
SeqInfo::Session {
12841301
session_id,
12851302
seq: expect.1,
12861303
},

hyperactor/src/mailbox.rs

Lines changed: 28 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@
6666
//! implementation to avoid a serialization roundtrip when passing
6767
//! messages locally.
6868
69-
#![allow(dead_code)] // Allow until this is used outside of tests.
70-
7169
use std::any::Any;
7270
use std::collections::BTreeMap;
7371
use std::collections::BTreeSet;
@@ -946,43 +944,6 @@ impl Future for MailboxServerHandle {
946944
}
947945
}
948946

949-
// A `MailboxServer` (such as a router) can receive a message
950-
// that couldn't reach its destination. We can use the fact that
951-
// servers are `MailboxSender`s to attempt to forward them back to
952-
// their senders.
953-
fn server_return_handle<T: MailboxServer>(server: T) -> PortHandle<Undeliverable<MessageEnvelope>> {
954-
let (return_handle, mut rx) = undeliverable::new_undeliverable_port();
955-
956-
tokio::task::spawn(async move {
957-
while let Ok(Undeliverable(mut envelope)) = rx.recv().await {
958-
if let Ok(Undeliverable(e)) = envelope.deserialized::<Undeliverable<MessageEnvelope>>()
959-
{
960-
// A non-returnable undeliverable.
961-
UndeliverableMailboxSender.post(e, monitored_return_handle());
962-
continue;
963-
}
964-
envelope.set_error(DeliveryError::BrokenLink(
965-
"message was undeliverable".to_owned(),
966-
));
967-
server.post(
968-
MessageEnvelope::new(
969-
envelope.sender().clone(),
970-
PortRef::<Undeliverable<MessageEnvelope>>::attest_message_port(
971-
envelope.sender(),
972-
)
973-
.port_id()
974-
.clone(),
975-
Serialized::serialize(&Undeliverable(envelope)).unwrap(),
976-
Attrs::new(),
977-
),
978-
monitored_return_handle(),
979-
);
980-
}
981-
});
982-
983-
return_handle
984-
}
985-
986947
/// Serve a port on the provided [`channel::Rx`]. This dispatches all
987948
/// channel messages directly to the port.
988949
pub trait MailboxServer: MailboxSender + Clone + Sized + 'static {
@@ -1011,6 +972,9 @@ pub trait MailboxServer: MailboxSender + Clone + Sized + 'static {
1011972
envelope.set_error(DeliveryError::BrokenLink(
1012973
"message was undeliverable".to_owned(),
1013974
));
975+
let mut headers = Attrs::new();
976+
// Ordering is not required when returning Undeliverable.
977+
headers.set(SEQ_INFO, SeqInfo::Unordered);
1014978
server.post(
1015979
MessageEnvelope::new(
1016980
envelope.sender().clone(),
@@ -1020,7 +984,7 @@ pub trait MailboxServer: MailboxSender + Clone + Sized + 'static {
1020984
.port_id()
1021985
.clone(),
1022986
Serialized::serialize(&Undeliverable(envelope)).unwrap(),
1023-
Attrs::new(),
987+
headers,
1024988
),
1025989
monitored_return_handle(),
1026990
);
@@ -1589,19 +1553,30 @@ impl<M: Message> PortHandle<M> {
15891553
let mut headers = Attrs::new();
15901554

15911555
crate::mailbox::headers::set_send_timestamp(&mut headers);
1592-
// Message sent from handle is delivered immediately. It could race with
1593-
// messages from refs. So we need to assign seq if the handle is bound.
1594-
if let Some(bound_port) = self.bound.get()
1595-
&& bound_port.is_actor_port()
1596-
{
1597-
let sequencer = cx.instance().sequencer();
1598-
let seq = sequencer.assign_seq(self.mailbox.actor_id());
1599-
let seq_info = SeqInfo {
1600-
session_id: sequencer.session_id(),
1601-
seq,
1602-
};
1603-
headers.set(SEQ_INFO, seq_info);
1556+
1557+
match self.bound.get() {
1558+
Some(bound_port) => {
1559+
// Message sent from handle is delivered immediately. It could
1560+
// race with messages from refs. So we need to assign seq to
1561+
// preserve the ordering.
1562+
if bound_port.is_actor_port() {
1563+
let sequencer = cx.instance().sequencer();
1564+
let seq = sequencer.assign_seq(self.mailbox.actor_id());
1565+
let seq_info = SeqInfo::Session {
1566+
session_id: sequencer.session_id(),
1567+
seq,
1568+
};
1569+
headers.set(SEQ_INFO, seq_info);
1570+
}
1571+
}
1572+
None => {
1573+
// we do not have info to know whether this handle is used for
1574+
// enqueue port or not. Since enqueue port requires the SEQ_INFO
1575+
// header, we set it in for all messages sent from unbound handles.
1576+
headers.set(SEQ_INFO, SeqInfo::Unordered);
1577+
}
16041578
}
1579+
16051580
// Encountering error means the port is closed. So we do not need to
16061581
// rollback the seq, because no message can be delivered to it, and
16071582
// subsequently do not need to worry about out-of-sequence for messages
@@ -1619,6 +1594,7 @@ impl<M: Message> PortHandle<M> {
16191594
pub fn anon_send(&self, message: M) -> Result<(), MailboxSenderError> {
16201595
let mut headers = Attrs::new();
16211596
crate::mailbox::headers::set_send_timestamp(&mut headers);
1597+
headers.set(SEQ_INFO, SeqInfo::Unordered);
16221598
self.sender.send(headers, message).map_err(|err| {
16231599
MailboxSenderError::new_unbound::<M>(
16241600
self.mailbox.actor_id().clone(),

hyperactor/src/ordering.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ impl<T> OrderedSender<T> {
164164
}
165165

166166
pub(crate) fn direct_send(&self, msg: T) -> Result<(), SendError<T>> {
167-
assert!(!self.enable_buffering);
168167
self.tx.send(msg)
169168
}
170169
}

hyperactor/src/proc.rs

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
use std::any::Any;
1515
use std::any::TypeId;
16+
use std::any::type_name;
1617
use std::collections::HashMap;
1718
use std::fmt;
1819
use std::future::Future;
@@ -1847,30 +1848,42 @@ pub struct Ports<A: Actor> {
18471848

18481849
/// A message's sequencer number infomation.
18491850
#[derive(Debug, Serialize, Deserialize, Clone, Named, AttrValue, PartialEq)]
1850-
pub struct SeqInfo {
1851-
/// Message's session ID
1852-
pub session_id: Uuid,
1853-
/// Message's sequence number in the given session.
1854-
pub seq: u64,
1851+
pub enum SeqInfo {
1852+
/// Messages with the same session ID should be delivered in order.
1853+
Session {
1854+
/// Message's session ID
1855+
session_id: Uuid,
1856+
/// Message's sequence number in the given session.
1857+
seq: u64,
1858+
},
1859+
/// This message does not require ordering and thus have no sequence number.
1860+
Unordered,
18551861
}
18561862

18571863
impl fmt::Display for SeqInfo {
18581864
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1859-
write!(f, "{}:{}", self.session_id, self.seq)
1865+
match self {
1866+
Self::Unordered => write!(f, "unordered"),
1867+
Self::Session { session_id, seq } => write!(f, "{}:{}", session_id, seq),
1868+
}
18601869
}
18611870
}
18621871

18631872
impl std::str::FromStr for SeqInfo {
18641873
type Err = anyhow::Error;
18651874

18661875
fn from_str(s: &str) -> Result<Self, Self::Err> {
1876+
if s == "unordered" {
1877+
return Ok(SeqInfo::Unordered);
1878+
}
1879+
18671880
let parts: Vec<_> = s.split(':').collect();
18681881
if parts.len() != 2 {
18691882
return Err(anyhow::anyhow!("invalid SeqInfo: {}", s));
18701883
}
18711884
let session_id: Uuid = parts[0].parse()?;
18721885
let seq: u64 = parts[1].parse()?;
1873-
Ok(SeqInfo { session_id, seq })
1886+
Ok(SeqInfo::Session { session_id, seq })
18741887
}
18751888
}
18761889

@@ -1926,18 +1939,37 @@ impl<A: Actor> Ports<A> {
19261939
hyperactor_telemetry::kv_pairs!("actor_id" => actor_id.clone()),
19271940
);
19281941
if workq.enable_buffering {
1929-
let SeqInfo { session_id, seq } =
1930-
seq_info.expect("SEQ_INFO must be set when buffering is enabled");
1931-
1932-
// TODO: return the message contained in the error instead of dropping them when converting
1933-
// to anyhow::Error. In that way, the message can be picked up by mailbox and returned to sender.
1934-
workq.send(session_id, seq, work).map_err(|e| match e {
1935-
OrderedSenderError::InvalidZeroSeq(_) => {
1936-
anyhow::anyhow!("seq must be greater than 0")
1942+
match seq_info {
1943+
Some(SeqInfo::Session { session_id, seq }) => {
1944+
// TODO: return the message contained in the error instead of dropping them when converting
1945+
// to anyhow::Error. In that way, the message can be picked up by mailbox and returned to sender.
1946+
workq.send(session_id, seq, work).map_err(|e| match e {
1947+
OrderedSenderError::InvalidZeroSeq(_) => {
1948+
let error_msg = format!(
1949+
"in enqueue func for {}, got seq 0 for message type {}",
1950+
actor_id,
1951+
std::any::type_name::<M>(),
1952+
);
1953+
tracing::error!(error_msg);
1954+
anyhow::anyhow!(error_msg)
1955+
}
1956+
OrderedSenderError::SendError(e) => anyhow::Error::from(e),
1957+
OrderedSenderError::FlushError(e) => e,
1958+
})
19371959
}
1938-
OrderedSenderError::SendError(e) => anyhow::Error::from(e),
1939-
OrderedSenderError::FlushError(e) => e,
1940-
})
1960+
Some(SeqInfo::Unordered) => {
1961+
workq.direct_send(work).map_err(anyhow::Error::from)
1962+
}
1963+
None => {
1964+
let error_msg = format!(
1965+
"in enqueue func for {}, buffering is enabled, but SEQ_INFO is not set for message type {}",
1966+
actor_id,
1967+
std::any::type_name::<M>(),
1968+
);
1969+
tracing::error!(error_msg);
1970+
anyhow::bail!(error_msg);
1971+
}
1972+
}
19411973
} else {
19421974
workq.direct_send(work).map_err(anyhow::Error::from)
19431975
}

hyperactor/src/reference.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ impl PortId {
918918
// without worrying about rollback.
919919
let sequencer = cx.instance().sequencer();
920920
let seq = sequencer.assign_seq(self.actor_id());
921-
let seq_info = SeqInfo {
921+
let seq_info = SeqInfo::Session {
922922
session_id: sequencer.session_id(),
923923
seq,
924924
};
@@ -1563,10 +1563,13 @@ mod tests {
15631563
let port_ref = PortRef::attest(port_id.clone());
15641564

15651565
port_handle.send(&client, ()).unwrap();
1566-
let SeqInfo {
1566+
let SeqInfo::Session {
15671567
session_id,
15681568
mut seq,
1569-
} = rx.try_recv().unwrap().unwrap();
1569+
} = rx.try_recv().unwrap().unwrap()
1570+
else {
1571+
panic!("expected session info");
1572+
};
15701573
assert_eq!(session_id, client.sequencer().session_id());
15711574
assert_eq!(seq, 1);
15721575

@@ -1576,10 +1579,13 @@ mod tests {
15761579
seq: &mut u64,
15771580
) {
15781581
*seq += 1;
1579-
let SeqInfo {
1582+
let SeqInfo::Session {
15801583
session_id: rcved_session_id,
15811584
seq: rcved_seq,
1582-
} = rx.try_recv().unwrap().unwrap();
1585+
} = rx.try_recv().unwrap().unwrap()
1586+
else {
1587+
panic!("expected session info");
1588+
};
15831589
assert_eq!(rcved_session_id, session_id);
15841590
assert_eq!(rcved_seq, *seq);
15851591
}
@@ -1617,8 +1623,9 @@ mod tests {
16171623
Ok(())
16181624
});
16191625
port_handle.send(&client, ()).unwrap();
1620-
// No seq will be assigned for unbound port handle.
1621-
assert!(rx.try_recv().unwrap().is_none());
1626+
// Unordered be set for unbound port handle since handler's ordered
1627+
// channel is expecting the SEQ_INFO header to be set.
1628+
assert_eq!(rx.try_recv().unwrap().unwrap(), SeqInfo::Unordered);
16221629

16231630
// Bind to the allocated port.
16241631
port_handle.bind();

hyperactor_mesh/src/comm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ impl Handler<ForwardMessageV1> for CommActor {
521521
&mut headers,
522522
cast_point,
523523
message.cast_headers.sender.clone(),
524-
Some(SeqInfo {
524+
Some(SeqInfo::Session {
525525
session_id: message.cast_headers.session_id,
526526
seq,
527527
}),

hyperactor_mesh/src/v1/testactor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ pub async fn assert_casting_correctness(
281281
.iter()
282282
.zip(
283283
seqs.into_iter()
284-
.map(|seq| Some(SeqInfo { session_id, seq })),
284+
.map(|seq| Some(SeqInfo::Session { session_id, seq })),
285285
)
286286
.collect(),
287287
};

0 commit comments

Comments
 (0)