Skip to content

Commit 0c3360b

Browse files
authored
fix(tls): avoid infinite loop (#499)
* fix(tls): avoid infinite loop * feat(tls): bump version for fix * fix(tls): handle eof * feat(tls,io): use futures-rustls and fix AsyncStream * fix(tls): handle UnexpectedEof from rustls * feat(tls): remove Unpin bound * feat: update version in the workspace
1 parent 10158d5 commit 0c3360b

File tree

9 files changed

+277
-570
lines changed

9 files changed

+277
-570
lines changed

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ compio-driver = { path = "./compio-driver", version = "0.9.0", default-features
3030
compio-runtime = { path = "./compio-runtime", version = "0.9.0" }
3131
compio-macros = { path = "./compio-macros", version = "0.1.2" }
3232
compio-fs = { path = "./compio-fs", version = "0.9.0" }
33-
compio-io = { path = "./compio-io", version = "0.8.0" }
33+
compio-io = { path = "./compio-io", version = "0.8.2" }
3434
compio-net = { path = "./compio-net", version = "0.9.0" }
3535
compio-signal = { path = "./compio-signal", version = "0.7.0" }
3636
compio-dispatcher = { path = "./compio-dispatcher", version = "0.8.0" }
3737
compio-log = { path = "./compio-log", version = "0.1.0" }
38-
compio-tls = { path = "./compio-tls", version = "0.7.0", default-features = false }
38+
compio-tls = { path = "./compio-tls", version = "0.7.1", default-features = false }
3939
compio-process = { path = "./compio-process", version = "0.6.0" }
4040
compio-quic = { path = "./compio-quic", version = "0.5.0", default-features = false }
4141

@@ -46,6 +46,7 @@ criterion = "0.7.0"
4646
crossbeam-queue = "0.3.8"
4747
flume = { version = "0.11.0", default-features = false }
4848
futures-channel = "0.3.29"
49+
futures-rustls = { version = "0.26.0", default-features = false }
4950
futures-util = "0.3.29"
5051
libc = "0.2.164"
5152
nix = "0.30.1"

compio-io/Cargo.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "compio-io"
3-
version = "0.8.1"
3+
version = "0.8.2"
44
description = "IO traits for completion based async IO"
55
categories = ["asynchronous"]
66
keywords = ["async", "io"]
@@ -15,7 +15,6 @@ compio-buf = { workspace = true, features = ["arrayvec", "bytes"] }
1515
futures-util = { workspace = true, features = ["sink"] }
1616
paste = { workspace = true }
1717
thiserror = { workspace = true, optional = true }
18-
pin-project-lite = { workspace = true, optional = true }
1918
serde = { version = "1.0.219", optional = true }
2019
serde_json = { version = "1.0.140", optional = true }
2120

@@ -29,7 +28,7 @@ futures-executor = "0.3.30"
2928

3029
[features]
3130
default = []
32-
compat = ["dep:pin-project-lite", "futures-util/io"]
31+
compat = ["futures-util/io"]
3332

3433
# Codecs
3534
# Serde json codec

compio-io/src/compat.rs

Lines changed: 45 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
//! Compat wrappers for interop with other crates.
22
33
use std::{
4+
fmt::Debug,
45
io::{self, BufRead, Read, Write},
56
mem::MaybeUninit,
67
pin::Pin,
78
task::{Context, Poll},
89
};
910

1011
use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit};
11-
use pin_project_lite::pin_project;
1212

1313
use crate::{PinBoxFuture, buffer::Buffer, util::DEFAULT_BUF_SIZE};
1414

@@ -176,15 +176,14 @@ impl<S: crate::AsyncWrite> SyncStream<S> {
176176
}
177177
}
178178

