Skip to content

Commit 02dc44f

Browse files
authored
make server auto::Connection use hyper's IO instead of Tokios (#52)
1 parent a89ea05 commit 02dc44f

File tree

2 files changed

+54
-27
lines changed

2 files changed

+54
-27
lines changed

src/common/rewind.rs

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::marker::Unpin;
22
use std::{cmp, io};
33

44
use bytes::{Buf, Bytes};
5-
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
5+
use hyper::rt::{Read, ReadBufCursor, Write};
66

77
use std::{
88
pin::Pin,
@@ -48,21 +48,21 @@ impl<T> Rewind<T> {
4848
// }
4949
}
5050

51-
impl<T> AsyncRead for Rewind<T>
51+
impl<T> Read for Rewind<T>
5252
where
53-
T: AsyncRead + Unpin,
53+
T: Read + Unpin,
5454
{
5555
fn poll_read(
5656
mut self: Pin<&mut Self>,
5757
cx: &mut task::Context<'_>,
58-
buf: &mut ReadBuf<'_>,
58+
mut buf: ReadBufCursor<'_>,
5959
) -> Poll<io::Result<()>> {
6060
if let Some(mut prefix) = self.pre.take() {
6161
// If there are no remaining bytes, let the bytes get dropped.
6262
if !prefix.is_empty() {
63-
let copy_len = cmp::min(prefix.len(), buf.remaining());
63+
let copy_len = cmp::min(prefix.len(), remaining(&mut buf));
6464
// TODO: There should be a way to do following two lines cleaner...
65-
buf.put_slice(&prefix[..copy_len]);
65+
put_slice(&mut buf, &prefix[..copy_len]);
6666
prefix.advance(copy_len);
6767
// Put back what's left
6868
if !prefix.is_empty() {
@@ -76,9 +76,36 @@ where
7676
}
7777
}
7878

79-
impl<T> AsyncWrite for Rewind<T>
79+
fn remaining(cursor: &mut ReadBufCursor<'_>) -> usize {
80+
// SAFETY:
81+
// We do not uninitialize any set bytes.
82+
unsafe { cursor.as_mut().len() }
83+
}
84+
85+
// Copied from `ReadBufCursor::put_slice`.
86+
// If that becomes public, we could ditch this.
87+
fn put_slice(cursor: &mut ReadBufCursor<'_>, slice: &[u8]) {
88+
assert!(
89+
remaining(cursor) >= slice.len(),
90+
"buf.len() must fit in remaining()"
91+
);
92+
93+
let amt = slice.len();
94+
95+
// SAFETY:
96+
// the length is asserted above
97+
unsafe {
98+
cursor.as_mut()[..amt]
99+
.as_mut_ptr()
100+
.cast::<u8>()
101+
.copy_from_nonoverlapping(slice.as_ptr(), amt);
102+
cursor.advance(amt);
103+
}
104+
}
105+
106+
impl<T> Write for Rewind<T>
80107
where
81-
T: AsyncWrite + Unpin,
108+
T: Write + Unpin,
82109
{
83110
fn poll_write(
84111
mut self: Pin<&mut Self>,
@@ -109,10 +136,9 @@ where
109136
}
110137
}
111138

139+
/*
112140
#[cfg(test)]
113141
mod tests {
114-
// FIXME: re-implement tests with `async/await`, this import should
115-
// trigger a warning to remind us
116142
use super::Rewind;
117143
use bytes::Bytes;
118144
use tokio::io::AsyncReadExt;
@@ -159,3 +185,4 @@ mod tests {
159185
stream.read_exact(&mut buf).await.expect("read1");
160186
}
161187
}
188+
*/

src/server/conn/auto.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@ use http::{Request, Response};
1313
use http_body::Body;
1414
use hyper::{
1515
body::Incoming,
16-
rt::{bounds::Http2ServerConnExec, Timer},
16+
rt::{bounds::Http2ServerConnExec, Read, ReadBuf, Timer, Write},
1717
server::conn::{http1, http2},
1818
service::Service,
1919
};
2020
use pin_project_lite::pin_project;
21-
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
2221

23-
use crate::{common::rewind::Rewind, rt::TokioIo};
22+
use crate::common::rewind::Rewind;
2423

2524
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
2625

@@ -74,11 +73,10 @@ impl<E> Builder<E> {
7473
B: Body + Send + 'static,
7574
B::Data: Send,
7675
B::Error: Into<Box<dyn StdError + Send + Sync>>,
77-
I: AsyncRead + AsyncWrite + Unpin + 'static,
76+
I: Read + Write + Unpin + 'static,
7877
E: Http2ServerConnExec<S::Future, B>,
7978
{
8079
let (version, io) = read_version(io).await?;
81-
let io = TokioIo::new(io);
8280
match version {
8381
Version::H1 => self.http1.serve_connection(io, service).await?,
8482
Version::H2 => self.http2.serve_connection(io, service).await?,
@@ -98,11 +96,10 @@ impl<E> Builder<E> {
9896
B: Body + Send + 'static,
9997
B::Data: Send,
10098
B::Error: Into<Box<dyn StdError + Send + Sync>>,
101-
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
99+
I: Read + Write + Unpin + Send + 'static,
102100
E: Http2ServerConnExec<S::Future, B>,
103101
{
104102
let (version, io) = read_version(io).await?;
105-
let io = TokioIo::new(io);
106103
match version {
107104
Version::H1 => {
108105
self.http1
@@ -123,12 +120,14 @@ enum Version {
123120
}
124121
async fn read_version<'a, A>(mut reader: A) -> IoResult<(Version, Rewind<A>)>
125122
where
126-
A: AsyncRead + Unpin,
123+
A: Read + Unpin,
127124
{
128-
let mut buf = [0; 24];
125+
use std::mem::MaybeUninit;
126+
127+
let mut buf = [MaybeUninit::uninit(); 24];
129128
let (version, buf) = ReadVersion {
130129
reader: &mut reader,
131-
buf: ReadBuf::new(&mut buf),
130+
buf: ReadBuf::uninit(&mut buf),
132131
version: Version::H1,
133132
_pin: PhantomPinned,
134133
}
@@ -148,21 +147,21 @@ pin_project! {
148147

149148
impl<A> Future for ReadVersion<'_, A>
150149
where
151-
A: AsyncRead + Unpin + ?Sized,
150+
A: Read + Unpin + ?Sized,
152151
{
153152
type Output = IoResult<(Version, Vec<u8>)>;
154153

155154
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<(Version, Vec<u8>)>> {
156155
let this = self.project();
157156

158-
while this.buf.remaining() != 0 {
157+
while this.buf.filled().len() < H2_PREFACE.len() {
159158
if this.buf.filled() != &H2_PREFACE[0..this.buf.filled().len()] {
160159
return Poll::Ready(Ok((*this.version, this.buf.filled().to_vec())));
161160
}
162161
// if our buffer is empty, then we need to read some data to continue.
163-
let rem = this.buf.remaining();
164-
ready!(Pin::new(&mut *this.reader).poll_read(cx, this.buf))?;
165-
if this.buf.remaining() == rem {
162+
let len = this.buf.filled().len();
163+
ready!(Pin::new(&mut *this.reader).poll_read(cx, this.buf.unfilled()))?;
164+
if this.buf.filled().len() == len {
166165
return Err(IoError::new(ErrorKind::UnexpectedEof, "early eof")).into();
167166
}
168167
}
@@ -302,7 +301,7 @@ impl<E> Http1Builder<'_, E> {
302301
B: Body + Send + 'static,
303302
B::Data: Send,
304303
B::Error: Into<Box<dyn StdError + Send + Sync>>,
305-
I: AsyncRead + AsyncWrite + Unpin + 'static,
304+
I: Read + Write + Unpin + 'static,
306305
E: Http2ServerConnExec<S::Future, B>,
307306
{
308307
self.inner.serve_connection(io, service).await
@@ -450,7 +449,7 @@ impl<E> Http2Builder<'_, E> {
450449
B: Body + Send + 'static,
451450
B::Data: Send,
452451
B::Error: Into<Box<dyn StdError + Send + Sync>>,
453-
I: AsyncRead + AsyncWrite + Unpin + 'static,
452+
I: Read + Write + Unpin + 'static,
454453
E: Http2ServerConnExec<S::Future, B>,
455454
{
456455
self.inner.serve_connection(io, service).await
@@ -562,6 +561,7 @@ mod tests {
562561
tokio::spawn(async move {
563562
loop {
564563
let (stream, _) = listener.accept().await.unwrap();
564+
let stream = TokioIo::new(stream);
565565
tokio::task::spawn(async move {
566566
let _ = auto::Builder::new(TokioExecutor::new())
567567
.serve_connection(stream, service_fn(hello))

0 commit comments

Comments
 (0)