Skip to content

Commit aaac251

Browse files
channel: net: update unit-tests 1/N (#910)
Summary: Pull Request resolved: #910 replace `tokio_util::codec::LengthDelimitedCodec`/`Framed` in `test_persistent_server_session` and `test_ack_from_server_session` with `hyperactor::channel::net`’s `FrameReader`/`FrameWrite`.introduce `serve2` (returns `FrameReader<ReadHalf<DuplexStream>>` and `WriteHalf<DuplexStream>`) and `write_stream2` helper. behavior is unchanged; tests now exercise the production framer (zero-copy, cancellation-safe) directly. Reviewed By: mariusae Differential Revision: D80472238 fbshipit-source-id: e24a858a6d72a74f28e970cfcae1d0abd3d5ae6e
1 parent fb24ab9 commit aaac251

File tree

2 files changed

+102
-46
lines changed

2 files changed

+102
-46
lines changed

hyperactor/benches/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ async fn channel_ping_pong(
230230
});
231231

232232
let start = Instant::now();
233-
client_handle.await.unwrap().unwrap();
233+
let _ = client_handle.await.unwrap().unwrap();
234234
start.elapsed()
235235
}
236236

hyperactor/src/channel/net.rs

Lines changed: 101 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ use dashmap::mapref::entry::Entry;
4343
use enum_as_inner::EnumAsInner;
4444
use serde::de::Error;
4545
use tokio::io::AsyncRead;
46-
use tokio::io::AsyncReadExt;
4746
use tokio::io::AsyncWrite;
4847
use tokio::io::AsyncWriteExt;
4948
use tokio::io::ReadHalf;
@@ -997,7 +996,7 @@ impl<W: AsyncWrite + Unpin, T> WriteState<W, T> {
997996
async fn send(&mut self) -> io::Result<T> {
998997
match self {
999998
Self::Idle(_) => futures::future::pending().await,
1000-
Self::Writing(fw, value) => {
999+
Self::Writing(fw, _value) => {
10011000
fw.send().await?;
10021001
let Ok((fw, value)) = replace(self, Self::Broken).into_writing() else {
10031002
panic!("illegal state");
@@ -2020,13 +2019,15 @@ mod tests {
20202019

20212020
#[cfg(target_os = "linux")] // uses abstract names
20222021
use anyhow::Result;
2022+
use bytes::Bytes;
20232023
use futures::SinkExt;
20242024
use futures::stream::SplitSink;
20252025
use futures::stream::SplitStream;
20262026
use rand::Rng;
20272027
use rand::SeedableRng;
20282028
use rand::distributions::Alphanumeric;
20292029
use timed_test::async_timed_test;
2030+
use tokio::io::AsyncWrite;
20302031
use tokio::io::DuplexStream;
20312032
use tokio_util::codec::Framed;
20322033

@@ -2548,6 +2549,36 @@ mod tests {
25482549
}
25492550
}
25502551

2552+
async fn serve2<M>(
2553+
manager: &SessionManager,
2554+
) -> (
2555+
JoinHandle<std::result::Result<(), anyhow::Error>>,
2556+
FrameReader<ReadHalf<DuplexStream>>,
2557+
WriteHalf<DuplexStream>,
2558+
mpsc::Receiver<M>,
2559+
CancellationToken,
2560+
)
2561+
where
2562+
M: RemoteMessage,
2563+
{
2564+
let cancel_token = CancellationToken::new();
2565+
// When testing ServerConn, we do not need a Link object, but
2566+
// only a duplex stream. Therefore, we create them directly so
2567+
// the test will not have dependence on Link.
2568+
let (sender, receiver) = tokio::io::duplex(5000);
2569+
let source = ChannelAddr::Local(u64::MAX);
2570+
let dest = ChannelAddr::Local(u64::MAX);
2571+
let conn = ServerConn::new(receiver, source, dest);
2572+
let manager1 = manager.clone();
2573+
let cancel_token_1 = cancel_token.child_token();
2574+
let (tx, rx) = mpsc::channel(1);
2575+
let join_handle =
2576+
tokio::spawn(async move { manager1.serve(conn, tx, cancel_token_1).await });
2577+
let (r, writer) = tokio::io::split(sender);
2578+
let reader = FrameReader::new(r, config::global::get(config::CODEC_MAX_FRAME_LENGTH));
2579+
(join_handle, reader, writer, rx, cancel_token)
2580+
}
2581+
25512582
async fn serve<M>(
25522583
manager: &SessionManager,
25532584
) -> (
@@ -2576,6 +2607,33 @@ mod tests {
25762607
(join_handle, framed, rx, cancel_token)
25772608
}
25782609

2610+
async fn write_stream2<M, W>(
2611+
mut writer: W,
2612+
session_id: u64,
2613+
messages: &[(u64, M)],
2614+
init: bool,
2615+
) -> W
2616+
where
2617+
M: RemoteMessage + PartialEq + Clone,
2618+
W: AsyncWrite + Unpin,
2619+
{
2620+
if init {
2621+
let frame = bincode::serialize(&Frame::<u64>::Init(session_id)).unwrap();
2622+
let mut fw = FrameWrite::new(writer, Bytes::from(frame));
2623+
fw.send().await.unwrap();
2624+
writer = fw.complete();
2625+
}
2626+
2627+
for (seq, message) in messages {
2628+
let frame = bincode::serialize(&Frame::<M>::Message(*seq, message.clone())).unwrap();
2629+
let mut fw = FrameWrite::new(writer, Bytes::from(frame));
2630+
fw.send().await.unwrap();
2631+
writer = fw.complete();
2632+
}
2633+
2634+
writer
2635+
}
2636+
25792637
async fn write_stream<M: RemoteMessage + std::cmp::PartialEq + Clone>(
25802638
framed: &mut Framed<DuplexStream, LengthDelimitedCodec>,
25812639
session_id: u64,
@@ -2610,21 +2668,12 @@ mod tests {
26102668
// Use temporary config for this test
26112669
let config = config::global::lock();
26122670
let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
2613-
async fn verify_ack(
2614-
framed: &mut Framed<DuplexStream, LengthDelimitedCodec>,
2615-
expected_last: u64,
2616-
) {
2671+
2672+
async fn verify_ack(reader: &mut FrameReader<ReadHalf<DuplexStream>>, expected_last: u64) {
26172673
let mut last_acked: i128 = -1;
26182674
loop {
2619-
let acked = deserialize_ack(
2620-
tokio_stream::StreamExt::next(framed)
2621-
.await
2622-
.unwrap()
2623-
.unwrap()
2624-
.into(),
2625-
)
2626-
.unwrap();
2627-
2675+
let bytes = reader.next().await.unwrap().unwrap();
2676+
let acked = deserialize_ack(bytes).unwrap();
26282677
assert!(
26292678
acked as i128 > last_acked,
26302679
"acks should be delivered in ascending order"
@@ -2641,11 +2690,17 @@ mod tests {
26412690
let session_id = 123;
26422691

26432692
{
2644-
let (handle, mut framed, mut rx, _cancel_token) = serve(&manager).await;
2645-
write_stream(
2646-
&mut framed,
2693+
let (handle, mut reader, mut writer, mut rx, _cancel_token) =
2694+
serve2::<u64>(&manager).await;
2695+
writer = write_stream2(
2696+
writer,
26472697
session_id,
2648-
&[(0, 100), (1, 101), (2, 102), (3, 103)],
2698+
&[
2699+
(0u64, 100u64),
2700+
(1u64, 101u64),
2701+
(2u64, 102u64),
2702+
(3u64, 103u64),
2703+
],
26492704
/*init*/ true,
26502705
)
26512706
.await;
@@ -2660,10 +2715,11 @@ mod tests {
26602715
// server side might or might not ack seq<3 depending on the order
26612716
// of execution introduced by tokio::select. But it definitely would
26622717
// ack 3.
2663-
verify_ack(&mut framed, 3).await;
2718+
verify_ack(&mut reader, 3).await;
26642719

2665-
// Drop the sender side and cause the connection to close.
2666-
drop(framed);
2720+
// Drop the reader and writer to cause the connection to close.
2721+
drop(reader);
2722+
drop(writer);
26672723
handle.await.unwrap().unwrap();
26682724
// mspc is closed too and there should be no unread message left.
26692725
assert_eq!(rx.recv().await, Some(103));
@@ -2672,17 +2728,23 @@ mod tests {
26722728

26732729
// Now, create a new connection with the same session.
26742730
{
2675-
let (handle, mut framed, mut rx, cancel_token) = serve(&manager).await;
2731+
let (handle, mut reader, mut writer, mut rx, cancel_token) =
2732+
serve2::<u64>(&manager).await;
26762733
let handle = tokio::spawn(async move {
26772734
let result = handle.await.unwrap();
26782735
eprintln!("handle joined with: {:?}", result);
26792736
result
26802737
});
26812738

2682-
write_stream(
2683-
&mut framed,
2739+
writer = write_stream2(
2740+
writer,
26842741
session_id,
2685-
&[(2, 102), (3, 103), (4, 104), (5, 105)],
2742+
&[
2743+
(2u64, 102u64),
2744+
(3u64, 103u64),
2745+
(4u64, 104u64),
2746+
(5u64, 105u64),
2747+
],
26862748
/*init*/ true,
26872749
)
26882750
.await;
@@ -2692,7 +2754,7 @@ mod tests {
26922754
assert_eq!(rx.recv().await, Some(104));
26932755
assert_eq!(rx.recv().await, Some(105));
26942756

2695-
verify_ack(&mut framed, 5).await;
2757+
verify_ack(&mut reader, 5).await;
26962758

26972759
// Wait long enough to ensure server processed everything.
26982760
RealClock.sleep(Duration::from_secs(5)).await;
@@ -2702,7 +2764,7 @@ mod tests {
27022764
// mspc is closed too and there should be no unread message left.
27032765
assert!(rx.recv().await.is_none());
27042766
// No more acks from server.
2705-
assert!(tokio_stream::StreamExt::next(&mut framed).await.is_none());
2767+
assert!(reader.next().await.unwrap().is_none());
27062768
};
27072769
}
27082770

@@ -2711,26 +2773,20 @@ mod tests {
27112773
let config = config::global::lock();
27122774
let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
27132775
let manager = SessionManager::new();
2714-
let session_id = 123;
2776+
let session_id = 123u64;
27152777

2716-
let (handle, mut framed, mut rx, cancel_token) = serve(&manager).await;
2717-
for i in 0..100 {
2718-
write_stream(
2719-
&mut framed,
2778+
let (handle, mut reader, mut writer, mut rx, cancel_token) = serve2::<u64>(&manager).await;
2779+
for i in 0u64..100u64 {
2780+
writer = write_stream2(
2781+
writer,
27202782
session_id,
2721-
&[(i, 100 + i)],
2722-
/*init*/ i == 0,
2783+
&[(i, 100u64 + i)],
2784+
/*init*/ i == 0u64,
27232785
)
27242786
.await;
2725-
assert_eq!(rx.recv().await, Some(100 + i));
2726-
let acked = deserialize_ack(
2727-
tokio_stream::StreamExt::next(&mut framed)
2728-
.await
2729-
.unwrap()
2730-
.unwrap()
2731-
.into(),
2732-
)
2733-
.unwrap();
2787+
assert_eq!(rx.recv().await, Some(100u64 + i));
2788+
let bytes = reader.next().await.unwrap().unwrap();
2789+
let acked = deserialize_ack(bytes).unwrap();
27342790
assert_eq!(acked, i);
27352791
}
27362792

@@ -2742,7 +2798,7 @@ mod tests {
27422798
// mspc is closed too and there should be no unread message left.
27432799
assert!(rx.recv().await.is_none());
27442800
// No more acks from server.
2745-
assert!(tokio_stream::StreamExt::next(&mut framed).await.is_none());
2801+
assert!(reader.next().await.unwrap().is_none());
27462802
}
27472803

27482804
#[tracing_test::traced_test]

0 commit comments

Comments
 (0)