Skip to content

Commit 7ed95af

Browse files
committed
tests(ready_stream): ready_stream as pathological example
1 parent f9f8f44 commit 7ed95af

File tree

2 files changed

+255
-0
lines changed

2 files changed

+255
-0
lines changed

Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ tokio = { version = "1", features = [
6666
] }
6767
tokio-test = "0.4"
6868
tokio-util = "0.7.10"
69+
tracing-subscriber = "0.3"
6970

7071
[features]
7172
# Nothing by default
@@ -239,6 +240,11 @@ name = "integration"
239240
path = "tests/integration.rs"
240241
required-features = ["full"]
241242

243+
[[test]]
244+
name = "ready_stream"
245+
path = "tests/ready_stream.rs"
246+
required-features = ["full", "tracing"]
247+
242248
[[test]]
243249
name = "server"
244250
path = "tests/server.rs"

tests/ready_stream.rs

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
use http_body_util::StreamBody;
2+
use hyper::body::Bytes;
3+
use hyper::body::Frame;
4+
use hyper::rt::{Read, ReadBufCursor, Write};
5+
use hyper::server::conn::http1;
6+
use hyper::service::service_fn;
7+
use hyper::{Response, StatusCode};
8+
use pin_project_lite::pin_project;
9+
use std::convert::Infallible;
10+
use std::io;
11+
use std::pin::Pin;
12+
use std::task::{ready, Context, Poll};
13+
use tokio::sync::mpsc;
14+
use tracing::{error, info};
15+
16+
pin_project! {
17+
#[derive(Debug)]
18+
pub struct TxReadyStream {
19+
#[pin]
20+
read_rx: mpsc::UnboundedReceiver<Vec<u8>>,
21+
write_tx: mpsc::UnboundedSender<Vec<u8>>,
22+
read_buffer: Vec<u8>,
23+
poll_since_write:bool,
24+
flush_count: usize,
25+
panic_task: Option<tokio::task::JoinHandle<()>>,
26+
}
27+
}
28+
29+
impl TxReadyStream {
30+
fn new(
31+
read_rx: mpsc::UnboundedReceiver<Vec<u8>>,
32+
write_tx: mpsc::UnboundedSender<Vec<u8>>,
33+
) -> Self {
34+
Self {
35+
read_rx,
36+
write_tx,
37+
read_buffer: Vec::new(),
38+
poll_since_write: true,
39+
flush_count: 0,
40+
panic_task: None,
41+
}
42+
}
43+
44+
/// Create a new pair of connected ReadyStreams. Returns two streams that are connected to each other.
45+
fn new_pair() -> (Self, Self) {
46+
let (s1_tx, s2_rx) = mpsc::unbounded_channel();
47+
let (s2_tx, s1_rx) = mpsc::unbounded_channel();
48+
let s1 = Self::new(s1_rx, s1_tx);
49+
let s2 = Self::new(s2_rx, s2_tx);
50+
(s1, s2)
51+
}
52+
53+
/// Send data to the other end of the stream (this will be available for reading on the other stream)
54+
fn send(&self, data: &[u8]) -> Result<(), mpsc::error::SendError<Vec<u8>>> {
55+
self.write_tx.send(data.to_vec())
56+
}
57+
58+
/// Receive data written to this stream by the other end (async)
59+
async fn recv(&mut self) -> Option<Vec<u8>> {
60+
self.read_rx.recv().await
61+
}
62+
}
63+
64+
impl Read for TxReadyStream {
65+
fn poll_read(
66+
mut self: Pin<&mut Self>,
67+
cx: &mut Context<'_>,
68+
mut buf: ReadBufCursor<'_>,
69+
) -> Poll<io::Result<()>> {
70+
let mut this = self.as_mut().project();
71+
72+
// First, try to satisfy the read request from the internal buffer
73+
if !this.read_buffer.is_empty() {
74+
let to_read = std::cmp::min(this.read_buffer.len(), buf.remaining());
75+
// Copy data from internal buffer to the read buffer
76+
buf.put_slice(&this.read_buffer[..to_read]);
77+
// Remove the consumed data from the internal buffer
78+
this.read_buffer.drain(..to_read);
79+
return Poll::Ready(Ok(()));
80+
}
81+
82+
// If internal buffer is empty, try to get data from the channel
83+
match this.read_rx.try_recv() {
84+
Ok(data) => {
85+
// Copy as much data as we can fit in the buffer
86+
let to_read = std::cmp::min(data.len(), buf.remaining());
87+
buf.put_slice(&data[..to_read]);
88+
89+
// Store any remaining data in the internal buffer for next time
90+
if to_read < data.len() {
91+
let remaining = &data[to_read..];
92+
this.read_buffer.extend_from_slice(remaining);
93+
}
94+
Poll::Ready(Ok(()))
95+
}
96+
Err(mpsc::error::TryRecvError::Empty) => {
97+
match ready!(this.read_rx.poll_recv(cx)) {
98+
Some(data) => {
99+
// Copy as much data as we can fit in the buffer
100+
let to_read = std::cmp::min(data.len(), buf.remaining());
101+
buf.put_slice(&data[..to_read]);
102+
103+
// Store any remaining data in the internal buffer for next time
104+
if to_read < data.len() {
105+
let remaining = &data[to_read..];
106+
this.read_buffer.extend_from_slice(remaining);
107+
}
108+
Poll::Ready(Ok(()))
109+
}
110+
None => Poll::Ready(Ok(())),
111+
}
112+
}
113+
Err(mpsc::error::TryRecvError::Disconnected) => {
114+
// Channel closed, return EOF
115+
Poll::Ready(Ok(()))
116+
}
117+
}
118+
}
119+
}
120+
121+
impl Write for TxReadyStream {
122+
fn poll_write(
123+
mut self: Pin<&mut Self>,
124+
_cx: &mut Context<'_>,
125+
buf: &[u8],
126+
) -> Poll<io::Result<usize>> {
127+
if !self.poll_since_write {
128+
return Poll::Pending;
129+
}
130+
self.poll_since_write = false;
131+
let this = self.project();
132+
let buf = Vec::from(&buf[..buf.len()]);
133+
let len = buf.len();
134+
135+
// Send data through the channel - this should always be ready for unbounded channels
136+
match this.write_tx.send(buf) {
137+
Ok(_) => {
138+
// Increment write count
139+
Poll::Ready(Ok(len))
140+
}
141+
Err(_) => {
142+
error!("ReadyStream::poll_write failed - channel closed");
143+
Poll::Ready(Err(io::Error::new(
144+
io::ErrorKind::BrokenPipe,
145+
"Write channel closed",
146+
)))
147+
}
148+
}
149+
}
150+
151+
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
152+
self.flush_count += 1;
153+
// We require two flushes to complete each chunk, simulating a success at the end of the old
154+
// poll loop. After all chunks are written, we always succeed on flush to allow for finish.
155+
if self.flush_count % 2 != 0 && self.flush_count < TOTAL_CHUNKS * 2 {
156+
// Spawn panic task if not already spawned
157+
if self.panic_task.is_none() {
158+
let task = tokio::spawn(async {
159+
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
160+
});
161+
self.panic_task = Some(task);
162+
}
163+
return Poll::Pending;
164+
}
165+
166+
// Abort the panic task if it exists
167+
if let Some(task) = self.panic_task.take() {
168+
info!("Task polled to completion. Aborting panic (aka waker stand-in task).");
169+
task.abort();
170+
}
171+
172+
self.poll_since_write = true;
173+
Poll::Ready(Ok(()))
174+
}
175+
176+
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
177+
Poll::Ready(Ok(()))
178+
}
179+
}
180+
181+
fn init_tracing() {
182+
use std::sync::Once;
183+
static INIT: Once = Once::new();
184+
INIT.call_once(|| {
185+
tracing_subscriber::fmt()
186+
.with_max_level(tracing::Level::INFO)
187+
.with_target(true)
188+
.with_thread_ids(true)
189+
.with_thread_names(true)
190+
.init();
191+
});
192+
}
193+
194+
const TOTAL_CHUNKS: usize = 16;
195+
196+
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
197+
async fn body_test() {
198+
init_tracing();
199+
// Create a pair of connected streams
200+
let (server_stream, mut client_stream) = TxReadyStream::new_pair();
201+
202+
let mut http_builder = http1::Builder::new();
203+
http_builder.max_buf_size(CHUNK_SIZE);
204+
const CHUNK_SIZE: usize = 64 * 1024;
205+
let service = service_fn(|_| async move {
206+
info!(
207+
"Creating payload of {} chunks of {} KiB each ({} MiB total)...",
208+
TOTAL_CHUNKS,
209+
CHUNK_SIZE / 1024,
210+
TOTAL_CHUNKS * CHUNK_SIZE / (1024 * 1024)
211+
);
212+
let bytes = Bytes::from(vec![0; CHUNK_SIZE]);
213+
let data = vec![bytes.clone(); TOTAL_CHUNKS];
214+
let stream = futures_util::stream::iter(
215+
data.into_iter()
216+
.map(|b| Ok::<_, Infallible>(Frame::data(b))),
217+
);
218+
let body = StreamBody::new(stream);
219+
info!("Server: Sending data response...");
220+
Ok::<_, hyper::Error>(
221+
Response::builder()
222+
.status(StatusCode::OK)
223+
.header("content-type", "application/octet-stream")
224+
.header("content-length", (TOTAL_CHUNKS * CHUNK_SIZE).to_string())
225+
.body(body)
226+
.unwrap(),
227+
)
228+
});
229+
230+
let server_task = tokio::spawn(async move {
231+
let conn = http_builder.serve_connection(server_stream, service);
232+
if let Err(e) = conn.await {
233+
error!("Server connection error: {}", e);
234+
}
235+
});
236+
237+
let get_request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
238+
client_stream.send(get_request.as_bytes()).unwrap();
239+
240+
info!("Client is reading response...");
241+
let mut bytes_received = 0;
242+
while let Some(chunk) = client_stream.recv().await {
243+
bytes_received += chunk.len();
244+
}
245+
// Clean up
246+
server_task.abort();
247+
248+
info!(bytes_received, "Client done receiving bytes");
249+
}

0 commit comments

Comments
 (0)