Skip to content

Commit 94e9993

Browse files
committed
tests: ready_stream as pathological example
1 parent f9f8f44 commit 94e9993

File tree

2 files changed

+240
-0
lines changed

2 files changed

+240
-0
lines changed

Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ tokio = { version = "1", features = [
6666
] }
6767
tokio-test = "0.4"
6868
tokio-util = "0.7.10"
69+
tracing-subscriber = "0.3"
6970

7071
[features]
7172
# Nothing by default
@@ -239,6 +240,11 @@ name = "integration"
239240
path = "tests/integration.rs"
240241
required-features = ["full"]
241242

243+
[[test]]
244+
name = "ready_stream"
245+
path = "tests/ready_stream.rs"
246+
required-features = ["full", "tracing"]
247+
242248
[[test]]
243249
name = "server"
244250
path = "tests/server.rs"

tests/ready_stream.rs

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
use http_body_util::StreamBody;
2+
use hyper::body::Bytes;
3+
use hyper::body::Frame;
4+
use hyper::rt::{Read, ReadBufCursor, Write};
5+
use hyper::server::conn::http1;
6+
use hyper::service::service_fn;
7+
use hyper::{Response, StatusCode};
8+
use pin_project_lite::pin_project;
9+
use std::convert::Infallible;
10+
use std::io;
11+
use std::pin::Pin;
12+
use std::task::{ready, Context, Poll};
13+
use tokio::sync::mpsc;
14+
use tracing::{error, info};
15+
16+
pin_project! {
17+
#[derive(Debug)]
18+
pub struct TxReadyStream {
19+
#[pin]
20+
read_rx: mpsc::UnboundedReceiver<Vec<u8>>,
21+
write_tx: mpsc::UnboundedSender<Vec<u8>>,
22+
read_buffer: Vec<u8>,
23+
poll_since_write:bool,
24+
flush_count: usize,
25+
}
26+
}
27+
28+
impl TxReadyStream {
29+
fn new(
30+
read_rx: mpsc::UnboundedReceiver<Vec<u8>>,
31+
write_tx: mpsc::UnboundedSender<Vec<u8>>,
32+
) -> Self {
33+
Self {
34+
read_rx,
35+
write_tx,
36+
read_buffer: Vec::new(),
37+
poll_since_write: true,
38+
flush_count: 0,
39+
}
40+
}
41+
42+
/// Create a new pair of connected ReadyStreams. Returns two streams that are connected to each other.
43+
fn new_pair() -> (Self, Self) {
44+
let (s1_tx, s2_rx) = mpsc::unbounded_channel();
45+
let (s2_tx, s1_rx) = mpsc::unbounded_channel();
46+
let s1 = Self::new(s1_rx, s1_tx);
47+
let s2 = Self::new(s2_rx, s2_tx);
48+
(s1, s2)
49+
}
50+
51+
/// Send data to the other end of the stream (this will be available for reading on the other stream)
52+
fn send(&self, data: &[u8]) -> Result<(), mpsc::error::SendError<Vec<u8>>> {
53+
self.write_tx.send(data.to_vec())
54+
}
55+
56+
57+
/// Receive data written to this stream by the other end (async)
58+
async fn recv(&mut self) -> Option<Vec<u8>> {
59+
self.read_rx.recv().await
60+
}
61+
}
62+
63+
impl Read for TxReadyStream {
64+
fn poll_read(
65+
mut self: Pin<&mut Self>,
66+
cx: &mut Context<'_>,
67+
mut buf: ReadBufCursor<'_>,
68+
) -> Poll<io::Result<()>> {
69+
let mut this = self.as_mut().project();
70+
71+
// First, try to satisfy the read request from the internal buffer
72+
if !this.read_buffer.is_empty() {
73+
let to_read = std::cmp::min(this.read_buffer.len(), buf.remaining());
74+
// Copy data from internal buffer to the read buffer
75+
buf.put_slice(&this.read_buffer[..to_read]);
76+
// Remove the consumed data from the internal buffer
77+
this.read_buffer.drain(..to_read);
78+
return Poll::Ready(Ok(()));
79+
}
80+
81+
// If internal buffer is empty, try to get data from the channel
82+
match this.read_rx.try_recv() {
83+
Ok(data) => {
84+
// Copy as much data as we can fit in the buffer
85+
let to_read = std::cmp::min(data.len(), buf.remaining());
86+
buf.put_slice(&data[..to_read]);
87+
88+
// Store any remaining data in the internal buffer for next time
89+
if to_read < data.len() {
90+
let remaining = &data[to_read..];
91+
this.read_buffer.extend_from_slice(remaining);
92+
}
93+
Poll::Ready(Ok(()))
94+
}
95+
Err(mpsc::error::TryRecvError::Empty) => {
96+
match ready!(this.read_rx.poll_recv(cx)) {
97+
Some(data) => {
98+
// Copy as much data as we can fit in the buffer
99+
let to_read = std::cmp::min(data.len(), buf.remaining());
100+
buf.put_slice(&data[..to_read]);
101+
102+
// Store any remaining data in the internal buffer for next time
103+
if to_read < data.len() {
104+
let remaining = &data[to_read..];
105+
this.read_buffer.extend_from_slice(remaining);
106+
}
107+
Poll::Ready(Ok(()))
108+
}
109+
None => Poll::Ready(Ok(())),
110+
}
111+
}
112+
Err(mpsc::error::TryRecvError::Disconnected) => {
113+
// Channel closed, return EOF
114+
Poll::Ready(Ok(()))
115+
}
116+
}
117+
}
118+
}
119+
120+
impl Write for TxReadyStream {
121+
fn poll_write(
122+
mut self: Pin<&mut Self>,
123+
_cx: &mut Context<'_>,
124+
buf: &[u8],
125+
) -> Poll<io::Result<usize>> {
126+
if !self.poll_since_write {
127+
return Poll::Pending;
128+
}
129+
self.poll_since_write = false;
130+
let this = self.project();
131+
let buf = Vec::from(&buf[..buf.len()]);
132+
let len = buf.len();
133+
134+
// Send data through the channel - this should always be ready for unbounded channels
135+
match this.write_tx.send(buf) {
136+
Ok(_) => {
137+
// Increment write count
138+
Poll::Ready(Ok(len))
139+
}
140+
Err(_) => {
141+
error!("ReadyStream::poll_write failed - channel closed");
142+
Poll::Ready(Err(io::Error::new(
143+
io::ErrorKind::BrokenPipe,
144+
"Write channel closed",
145+
)))
146+
}
147+
}
148+
}
149+
150+
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
151+
self.flush_count += 1;
152+
// We require two flushes to complete each chunk, simulating a success at the end of the old
153+
// poll loop. After all chunks are written, we always succeed on flush to allow for finish.
154+
if self.flush_count % 2 != 0 && self.flush_count < TOTAL_CHUNKS * 2 {
155+
return Poll::Pending;
156+
}
157+
self.poll_since_write = true;
158+
Poll::Ready(Ok(()))
159+
}
160+
161+
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
162+
Poll::Ready(Ok(()))
163+
}
164+
}
165+
166+
fn init_tracing() {
167+
use std::sync::Once;
168+
static INIT: Once = Once::new();
169+
INIT.call_once(|| {
170+
tracing_subscriber::fmt()
171+
.with_max_level(tracing::Level::INFO)
172+
.with_target(true)
173+
.with_thread_ids(true)
174+
.with_thread_names(true)
175+
.init();
176+
});
177+
}
178+
179+
const TOTAL_CHUNKS: usize = 16;
180+
181+
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
182+
async fn body_test() {
183+
init_tracing();
184+
// Create a pair of connected streams
185+
let (server_stream, mut client_stream) = TxReadyStream::new_pair();
186+
187+
let mut http_builder = http1::Builder::new();
188+
http_builder.max_buf_size(CHUNK_SIZE);
189+
const CHUNK_SIZE: usize = 64 * 1024;
190+
let service = service_fn(|_| async move {
191+
info!(
192+
"Creating payload of {} chunks of {} KiB each ({} MiB total)...",
193+
TOTAL_CHUNKS,
194+
CHUNK_SIZE / 1024,
195+
TOTAL_CHUNKS * CHUNK_SIZE / (1024 * 1024)
196+
);
197+
let bytes = Bytes::from(vec![0; CHUNK_SIZE]);
198+
let data = vec![bytes.clone(); TOTAL_CHUNKS];
199+
let stream = futures_util::stream::iter(
200+
data.into_iter()
201+
.map(|b| Ok::<_, Infallible>(Frame::data(b))),
202+
);
203+
let body = StreamBody::new(stream);
204+
info!("Server: Sending data response...");
205+
Ok::<_, hyper::Error>(
206+
Response::builder()
207+
.status(StatusCode::OK)
208+
.header("content-type", "application/octet-stream")
209+
.header("content-length", (TOTAL_CHUNKS * CHUNK_SIZE).to_string())
210+
.body(body)
211+
.unwrap(),
212+
)
213+
});
214+
215+
let server_task = tokio::spawn(async move {
216+
let conn = http_builder.serve_connection(server_stream, service);
217+
if let Err(e) = conn.await {
218+
error!("Server connection error: {}", e);
219+
}
220+
});
221+
222+
let get_request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
223+
client_stream.send(get_request.as_bytes()).unwrap();
224+
225+
info!("Client is reading response...");
226+
let mut bytes_received = 0;
227+
while let Some(chunk) = client_stream.recv().await {
228+
bytes_received += chunk.len();
229+
}
230+
// Clean up
231+
server_task.abort();
232+
233+
info!(bytes_received, "Client done receiving bytes");
234+
}

0 commit comments

Comments
 (0)