Skip to content

Commit d129f07

Browse files
committed
fix: properly check if connections can be recycled
Checks `poll_read`'s status directly, so that we can ensure a socket read actually happens, and then immediately continue if we get `Poll::Pending`. Related to #75
1 parent e62f1ef commit d129f07

File tree

5 files changed

+68
-5
lines changed

5 files changed

+68
-5
lines changed

src/h1/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ cfg_if::cfg_if! {
1919

2020
use super::{async_trait, Error, HttpClient, Request, Response};
2121

22+
pub(crate) mod utils;
23+
2224
mod tcp;
2325
#[cfg(any(feature = "native-tls", feature = "rustls"))]
2426
mod tls;

src/h1/tcp.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use deadpool::managed::{Manager, Object, RecycleResult};
88
use futures::io::{AsyncRead, AsyncWrite};
99
use futures::task::{Context, Poll};
1010

11+
use super::utils::PollRead;
12+
1113
#[derive(Clone, Debug)]
1214
pub(crate) struct TcpConnection {
1315
addr: SocketAddr,
@@ -63,7 +65,14 @@ impl Manager<TcpStream, std::io::Error> for TcpConnection {
6365

6466
async fn recycle(&self, conn: &mut TcpStream) -> RecycleResult<std::io::Error> {
6567
let mut buf = [0; 4];
66-
conn.peek(&mut buf[..]).await?;
68+
match PollRead::new(conn, &mut buf).await {
69+
Poll::Ready(Err(error)) => Err(error),
70+
Poll::Ready(Ok(bytes)) if bytes == 0 => Err(std::io::Error::new(
71+
std::io::ErrorKind::UnexpectedEof,
72+
"connection appeared to be closed (EoF)",
73+
)),
74+
_ => Ok(()),
75+
}?;
6776
Ok(())
6877
}
6978
}

src/h1/tls.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use deadpool::managed::{Manager, Object, RecycleResult};
88
use futures::io::{AsyncRead, AsyncWrite};
99
use futures::task::{Context, Poll};
1010

11+
use super::utils::PollRead;
12+
1113
cfg_if::cfg_if! {
1214
if #[cfg(feature = "rustls")] {
1315
use async_tls::client::TlsStream;
@@ -76,10 +78,15 @@ impl Manager<TlsStream<TcpStream>, Error> for TlsConnection {
7678

7779
async fn recycle(&self, conn: &mut TlsStream<TcpStream>) -> RecycleResult<Error> {
7880
let mut buf = [0; 4];
79-
conn.get_ref()
80-
.peek(&mut buf[..])
81-
.await
82-
.map_err(Error::from)?;
81+
match PollRead::new(conn, &mut buf).await {
82+
Poll::Ready(Err(error)) => Err(error),
83+
Poll::Ready(Ok(bytes)) if bytes == 0 => Err(std::io::Error::new(
84+
std::io::ErrorKind::UnexpectedEof,
85+
"connection appeared to be closed (EoF)",
86+
)),
87+
_ => Ok(()),
88+
}
89+
.map_err(Error::from)?;
8390
Ok(())
8491
}
8592
}

src/h1/utils.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use std::io;
2+
use std::pin::Pin;
3+
4+
use futures::future::Future;
5+
use futures::io::AsyncRead;
6+
use futures::task::{Context, Poll};
7+
8+
/// Like the `futures::io::Read` future, but returning the underlying `futures::task::Poll`.
9+
#[derive(Debug)]
10+
#[must_use = "futures do nothing unless you `.await` or poll them"]
11+
pub(crate) struct PollRead<'a, R: ?Sized> {
12+
reader: &'a mut R,
13+
buf: &'a mut [u8],
14+
}
15+
16+
impl<R: ?Sized + Unpin> Unpin for PollRead<'_, R> {}
17+
18+
impl<'a, R: AsyncRead + ?Sized + Unpin> PollRead<'a, R> {
19+
pub(super) fn new(reader: &'a mut R, buf: &'a mut [u8]) -> Self {
20+
Self { reader, buf }
21+
}
22+
}
23+
24+
impl<R: AsyncRead + ?Sized + Unpin> Future for PollRead<'_, R> {
25+
type Output = Poll<io::Result<usize>>;
26+
27+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
28+
let this = &mut *self;
29+
Poll::Ready(Pin::new(&mut this.reader).poll_read(cx, this.buf))
30+
}
31+
}

tests/test.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,17 @@ SOFTWARE.
135135

136136
Ok(())
137137
}
138+
139+
#[atest]
140+
async fn keep_alive() {
141+
let _mock_guard = mockito::mock("GET", "/report")
142+
.with_status(200)
143+
.expect_at_least(2)
144+
.create();
145+
146+
let client = DefaultClient::new();
147+
let url: Url = format!("{}/report", mockito::server_url()).parse().unwrap();
148+
let req = Request::new(http_types::Method::Get, url);
149+
client.send(req.clone()).await.unwrap();
150+
client.send(req.clone()).await.unwrap();
151+
}

0 commit comments

Comments
 (0)