Skip to content

Commit 242c726

Browse files
channel: net: framed: test frame size limits (#931)
Summary: Pull Request resolved: #931 document frame reader fatal error semantics in `next()` and add tests for frame size edges: accepts frames of exactly max length, rejects frames over max (`InvalidData`, reader must be dropped), and round-trips zero-length frames. Reviewed By: mariusae Differential Revision: D80561321 fbshipit-source-id: a8cfb6748a9e0cf412f01a886c1be05b7dbc90a7
1 parent 227ce9a commit 242c726

File tree

1 file changed

+78
-8
lines changed

1 file changed

+78
-8
lines changed

hyperactor/src/channel/net/framed.rs

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,20 @@ impl<R: AsyncRead + Unpin> FrameReader<R> {
4747
}
4848
}
4949

50-
/// Read the next frame from the underlying reader. If the frame exceeds
51-
/// the configured maximum length, `next` returns an `io::ErrorKind::InvalidData`
52-
/// error.
50+
/// Read the next frame from the underlying reader. If the frame
51+
/// exceeds the configured maximum length, `next` returns an
52+
/// `io::ErrorKind::InvalidData` error.
5353
///
54-
/// The method is cancellation safe in the sense that, if it is used in a branch
55-
/// of a `tokio::select!` block, frames are never dropped.
54+
/// The method is cancellation safe in the sense that, if it is
55+
/// used in a branch of a `tokio::select!` block, frames are never
56+
/// dropped.
57+
///
58+
/// # Errors
59+
///
60+
/// * Returns `io::ErrorKind::InvalidData` if a frame exceeds
61+
/// `max_frame_length`. **This error is fatal:** once returned,
62+
/// the `FrameReader` must be dropped; the underlying connection
63+
/// is no longer valid.
5664
pub async fn next(&mut self) -> io::Result<Option<Bytes>> {
5765
loop {
5866
match &mut self.state {
@@ -106,9 +114,9 @@ impl<R: AsyncRead + Unpin> FrameReader<R> {
106114
}
107115
}
108116

109-
/// A Writer for message frames. FrameWrite requires the user to drive
117+
/// A Writer for message frames. `FrameWrite` requires the user to drive
110118
/// the underlying state machines through (possibly) successive calls to
111-
/// `send`, retaining cancellation safety. The FrameWrite owns the underlying
119+
/// `send`, retaining cancellation safety. The `FrameWrite` owns the underlying
112120
/// writer until the frame has been written to completion.
113121
pub struct FrameWrite<W> {
114122
writer: W,
@@ -418,5 +426,67 @@ mod tests {
418426
assert!(reader.next().await.unwrap().is_none());
419427
}
420428

421-
// todo: frame size
429+
#[tokio::test]
430+
async fn test_reader_accepts_exact_max_len_frames() {
431+
const MAX: usize = 1024;
432+
const BUFSIZ: usize = 8 + MAX; // BUFSIZ (bytes) = 8 (len) + MAX (body)
433+
let (a, b) = tokio::io::duplex(BUFSIZ);
434+
let (r, _wu) = tokio::io::split(a);
435+
let (_ru, mut w) = tokio::io::split(b);
436+
let mut reader = FrameReader::new(r, MAX);
437+
438+
let bytes_written = Bytes::from(vec![0xAB; MAX]);
439+
w = FrameWrite::write_frame(w, bytes_written.clone())
440+
.await
441+
.unwrap();
442+
443+
let bytes_read = reader.next().await.unwrap().unwrap();
444+
assert_eq!(bytes_read.len(), MAX);
445+
assert_eq!(bytes_read, bytes_written);
446+
447+
w.shutdown().await.unwrap();
448+
assert!(reader.next().await.unwrap().is_none());
449+
}
450+
451+
#[tokio::test]
452+
async fn test_reader_rejects_over_max_len_frames() {
453+
const MAX: usize = 1024;
454+
const BUFSIZ: usize = 8 + MAX; // BUFSIZ (bytes) = 8 (len) + MAX (body)
455+
let (a, b) = tokio::io::duplex(BUFSIZ);
456+
let (r, _wu) = tokio::io::split(a);
457+
let (_ru, mut w) = tokio::io::split(b);
458+
let mut reader = FrameReader::new(r, MAX - 1);
459+
460+
let bytes_written = Bytes::from(vec![0xAB; MAX]);
461+
w = FrameWrite::write_frame(w, bytes_written).await.unwrap();
462+
463+
let err = reader.next().await.unwrap_err();
464+
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
465+
466+
// Do NOT try to use `reader` beyond this point! There has
467+
// been a protocol violation: `InvalidData` means the stream
468+
// is corrupted and the only valid thing you can do with it is
469+
// `drop` it.
470+
drop(reader);
471+
472+
w.shutdown().await.unwrap();
473+
}
474+
475+
#[tokio::test]
476+
async fn test_reader_accepts_zero_len_frames() {
477+
const MAX: usize = 0;
478+
const BUFSIZ: usize = 8 + MAX; // BUFSIZ (bytes) = 8 (len) + MAX (body)
479+
let (a, b) = tokio::io::duplex(BUFSIZ);
480+
let (r, _wu) = tokio::io::split(a);
481+
let (_ru, mut w) = tokio::io::split(b);
482+
let mut reader = FrameReader::new(r, MAX);
483+
484+
w = FrameWrite::write_frame(w, Bytes::new()).await.unwrap();
485+
assert_eq!(reader.next().await.unwrap().unwrap().len(), 0);
486+
w = FrameWrite::write_frame(w, Bytes::new()).await.unwrap();
487+
assert_eq!(reader.next().await.unwrap().unwrap().len(), 0);
488+
489+
w.shutdown().await.unwrap();
490+
assert!(reader.next().await.unwrap().is_none());
491+
}
422492
}

0 commit comments

Comments
 (0)