179-
pin_project! {
180-
/// A stream wrapper for [`futures_util::io`] traits.
181-
pub struct AsyncStream<S> {
182-
#[pin]
183-
inner: SyncStream<S>,
184-
read_future: Option<PinBoxFuture<io::Result<usize>>>,
185-
write_future: Option<PinBoxFuture<io::Result<usize>>>,
186-
shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
187-
}
179+
/// A stream wrapper for [`futures_util::io`] traits.
180+
pub struct AsyncStream<S> {
181+
// The futures keep the reference to the inner stream, so we need to pin
182+
// the inner stream to make sure the reference is valid.
183+
inner: Pin<Box<SyncStream<S>>>,
184+
read_future: Option<PinBoxFuture<io::Result<usize>>>,
185+
write_future: Option<PinBoxFuture<io::Result<usize>>>,
186+
shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
188187
}
189188

190189
impl<S> AsyncStream<S> {
@@ -200,7 +199,7 @@ impl<S> AsyncStream<S> {
200199

201200
fn new_impl(inner: SyncStream<S>) -> Self {
202201
Self {
203-
inner,
202+
inner: Box::pin(inner),
204203
read_future: None,
205204
write_future: None,
206205
shutdown_future: None,
@@ -253,20 +252,18 @@ macro_rules! poll_future_would_block {
253252

254253
impl<S: crate::AsyncRead + 'static> futures_util::AsyncRead for AsyncStream<S> {
255254
fn poll_read(
256-
self: Pin<&mut Self>,
255+
mut self: Pin<&mut Self>,
257256
cx: &mut Context<'_>,
258257
buf: &mut [u8],
259258
) -> Poll<io::Result<usize>> {
260-
let this = self.project();
261259
// Safety:
262260
// - The futures won't live longer than the stream.
263-
// - `self` is pinned.
264-
// - The inner stream won't be moved.
261+
// - The inner stream is pinned.
265262
let inner: &'static mut SyncStream<S> =
266-
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
263+
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
267264

268265
poll_future_would_block!(
269-
this.read_future,
266+
self.read_future,
270267
cx,
271268
inner.fill_read_buf(),
272269
io::Read::read(inner, buf)
@@ -279,16 +276,14 @@ impl<S: crate::AsyncRead + 'static> AsyncStream<S> {
279276
///
280277
/// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
281278
pub fn poll_read_uninit(
282-
self: Pin<&mut Self>,
279+
mut self: Pin<&mut Self>,
283280
cx: &mut Context<'_>,
284281
buf: &mut [MaybeUninit<u8>],
285282
) -> Poll<io::Result<usize>> {
286-
let this = self.project();
287-
288283
let inner: &'static mut SyncStream<S> =
289-
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
284+
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
290285
poll_future_would_block!(
291-
this.read_future,
286+
self.read_future,
292287
cx,
293288
inner.fill_read_buf(),
294289
inner.read_buf_uninit(buf)
@@ -297,79 +292,75 @@ impl<S: crate::AsyncRead + 'static> AsyncStream<S> {
297292
}
298293

299294
impl<S: crate::AsyncRead + 'static> futures_util::AsyncBufRead for AsyncStream<S> {
300-
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
301-
let this = self.project();
302-
295+
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
303296
let inner: &'static mut SyncStream<S> =
304-
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
297+
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
305298
poll_future_would_block!(
306-
this.read_future,
299+
self.read_future,
307300
cx,
308301
inner.fill_read_buf(),
309302
// Safety: anyway the slice won't be used after free.
310303
io::BufRead::fill_buf(inner).map(|slice| unsafe { &*(slice as *const _) })
311304
)
312305
}
313306

314-
fn consume(self: Pin<&mut Self>, amt: usize) {
315-
let this = self.project();
316-
317-
let inner: &'static mut SyncStream<S> =
318-
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
319-
inner.consume(amt)
307+
fn consume(mut self: Pin<&mut Self>, amt: usize) {
308+
unsafe { self.inner.as_mut().get_unchecked_mut().consume(amt) }
320309
}
321310
}
322311

323312
impl<S: crate::AsyncWrite + 'static> futures_util::AsyncWrite for AsyncStream<S> {
324313
fn poll_write(
325-
self: Pin<&mut Self>,
314+
mut self: Pin<&mut Self>,
326315
cx: &mut Context<'_>,
327316
buf: &[u8],
328317
) -> Poll<io::Result<usize>> {
329-
let this = self.project();
330-
331-
if this.shutdown_future.is_some() {
332-
debug_assert!(this.write_future.is_none());
318+
if self.shutdown_future.is_some() {
319+
debug_assert!(self.write_future.is_none());
333320
return Poll::Pending;
334321
}
335322

336323
let inner: &'static mut SyncStream<S> =
337-
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
324+
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
338325
poll_future_would_block!(
339-
this.write_future,
326+
self.write_future,
340327
cx,
341328
inner.flush_write_buf(),
342329
io::Write::write(inner, buf)
343330
)
344331
}
345332

346-
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
347-
let this = self.project();
348-
349-
if this.shutdown_future.is_some() {
350-
debug_assert!(this.write_future.is_none());
333+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
334+
if self.shutdown_future.is_some() {
335+
debug_assert!(self.write_future.is_none());
351336
return Poll::Pending;
352337
}
353338

354339
let inner: &'static mut SyncStream<S> =
355-
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
356-
let res = poll_future!(this.write_future, cx, inner.flush_write_buf());
340+
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
341+
let res = poll_future!(self.write_future, cx, inner.flush_write_buf());
357342
Poll::Ready(res.map(|_| ()))
358343
}
359344

360-
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
361-
let this = self.project();
362-
345+
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
363346
// Avoid shutdown on flush because the inner buffer might be passed to the
364347
// driver.
365-
if this.write_future.is_some() {
366-
debug_assert!(this.shutdown_future.is_none());
348+
if self.write_future.is_some() {
349+
debug_assert!(self.shutdown_future.is_none());
367350
return Poll::Pending;
368351
}
369352

370353
let inner: &'static mut SyncStream<S> =
371-
unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) };
372-
let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown());
354+
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
355+
let res = poll_future!(self.shutdown_future, cx, inner.get_mut().shutdown());
373356
Poll::Ready(res)
374357
}
375358
}
359+
360+
impl<S: Debug> Debug for AsyncStream<S> {
361+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362+
f.debug_struct("AsyncStream")
363+
.field("inner", &self.inner)
364+
.finish_non_exhaustive()
365+
}
366+
}

compio-tls/Cargo.toml

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "compio-tls"
3-
version = "0.7.0"
3+
version = "0.7.1"
44
description = "TLS adaptor with compio"
55
categories = ["asynchronous", "network-programming"]
66
keywords = ["async", "net", "tls"]
@@ -25,6 +25,12 @@ rustls = { workspace = true, default-features = false, optional = true, features
2525
"tls12",
2626
] }
2727

28+
futures-rustls = { workspace = true, default-features = false, optional = true, features = [
29+
"logging",
30+
"tls12",
31+
] }
32+
futures-util = { workspace = true, optional = true }
33+
2834
[dev-dependencies]
2935
compio-net = { workspace = true }
3036
compio-runtime = { workspace = true }
@@ -33,14 +39,18 @@ compio-macros = { workspace = true }
3339
rustls = { workspace = true, default-features = false, features = ["ring"] }
3440
rustls-native-certs = { workspace = true }
3541

42+
futures-rustls = { workspace = true, default-features = false, features = [
43+
"ring",
44+
] }
45+
3646
[features]
3747
default = ["native-tls"]
3848
all = ["native-tls", "rustls"]
39-
rustls = ["dep:rustls"]
49+
rustls = ["dep:rustls", "dep:futures-rustls", "dep:futures-util"]
4050

41-
ring = ["rustls", "rustls/ring"]
42-
aws-lc-rs = ["rustls", "rustls/aws-lc-rs"]
43-
aws-lc-rs-fips = ["aws-lc-rs", "rustls/fips"]
51+
ring = ["rustls", "rustls/ring", "futures-rustls/ring"]
52+
aws-lc-rs = ["rustls", "rustls/aws-lc-rs", "futures-rustls/aws-lc-rs"]
53+
aws-lc-rs-fips = ["aws-lc-rs", "rustls/fips", "futures-rustls/fips"]
4454

4555
read_buf = ["compio-buf/read_buf", "compio-io/read_buf", "rustls?/read_buf"]
4656
nightly = ["read_buf"]

0 commit comments

Comments
 (0)