Skip to content

Commit cab5441

Browse files
authored
Merge pull request #77 from Fishrock123/fix-h1-connection-pooling
fix: properly check if connections can be recycled
2 parents e62f1ef + 70d1805 commit cab5441

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

src/h1/tcp.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,15 @@ impl Manager<TcpStream, std::io::Error> for TcpConnection {
6363

6464
async fn recycle(&self, conn: &mut TcpStream) -> RecycleResult<std::io::Error> {
6565
let mut buf = [0; 4];
66-
conn.peek(&mut buf[..]).await?;
66+
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
67+
match Pin::new(conn).poll_read(&mut cx, &mut buf) {
68+
Poll::Ready(Err(error)) => Err(error),
69+
Poll::Ready(Ok(bytes)) if bytes == 0 => Err(std::io::Error::new(
70+
std::io::ErrorKind::UnexpectedEof,
71+
"connection appeared to be closed (EoF)",
72+
)),
73+
_ => Ok(()),
74+
}?;
6775
Ok(())
6876
}
6977
}

src/h1/tls.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,16 @@ impl Manager<TlsStream<TcpStream>, Error> for TlsConnection {
7676

7777
async fn recycle(&self, conn: &mut TlsStream<TcpStream>) -> RecycleResult<Error> {
7878
let mut buf = [0; 4];
79-
conn.get_ref()
80-
.peek(&mut buf[..])
81-
.await
82-
.map_err(Error::from)?;
79+
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
80+
match Pin::new(conn).poll_read(&mut cx, &mut buf) {
81+
Poll::Ready(Err(error)) => Err(error),
82+
Poll::Ready(Ok(bytes)) if bytes == 0 => Err(std::io::Error::new(
83+
std::io::ErrorKind::UnexpectedEof,
84+
"connection appeared to be closed (EoF)",
85+
)),
86+
_ => Ok(()),
87+
}
88+
.map_err(Error::from)?;
8389
Ok(())
8490
}
8591
}

tests/test.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,17 @@ SOFTWARE.
135135

136136
Ok(())
137137
}
138+
139+
#[atest]
140+
async fn keep_alive() {
141+
let _mock_guard = mockito::mock("GET", "/report")
142+
.with_status(200)
143+
.expect_at_least(2)
144+
.create();
145+
146+
let client = DefaultClient::new();
147+
let url: Url = format!("{}/report", mockito::server_url()).parse().unwrap();
148+
let req = Request::new(http_types::Method::Get, url);
149+
client.send(req.clone()).await.unwrap();
150+
client.send(req.clone()).await.unwrap();
151+
}

0 commit comments

Comments
 (0)