diff --git a/Cargo.toml b/Cargo.toml index 8334230937..9688c7be9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,7 @@ tokio = { version = "1", features = [ ] } tokio-test = "0.4" tokio-util = "0.7.10" +tracing-subscriber = "0.3" [features] # Nothing by default @@ -239,6 +240,11 @@ name = "integration" path = "tests/integration.rs" required-features = ["full"] +[[test]] +name = "ready_stream" +path = "tests/ready_stream.rs" +required-features = ["full", "tracing"] + [[test]] name = "server" path = "tests/server.rs" diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 5daeb5ebf6..3ee88e17d9 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -171,8 +171,13 @@ where for _ in 0..16 { let _ = self.poll_read(cx)?; let _ = self.poll_write(cx)?; - let _ = self.poll_flush(cx)?; + let conn_ready = self.poll_flush(cx)?.is_ready(); + // If we can write more body and the connection is ready, we should + // write again. If we return `Ready(Ok(())` here, we will yield + // without a guaranteed wakeup from the write side of the connection. + // This would lead to a deadlock if we also don't expect reads. + let wants_write_again = self.can_write_again() && conn_ready; // This could happen if reading paused before blocking on IO, // such as getting to the end of a framed message, but then // writing/flushing set the state back to Init. In that case, @@ -181,7 +186,10 @@ where // // Using this instead of task::current() and notify() inside // the Conn is noticeably faster in pipelined benchmarks. - if !self.conn.wants_read_again() { + let wants_read_again = self.conn.wants_read_again(); + // If we cannot write or read again, we yield and rely on the + // wakeup from the connection futures. + if !(wants_write_again || wants_read_again) { //break; return Poll::Ready(Ok(())); } @@ -433,6 +441,11 @@ where self.conn.close_write(); } + /// If there is pending data in body_rx, we can make progress writing if the connection is ready. + fn can_write_again(&mut self) -> bool { + self.body_rx.is_some() + } + fn is_done(&self) -> bool { if self.is_closing { return true; diff --git a/tests/ready_stream.rs b/tests/ready_stream.rs new file mode 100644 index 0000000000..c90e2176ae --- /dev/null +++ b/tests/ready_stream.rs @@ -0,0 +1,249 @@ +use http_body_util::StreamBody; +use hyper::body::Bytes; +use hyper::body::Frame; +use hyper::rt::{Read, ReadBufCursor, Write}; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{Response, StatusCode}; +use pin_project_lite::pin_project; +use std::convert::Infallible; +use std::io; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; +use tokio::sync::mpsc; +use tracing::{error, info}; + +pin_project! { + #[derive(Debug)] + pub struct TxReadyStream { + #[pin] + read_rx: mpsc::UnboundedReceiver>, + write_tx: mpsc::UnboundedSender>, + read_buffer: Vec, + poll_since_write:bool, + flush_count: usize, + panic_task: Option>, + } +} + +impl TxReadyStream { + fn new( + read_rx: mpsc::UnboundedReceiver>, + write_tx: mpsc::UnboundedSender>, + ) -> Self { + Self { + read_rx, + write_tx, + read_buffer: Vec::new(), + poll_since_write: true, + flush_count: 0, + panic_task: None, + } + } + + /// Create a new pair of connected ReadyStreams. Returns two streams that are connected to each other. + fn new_pair() -> (Self, Self) { + let (s1_tx, s2_rx) = mpsc::unbounded_channel(); + let (s2_tx, s1_rx) = mpsc::unbounded_channel(); + let s1 = Self::new(s1_rx, s1_tx); + let s2 = Self::new(s2_rx, s2_tx); + (s1, s2) + } + + /// Send data to the other end of the stream (this will be available for reading on the other stream) + fn send(&self, data: &[u8]) -> Result<(), mpsc::error::SendError>> { + self.write_tx.send(data.to_vec()) + } + + /// Receive data written to this stream by the other end (async) + async fn recv(&mut self) -> Option> { + self.read_rx.recv().await + } +} + +impl Read for TxReadyStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: ReadBufCursor<'_>, + ) -> Poll> { + let mut this = self.as_mut().project(); + + // First, try to satisfy the read request from the internal buffer + if !this.read_buffer.is_empty() { + let to_read = std::cmp::min(this.read_buffer.len(), buf.remaining()); + // Copy data from internal buffer to the read buffer + buf.put_slice(&this.read_buffer[..to_read]); + // Remove the consumed data from the internal buffer + this.read_buffer.drain(..to_read); + return Poll::Ready(Ok(())); + } + + // If internal buffer is empty, try to get data from the channel + match this.read_rx.try_recv() { + Ok(data) => { + // Copy as much data as we can fit in the buffer + let to_read = std::cmp::min(data.len(), buf.remaining()); + buf.put_slice(&data[..to_read]); + + // Store any remaining data in the internal buffer for next time + if to_read < data.len() { + let remaining = &data[to_read..]; + this.read_buffer.extend_from_slice(remaining); + } + Poll::Ready(Ok(())) + } + Err(mpsc::error::TryRecvError::Empty) => { + match ready!(this.read_rx.poll_recv(cx)) { + Some(data) => { + // Copy as much data as we can fit in the buffer + let to_read = std::cmp::min(data.len(), buf.remaining()); + buf.put_slice(&data[..to_read]); + + // Store any remaining data in the internal buffer for next time + if to_read < data.len() { + let remaining = &data[to_read..]; + this.read_buffer.extend_from_slice(remaining); + } + Poll::Ready(Ok(())) + } + None => Poll::Ready(Ok(())), + } + } + Err(mpsc::error::TryRecvError::Disconnected) => { + // Channel closed, return EOF + Poll::Ready(Ok(())) + } + } + } +} + +impl Write for TxReadyStream { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if !self.poll_since_write { + return Poll::Pending; + } + self.poll_since_write = false; + let this = self.project(); + let buf = Vec::from(&buf[..buf.len()]); + let len = buf.len(); + + // Send data through the channel - this should always be ready for unbounded channels + match this.write_tx.send(buf) { + Ok(_) => { + // Increment write count + Poll::Ready(Ok(len)) + } + Err(_) => { + error!("ReadyStream::poll_write failed - channel closed"); + Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "Write channel closed", + ))) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.flush_count += 1; + // We require two flushes to complete each chunk, simulating a success at the end of the old + // poll loop. After all chunks are written, we always succeed on flush to allow for finish. + if self.flush_count % 2 != 0 && self.flush_count < TOTAL_CHUNKS * 2 { + // Spawn panic task if not already spawned + if self.panic_task.is_none() { + let task = tokio::spawn(async { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + }); + self.panic_task = Some(task); + } + return Poll::Pending; + } + + // Abort the panic task if it exists + if let Some(task) = self.panic_task.take() { + info!("Task polled to completion. Aborting panic (aka waker stand-in task)."); + task.abort(); + } + + self.poll_since_write = true; + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +fn init_tracing() { + use std::sync::Once; + static INIT: Once = Once::new(); + INIT.call_once(|| { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .with_target(true) + .with_thread_ids(true) + .with_thread_names(true) + .init(); + }); +} + +const TOTAL_CHUNKS: usize = 16; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn body_test() { + init_tracing(); + // Create a pair of connected streams + let (server_stream, mut client_stream) = TxReadyStream::new_pair(); + + let mut http_builder = http1::Builder::new(); + http_builder.max_buf_size(CHUNK_SIZE); + const CHUNK_SIZE: usize = 64 * 1024; + let service = service_fn(|_| async move { + info!( + "Creating payload of {} chunks of {} KiB each ({} MiB total)...", + TOTAL_CHUNKS, + CHUNK_SIZE / 1024, + TOTAL_CHUNKS * CHUNK_SIZE / (1024 * 1024) + ); + let bytes = Bytes::from(vec![0; CHUNK_SIZE]); + let data = vec![bytes.clone(); TOTAL_CHUNKS]; + let stream = futures_util::stream::iter( + data.into_iter() + .map(|b| Ok::<_, Infallible>(Frame::data(b))), + ); + let body = StreamBody::new(stream); + info!("Server: Sending data response..."); + Ok::<_, hyper::Error>( + Response::builder() + .status(StatusCode::OK) + .header("content-type", "application/octet-stream") + .header("content-length", (TOTAL_CHUNKS * CHUNK_SIZE).to_string()) + .body(body) + .unwrap(), + ) + }); + + let server_task = tokio::spawn(async move { + let conn = http_builder.serve_connection(server_stream, service); + if let Err(e) = conn.await { + error!("Server connection error: {}", e); + } + }); + + let get_request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; + client_stream.send(get_request.as_bytes()).unwrap(); + + info!("Client is reading response..."); + let mut bytes_received = 0; + while let Some(chunk) = client_stream.recv().await { + bytes_received += chunk.len(); + } + // Clean up + server_task.abort(); + + info!(bytes_received, "Client done receiving bytes"); +}