Skip to content

Commit 1cca93b

Browse files
authored
fix(s2n-quic-dc): send connection close frames with pruned streams (#2831)
1 parent 9b89de7 commit 1cca93b

File tree

24 files changed

+713
-181
lines changed

24 files changed

+713
-181
lines changed

dc/s2n-quic-dc/src/clock.rs

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
use core::{fmt, pin::Pin, task::Poll, time::Duration};
5-
use s2n_quic_core::{ensure, time};
5+
use s2n_quic_core::{
6+
ensure, time,
7+
time::{timer, timer::Provider},
8+
};
69
use tracing::trace;
710

811
#[macro_use]
@@ -62,7 +65,7 @@ pub trait Sleep: Clock + core::future::Future<Output = ()> {
6265

6366
pub struct Timer {
6467
/// The `Instant` at which the timer should expire
65-
target: Option<Timestamp>,
68+
target: timer::Timer,
6669
/// The handle to the timer entry in the tokio runtime
6770
sleep: Pin<Box<dyn Sleep>>,
6871
}
@@ -88,16 +91,18 @@ impl Timer {
8891
#[inline]
8992
pub fn new_with_timeout(clock: &dyn Clock, timeout: Duration) -> Self {
9093
let (sleep, target) = clock.sleep(timeout);
94+
let mut timer = timer::Timer::default();
95+
timer.set(target);
9196
Self {
92-
target: Some(target),
97+
target: timer,
9398
sleep,
9499
}
95100
}
96101

97102
#[inline]
98103
pub fn cancel(&mut self) {
99104
trace!(cancel = ?self.target);
100-
self.target = None;
105+
self.target.cancel();
101106
}
102107

103108
pub async fn sleep(&mut self, target: Timestamp) {
@@ -110,13 +115,13 @@ impl Timer {
110115
impl time::clock::Timer for Timer {
111116
#[inline]
112117
fn poll_ready(&mut self, cx: &mut core::task::Context) -> Poll<()> {
113-
ensure!(self.target.is_some(), Poll::Ready(()));
118+
ensure!(self.target.is_armed(), Poll::Ready(()));
114119

115120
let res = self.sleep.as_mut().poll(cx);
116121

117122
if res.is_ready() {
118123
// clear the target after it fires, otherwise we'll endlessly wake up the task
119-
self.target = None;
124+
self.target.cancel();
120125
}
121126

122127
res
@@ -125,9 +130,15 @@ impl time::clock::Timer for Timer {
125130
#[inline]
126131
fn update(&mut self, target: Timestamp) {
127132
// no need to update if it hasn't changed
128-
ensure!(self.target != Some(target));
133+
ensure!(self.target.next_expiration() != Some(target));
129134

130135
self.sleep.as_mut().update(target);
131-
self.target = Some(target);
136+
self.target.set(target);
137+
}
138+
}
139+
140+
impl timer::Provider for Timer {
141+
fn timers<Q: timer::Query>(&self, query: &mut Q) -> timer::Result {
142+
self.target.timers(query)
132143
}
133144
}

dc/s2n-quic-dc/src/packet/control/decoder.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,13 @@ pub struct ControlFramesMut<'a> {
279279
buffer: &'a mut [u8],
280280
}
281281

282+
impl<'a> ControlFramesMut<'a> {
283+
#[inline]
284+
pub(crate) fn new(buffer: &'a mut [u8]) -> Self {
285+
Self { buffer }
286+
}
287+
}
288+
282289
impl<'a> Iterator for ControlFramesMut<'a> {
283290
type Item = Result<FrameMut<'a>, s2n_codec::DecoderError>;
284291

dc/s2n-quic-dc/src/packet/stream/decoder.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::{
55
credentials::Credentials,
66
crypto,
77
packet::{
8+
control::decoder::ControlFramesMut,
89
stream::{self, RelativeRetransmissionOffset, Tag},
910
WireVersion,
1011
},
@@ -202,6 +203,11 @@ impl Packet<'_> {
202203
self.control_data.get(self.header)
203204
}
204205

206+
#[inline]
207+
pub fn control_frames_mut(&mut self) -> ControlFramesMut<'_> {
208+
ControlFramesMut::new(self.control_data.get_mut(self.header))
209+
}
210+
205211
#[inline]
206212
pub fn header(&self) -> &[u8] {
207213
self.header

dc/s2n-quic-dc/src/psk/io.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,19 @@ impl Server {
6565

6666
let server = s2n_quic::Server::builder().with_io(io)?;
6767

68+
let initial_max_data = builder.initial_data_window.unwrap_or_else(|| {
69+
// default to only receive 10 packet worth before the application accepts the connection
70+
builder.mtu as u64 * 10
71+
});
72+
6873
let connection_limits = s2n_quic::provider::limits::Limits::new()
6974
.with_max_idle_timeout(builder.max_idle_timeout)?
70-
.with_data_window(builder.data_window)?
75+
.with_data_window(initial_max_data)?
76+
// After the connection is established we increase the data window to the configured value
7177
.with_bidirectional_local_data_window(builder.data_window)?
72-
.with_bidirectional_remote_data_window(builder.data_window)?
78+
.with_bidirectional_remote_data_window(initial_max_data)?
7379
.with_initial_round_trip_time(DEFAULT_INITIAL_RTT)?;
80+
7481
let event = (ConfirmComplete, subscriber);
7582

7683
let server = server
@@ -191,6 +198,7 @@ impl Client {
191198
.with_bidirectional_local_data_window(builder.data_window)?
192199
.with_bidirectional_remote_data_window(builder.data_window)?
193200
.with_initial_round_trip_time(DEFAULT_INITIAL_RTT)?;
201+
194202
let event = (ConfirmComplete, subscriber);
195203

196204
let client = client

dc/s2n-quic-dc/src/psk/server/builder.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub struct Builder<
1818
#[allow(dead_code)]
1919
pub(crate) event_subscriber: Event,
2020
pub(crate) data_window: u64,
21+
pub(crate) initial_data_window: Option<u64>,
2122
pub(crate) mtu: u16,
2223
pub(crate) max_idle_timeout: Duration,
2324
pub(crate) pto_jitter_percentage: u8,
@@ -28,6 +29,7 @@ impl Default for Builder<s2n_quic::provider::event::default::Subscriber> {
2829
Self {
2930
event_subscriber: Default::default(),
3031
data_window: DEFAULT_MAX_DATA,
32+
initial_data_window: None,
3133
mtu: DEFAULT_MTU,
3234
max_idle_timeout: DEFAULT_IDLE_TIMEOUT,
3335
pto_jitter_percentage: DEFAULT_PTO_JITTER_PERCENTAGE,
@@ -44,6 +46,7 @@ impl<Event: s2n_quic::provider::event::Subscriber> Builder<Event> {
4446
Builder {
4547
event_subscriber,
4648
data_window: self.data_window,
49+
initial_data_window: self.initial_data_window,
4750
mtu: self.mtu,
4851
max_idle_timeout: self.max_idle_timeout,
4952
pto_jitter_percentage: self.pto_jitter_percentage,
@@ -56,6 +59,15 @@ impl<Event: s2n_quic::provider::event::Subscriber> Builder<Event> {
5659
self
5760
}
5861

62+
/// Sets the initial amount of data that the peer is allowed to send before the application
63+
/// accepts the stream
64+
///
65+
/// This defaults to 10x the MTU if not set.
66+
pub fn with_initial_data_window(mut self, initial_data_window: u64) -> Self {
67+
self.initial_data_window = Some(initial_data_window);
68+
self
69+
}
70+
5971
/// Sets the largest maximum transmission unit (MTU) that will be used for transmission
6072
pub fn with_mtu(mut self, mtu: u16) -> Self {
6173
self.mtu = mtu;

dc/s2n-quic-dc/src/socket/recv/router.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ use s2n_quic_core::{
1313
varint::VarInt,
1414
};
1515

16+
// Use `debug` logging for unhandled packets in non-test builds to reduce noise
17+
#[cfg(not(test))]
18+
use tracing::debug as warn;
19+
#[cfg(test)]
20+
use tracing::warn;
21+
1622
mod with_map;
1723
mod zero_router;
1824

@@ -139,7 +145,7 @@ pub trait Router {
139145
credentials: Credentials,
140146
segment: descriptor::Filled,
141147
) {
142-
tracing::warn!(
148+
warn!(
143149
unhandled_packet = "control",
144150
?tag,
145151
?id,
@@ -168,7 +174,7 @@ pub trait Router {
168174
credentials: Credentials,
169175
segment: descriptor::Filled,
170176
) {
171-
tracing::warn!(
177+
warn!(
172178
unhandled_packet = "stream",
173179
?tag,
174180
?id,
@@ -196,7 +202,7 @@ pub trait Router {
196202
credentials: Credentials,
197203
segment: descriptor::Filled,
198204
) {
199-
tracing::warn!(
205+
warn!(
200206
unhandled_packet = "datagram",
201207
?tag,
202208
?credentials,
@@ -221,7 +227,7 @@ pub trait Router {
221227
credentials: credentials::Id,
222228
segment: descriptor::Filled,
223229
) {
224-
tracing::warn!(
230+
warn!(
225231
unhandled_packet = "stale_key",
226232
?queue_id,
227233
?credentials,
@@ -246,7 +252,7 @@ pub trait Router {
246252
credentials: credentials::Id,
247253
segment: descriptor::Filled,
248254
) {
249-
tracing::warn!(
255+
warn!(
250256
unhandled_packet = "replay_detected",
251257
?queue_id,
252258
?credentials,
@@ -270,7 +276,7 @@ pub trait Router {
270276
credentials: credentials::Id,
271277
segment: descriptor::Filled,
272278
) {
273-
tracing::warn!(
279+
warn!(
274280
unhandled_packet = "unknown_path_secret",
275281
?queue_id,
276282
?credentials,
@@ -281,7 +287,7 @@ pub trait Router {
281287

282288
#[inline]
283289
fn on_unhandled_packet(&mut self, remote_address: SocketAddress, packet: packet::Packet) {
284-
tracing::warn!(unhandled_packet = ?packet, ?remote_address)
290+
warn!(unhandled_packet = ?packet, ?remote_address)
285291
}
286292

287293
#[inline]
@@ -291,7 +297,7 @@ pub trait Router {
291297
remote_address: SocketAddress,
292298
segment: descriptor::Filled,
293299
) {
294-
tracing::warn!(
300+
warn!(
295301
?error,
296302
?remote_address,
297303
packet_len = segment.len(),

dc/s2n-quic-dc/src/stream/application.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ where
9999
reason,
100100
});
101101

102-
// TODO emit event
102+
self.shared.receiver.on_prune();
103+
self.shared.sender.on_prune();
103104
}
104105
}
105106

@@ -170,6 +171,14 @@ where
170171
self.write.write_from(buf).await
171172
}
172173

174+
#[inline]
175+
pub async fn write_all_from(
176+
&mut self,
177+
buf: &mut impl buffer::reader::storage::Infallible,
178+
) -> io::Result<usize> {
179+
self.write.write_all_from(buf).await
180+
}
181+
173182
#[inline]
174183
pub async fn write_from_fin(
175184
&mut self,
@@ -178,6 +187,14 @@ where
178187
self.write.write_from_fin(buf).await
179188
}
180189

190+
#[inline]
191+
pub async fn write_all_from_fin(
192+
&mut self,
193+
buf: &mut impl buffer::reader::storage::Infallible,
194+
) -> io::Result<usize> {
195+
self.write.write_all_from_fin(buf).await
196+
}
197+
181198
#[inline]
182199
pub async fn read_into(
183200
&mut self,

dc/s2n-quic-dc/src/stream/endpoint.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ where
289289
let shared = shared.clone();
290290

291291
let task = async move {
292-
let mut reader = recv::worker::Worker::new(socket, shared, endpoint_type);
292+
let mut reader = recv::worker::Worker::new(socket, shared, endpoint_type, &parameters);
293293

294294
let mut prev_waker: Option<core::task::Waker> = None;
295295
core::future::poll_fn(|cx| {

dc/s2n-quic-dc/src/stream/environment/bach.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::{
1414
use bach::ext::*;
1515
use s2n_quic_platform::features;
1616
use std::{io, net::SocketAddr, sync::Arc};
17+
use tracing::{info_span, Instrument};
1718

1819
mod pool;
1920
pub mod udp;
@@ -188,8 +189,9 @@ where
188189
}
189190

190191
#[inline]
192+
#[track_caller]
191193
fn spawn_reader<F: 'static + Send + std::future::Future<Output = ()>>(&self, f: F) {
192-
self.rt.spawn(f.primary());
194+
self.rt.spawn(f.instrument(info_span!("reader")).primary());
193195
}
194196

195197
#[inline]
@@ -198,7 +200,8 @@ where
198200
}
199201

200202
#[inline]
203+
#[track_caller]
201204
fn spawn_writer<F: 'static + Send + std::future::Future<Output = ()>>(&self, f: F) {
202-
self.rt.spawn(f.primary());
205+
self.rt.spawn(f.instrument(info_span!("writer")).primary());
203206
}
204207
}

dc/s2n-quic-dc/src/stream/recv/application.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ use crate::{
55
clock::Timer,
66
event::{self, ConnectionPublisher as _},
77
msg,
8-
stream::{recv, runtime, shared::ArcShared, socket, Actor},
8+
stream::{
9+
recv, runtime,
10+
shared::{AcceptState, ArcShared, ShutdownKind},
11+
socket, Actor,
12+
},
913
};
1014
use core::{
1115
fmt,
@@ -227,7 +231,8 @@ where
227231

228232
loop {
229233
// try to process any bytes we have in the recv buffer
230-
reader.process_recv_buffer(out_buf, shared, transport_features);
234+
// - use the `Accepted` state since this is the application interface
235+
reader.process_recv_buffer(out_buf, shared, transport_features, AcceptState::Accepted);
231236

232237
// if we still have remaining capacity in the `out_buf` make sure the reassembler is
233238
// fully drained
@@ -239,6 +244,12 @@ where
239244
if let Err(err) = reader.receiver.check_error() {
240245
self.local_state
241246
.transition(LocalState::Errored(err), &self.shared);
247+
248+
if out_buf.written_len() > 0 {
249+
// if we've written something to the buffer then return that first
250+
break;
251+
}
252+
242253
return Err(err.into()).into();
243254
}
244255

@@ -370,9 +381,13 @@ where
370381
.transition(LocalState::Drained, &self.shared);
371382

372383
// let the peer know if we shut down cleanly
373-
let is_panicking = std::thread::panicking();
384+
let kind = if std::thread::panicking() {
385+
ShutdownKind::Panicking
386+
} else {
387+
ShutdownKind::Normal
388+
};
374389

375-
self.shared.receiver.shutdown(is_panicking);
390+
self.shared.receiver.shutdown(kind);
376391
}
377392

378393
#[inline(always)]

0 commit comments

Comments
 (0)