Skip to content

Commit fc61a39

Browse files
committed
fix(quic): fix quic accept issues
1 parent bfdd7e5 commit fc61a39

File tree

5 files changed

+52
-28
lines changed

5 files changed

+52
-28
lines changed

Cargo.lock

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ rand = "0.8"
5858

5959
# networking
6060
quinn = "0.11.9"
61-
rcgen = "0.12"
61+
rcgen = "0.14"
6262

6363
# benchmarking & profiling
6464
criterion = { version = "0.5", features = ["async_tokio"] }

msg-transport/src/lib.rs

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,23 @@ pub trait TransportExt<A: Address>: Transport<A> {
7878
}
7979
}
8080

81-
pub struct Acceptor<'a, T, A> {
81+
pub struct Acceptor<'a, T, A>
82+
where
83+
T: Transport<A>,
84+
A: Address,
85+
{
8286
inner: &'a mut T,
87+
pending: Option<T::Accept>,
8388
_marker: PhantomData<A>,
8489
}
8590

86-
impl<'a, T, A> Acceptor<'a, T, A> {
91+
impl<'a, T, A> Acceptor<'a, T, A>
92+
where
93+
T: Transport<A>,
94+
A: Address,
95+
{
8796
fn new(inner: &'a mut T) -> Self {
88-
Self { inner, _marker: PhantomData }
97+
Self { inner, pending: None, _marker: PhantomData }
8998
}
9099
}
91100

@@ -97,13 +106,26 @@ where
97106
type Output = Result<T::Io, T::Error>;
98107

99108
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100-
match Pin::new(&mut *self.get_mut().inner).poll_accept(cx) {
101-
Poll::Ready(mut accept) => match accept.poll_unpin(cx) {
102-
Poll::Ready(Ok(output)) => Poll::Ready(Ok(output)),
103-
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
104-
Poll::Pending => Poll::Pending,
105-
},
106-
Poll::Pending => Poll::Pending,
109+
let this = self.get_mut();
110+
111+
loop {
112+
if let Some(pending) = this.pending.as_mut() {
113+
match pending.poll_unpin(cx) {
114+
Poll::Ready(res) => {
115+
this.pending = None;
116+
return Poll::Ready(res);
117+
}
118+
Poll::Pending => return Poll::Pending,
119+
}
120+
}
121+
122+
match Pin::new(&mut *this.inner).poll_accept(cx) {
123+
Poll::Ready(accept) => {
124+
this.pending = Some(accept);
125+
continue;
126+
}
127+
Poll::Pending => return Poll::Pending,
128+
}
107129
}
108130
}
109131
}

msg-transport/src/quic/mod.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,24 +113,24 @@ impl Transport<SocketAddr> for Quic {
113113
let endpoint = if let Some(endpoint) = self.endpoint.clone() {
114114
endpoint
115115
} else {
116-
let Ok(endpoint) = self.new_endpoint(None, None) else {
116+
let Ok(mut endpoint) = self.new_endpoint(None, None) else {
117117
return async_error(Error::ClosedEndpoint);
118118
};
119119

120+
endpoint.set_default_client_config(self.config.client_config.clone());
121+
120122
self.endpoint = Some(endpoint.clone());
121123

122124
endpoint
123125
};
124126

125-
let client_config = self.config.client_config.clone();
126-
127127
Box::pin(async move {
128128
debug!(target = %addr, "Initiating connection");
129129

130130
// This `"l"` seems necessary because an empty string is an invalid domain
131131
// name. While we don't use domain names, the underlying rustls library
132132
// is based upon the assumption that we do.
133-
let connection = endpoint.connect_with(client_config, addr, "localhost")?.await?;
133+
let connection = endpoint.connect(addr, "l")?.await?;
134134

135135
debug!(target = %addr, "Connected, opening stream...");
136136

@@ -152,15 +152,15 @@ impl Transport<SocketAddr> for Quic {
152152
if let Some(ref mut incoming) = this.incoming {
153153
// Incoming channel and task are spawned, so we can poll it.
154154
match ready!(incoming.poll_recv(cx)) {
155-
Some(Ok(connecting)) => {
156-
let peer = connecting.remote_address();
155+
Some(Ok(incoming)) => {
156+
let peer = incoming.remote_address();
157157

158158
debug!("New incoming connection from {}", peer);
159159

160160
// Return a future that resolves to the output.
161161
return Poll::Ready(Box::pin(async move {
162162
debug!(client = %peer, "Accepting connection...");
163-
let connection = connecting.await?;
163+
let connection = incoming.accept()?.await?;
164164
debug!(
165165
"Accepted connection from {}, opening stream",
166166
connection.remote_address()

msg-transport/src/quic/tls.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use quinn::{
55
rustls::{
66
self, SignatureScheme,
77
client::danger::{ServerCertVerified, ServerCertVerifier},
8-
pki_types::{CertificateDer, PrivateKeyDer},
8+
pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
99
},
1010
};
1111

@@ -81,7 +81,7 @@ pub(crate) fn unsafe_client_config() -> QuicClientConfig {
8181

8282
let mut rustls_config = rustls::ClientConfig::builder_with_provider(provider)
8383
.with_protocol_versions(&[&rustls::version::TLS13])
84-
.expect("aws_lc_rs provider supports TLS 1.3")
84+
.expect("ring provider supports TLS 1.3")
8585
.dangerous()
8686
.with_custom_certificate_verifier(SkipServerVerification::new())
8787
.with_no_client_auth();
@@ -100,7 +100,7 @@ pub(crate) fn tls_server_config() -> QuicServerConfig {
100100

101101
let mut rustls_config = rustls::ServerConfig::builder_with_provider(provider)
102102
.with_protocol_versions(&[&rustls::version::TLS13])
103-
.expect("aws_lc_rs provider supports TLS 1.3")
103+
.expect("ring provider supports TLS 1.3")
104104
.with_no_client_auth()
105105
.with_single_cert(cert_chain, key_der)
106106
.expect("Valid rustls config");
@@ -113,12 +113,13 @@ pub(crate) fn tls_server_config() -> QuicServerConfig {
113113

114114
/// Generates a self-signed certificate chain and private key.
115115
pub(crate) fn self_signed_certificate() -> (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>) {
116-
let cert = rcgen::generate_simple_self_signed(vec![]).expect("Generates valid certificate");
117-
let cert_der = cert.serialize_der().expect("Serializes certificate");
118-
let priv_key =
119-
PrivateKeyDer::try_from(cert.serialize_private_key_der()).expect("Serializes private key");
116+
let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
117+
.expect("Generates valid certificate");
120118

121-
(vec![CertificateDer::from(cert_der)], priv_key)
119+
let cert_der = CertificateDer::from(cert.cert);
120+
let priv_key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
121+
122+
(vec![CertificateDer::from(cert_der)], priv_key.into())
122123
}
123124

124125
#[cfg(test)]

0 commit comments

Comments
 (0)