Skip to content

Commit 227ce9a

Browse files
channel: net: framed: test cancellation (#929)
Summary: Pull Request resolved: #929 add `test_writer_cancellation_resume` in net/framed.rs to verify `FrameWrite::send` preserves progress across cancellation: throttle to write only the 8-byte length, cancel, assert no full frame is readable, then resume and complete the frame; clean shutdown yields boundary `EOF`. introduce a small `Throttled<W: AsyncWrite>` helper for this test to limit bytes per poll (returns `Poll::Pending` when budget is `0` and wakes the task). Reviewed By: mariusae Differential Revision: D80552262 fbshipit-source-id: 411579eb737b02e5da3df2e3a3477a41f69c846d
1 parent 84a0c2f commit 227ce9a

File tree

1 file changed

+111
-1
lines changed

1 file changed

+111
-1
lines changed

hyperactor/src/channel/net/framed.rs

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,14 @@ impl<W: AsyncWrite + Unpin> FrameWrite<W> {
209209

210210
#[cfg(test)]
211211
mod tests {
212+
use std::pin::Pin;
213+
use std::task::Context;
214+
use std::task::Poll;
215+
216+
use bytes::Bytes;
212217
use rand::Rng;
213218
use rand::thread_rng;
219+
use tokio::io::AsyncWrite;
214220
use tokio::io::AsyncWriteExt;
215221

216222
use super::*;
@@ -308,5 +314,109 @@ mod tests {
308314
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
309315
}
310316

311-
// todo: test cancellation, frame size
317+
/// A wrapper around an `AsyncWrite` that throttles how many bytes
318+
/// may be written per poll.
319+
///
320+
/// We are going to use this to simulate partial writes to test
321+
/// cancellation safety: when the budget is 0, `poll_write`
322+
/// returns `Poll::Pending` and calls the waker so the task is
323+
/// scheduled to be polled again later.
324+
struct Throttled<W> {
325+
inner: W,
326+
// Number of bytes allowed to be written in the next poll. If
327+
// 0, writes return `Poll::Pending`.
328+
budget: usize,
329+
}
330+
331+
impl<W> Throttled<W> {
332+
fn new(inner: W) -> Self {
333+
Self {
334+
inner,
335+
budget: usize::MAX,
336+
}
337+
}
338+
339+
fn set_budget(&mut self, n: usize) {
340+
self.budget = n;
341+
}
342+
}
343+
344+
impl<W: AsyncWrite + Unpin> AsyncWrite for Throttled<W> {
345+
fn poll_write(
346+
mut self: Pin<&mut Self>,
347+
cx: &mut Context<'_>,
348+
buf: &[u8],
349+
) -> Poll<std::io::Result<usize>> {
350+
// No budget left this poll. Return "not ready" and ask to
351+
// be polled again later.
352+
if self.budget == 0 {
353+
cx.waker().wake_by_ref();
354+
return Poll::Pending;
355+
}
356+
let n = buf.len().min(self.budget);
357+
self.budget -= n;
358+
// Delegate a write of the first `n` bytes to the inner
359+
// writer.
360+
Pin::new(&mut self.inner).poll_write(cx, &buf[..n])
361+
}
362+
363+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
364+
// Delegate to `inner` for flushing.
365+
Pin::new(&mut self.inner).poll_flush(cx)
366+
}
367+
368+
fn poll_shutdown(
369+
mut self: Pin<&mut Self>,
370+
cx: &mut Context<'_>,
371+
) -> Poll<std::io::Result<()>> {
372+
// Delegate to `inner` (ensure resources are released and
373+
// `EOF` is signaled downstream).
374+
Pin::new(&mut self.inner).poll_shutdown(cx)
375+
}
376+
}
377+
378+
#[tokio::test]
379+
#[allow(clippy::disallowed_methods)]
380+
async fn test_writer_cancellation_resume() {
381+
let (a, b) = tokio::io::duplex(4096);
382+
let (r, _wu) = tokio::io::split(a);
383+
let (_ru, w) = tokio::io::split(b);
384+
385+
let w = Throttled::new(w);
386+
// 256 bytes, all = 0x2A ('*'), "the answer"
387+
let body = Bytes::from_static(&[42u8; 256]);
388+
let mut reader = FrameReader::new(r, 1024 * 1024);
389+
let mut fw = FrameWrite::new(w, body.clone());
390+
391+
// Allow only the 8-byte length to be written, then cancel.
392+
fw.writer.set_budget(8);
393+
let fut = fw.send();
394+
tokio::select! {
395+
_ = fut => panic!("send unexpectedly completed"),
396+
_ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {}
397+
}
398+
// The `fut` is dropped here i.e. "cancellation".
399+
assert!(
400+
tokio::time::timeout(std::time::Duration::from_millis(20), async {
401+
reader.next().await
402+
})
403+
.await
404+
.is_err(),
405+
"a full frame isn't available yet, so reader.next().await should block"
406+
);
407+
408+
// Now allow the remaining body to flush and complete the
409+
// frame.
410+
fw.writer.set_budget(usize::MAX);
411+
fw.send().await.unwrap();
412+
let mut w = fw.complete();
413+
let got = reader.next().await.unwrap().unwrap();
414+
assert_eq!(got, body);
415+
416+
// Shutdown and test for EOF on boundary.
417+
w.shutdown().await.unwrap();
418+
assert!(reader.next().await.unwrap().is_none());
419+
}
420+
421+
// todo: frame size
312422
}

0 commit comments

Comments
 (0)