Skip to content

Commit 5ed1955

Browse files
authored
feat(s2n-quic-transport): close connection during handshake (#2792)
1 parent 5fbf7cb commit 5ed1955

File tree

5 files changed

+394
-3
lines changed

5 files changed

+394
-3
lines changed

quic/s2n-quic-core/src/transport/error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ impl Code {
112112
}
113113

114114
#[inline]
115-
pub(crate) fn as_varint(self) -> VarInt {
115+
pub fn as_varint(self) -> VarInt {
116116
self.0
117117
}
118118
}

quic/s2n-quic-tests/src/tests.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ mod buffer_limit;
2828
mod connection_limits;
2929
mod connection_migration;
3030
mod deduplicate;
31+
mod endpoint_limits;
3132
mod exporter;
3233
mod handshake_cid_rotation;
3334
mod initial_rtt;
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use super::*;
5+
use s2n_quic::provider::{
6+
connection_id,
7+
endpoint_limits::{ConnectionAttempt, Limiter, Outcome},
8+
};
9+
use s2n_quic_core::{
10+
connection::{error::Error, id},
11+
endpoint,
12+
};
13+
14+
/// A custom limiter that allows the first connection but closes subsequent ones
15+
#[derive(Default)]
16+
struct AllowFirstThenCloseLimiter {
17+
connection_count: usize,
18+
}
19+
20+
impl Limiter for AllowFirstThenCloseLimiter {
21+
fn on_connection_attempt(&mut self, _info: &ConnectionAttempt) -> Outcome {
22+
if self.connection_count == 0 {
23+
self.connection_count += 1;
24+
// Allow the first connection
25+
Outcome::allow()
26+
} else {
27+
// Close subsequent connections
28+
Outcome::close()
29+
}
30+
}
31+
}
32+
33+
// We've allocated 150 bytes for the connection close packet.
34+
// Testing with the maximum length of a connection ID ensures that we've allocated enough to store packet.
35+
const MAX_CID_LEN: usize = 20;
36+
struct MaxSizeIdFormat;
37+
38+
impl connection_id::Generator for MaxSizeIdFormat {
39+
fn generate(
40+
&mut self,
41+
_connection_info: &id::ConnectionInfo,
42+
) -> s2n_quic_core::connection::LocalId {
43+
let mut id = [0u8; MAX_CID_LEN];
44+
::rand::rng().fill_bytes(&mut id);
45+
connection_id::LocalId::try_from_bytes(&id[..]).unwrap()
46+
}
47+
}
48+
49+
impl connection_id::Validator for MaxSizeIdFormat {
50+
fn validate(&self, _connection_info: &id::ConnectionInfo, _buffer: &[u8]) -> Option<usize> {
51+
Some(MAX_CID_LEN)
52+
}
53+
}
54+
55+
// This test verifies that the server sends a CONNECTION_CLOSE frame with
56+
// error code CONNECTION_REFUSED when the server's limiter returns Outcome::close().
57+
#[test]
58+
fn endpoint_limits_close_test() {
59+
let model = Model::default();
60+
61+
let connection_close_subscriber = recorder::ConnectionClosed::new();
62+
let connection_close_event = connection_close_subscriber.events();
63+
64+
test(model, |handle| {
65+
let server = Server::builder()
66+
.with_io(handle.builder().build()?)?
67+
.with_tls(SERVER_CERTS)?
68+
.with_event(tracing_events())?
69+
.with_connection_id(MaxSizeIdFormat)?
70+
.with_random(Random::with_seed(456))?
71+
.with_endpoint_limits(AllowFirstThenCloseLimiter::default())?
72+
.start()?;
73+
74+
let server_addr = start_server(server)?;
75+
76+
let client1 = Client::builder()
77+
.with_io(handle.builder().build()?)?
78+
.with_tls(certificates::CERT_PEM)?
79+
.with_event(tracing_events())?
80+
.with_connection_id(MaxSizeIdFormat)?
81+
.with_random(Random::with_seed(456))?
82+
.start()?;
83+
84+
let client2 = Client::builder()
85+
.with_io(handle.builder().build()?)?
86+
.with_tls(certificates::CERT_PEM)?
87+
.with_event((tracing_events(), connection_close_subscriber))?
88+
.with_connection_id(MaxSizeIdFormat)?
89+
.with_random(Random::with_seed(789))?
90+
.start()?;
91+
92+
primary::spawn(async move {
93+
// First client should connect successfully
94+
let connect1 = Connect::new(server_addr).with_server_name("localhost");
95+
client1.connect(connect1).await.unwrap();
96+
97+
// Second client should fail to connect, since the server's endpoint limiter
98+
// will refuse all connections after the first one.
99+
let connect2 = Connect::new(server_addr).with_server_name("localhost");
100+
let result = client2.connect(connect2).await;
101+
assert!(matches!(result.unwrap_err(), Error::Transport { code, .. } if code == s2n_quic_core::transport::Error::CONNECTION_REFUSED.code));
102+
});
103+
104+
Ok(())
105+
})
106+
.unwrap();
107+
108+
// Verify that the client received a CONNECTION_CLOSE frame with error code CONNECTION_REFUSED,
109+
// and the CONNECTION_CLOSE frame is sent by the server (remote from client's perspectives).
110+
let connection_close_status = connection_close_event.lock().unwrap();
111+
assert_eq!(connection_close_status.len(), 1);
112+
assert!(matches!(
113+
connection_close_status[0],
114+
Error::Transport {
115+
code,
116+
initiator,
117+
..
118+
} if (code == s2n_quic_core::transport::Error::CONNECTION_REFUSED.code && initiator == endpoint::Location::Remote)
119+
));
120+
}
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use crate::endpoint;
5+
use alloc::collections::VecDeque;
6+
use core::ops::Range;
7+
use s2n_codec::{DecoderBufferMut, EncoderValue};
8+
use s2n_quic_core::{
9+
connection,
10+
crypto::{InitialHeaderKey, InitialKey},
11+
event,
12+
frame::ConnectionClose,
13+
inet::ExplicitCongestionNotification,
14+
io::tx,
15+
packet::{initial::CleartextInitial, number::PacketNumberSpace},
16+
path, time, transport,
17+
varint::VarInt,
18+
};
19+
20+
// Size for the Initial packet with the CONNECTION_CLOSE frame in this scenario
21+
22+
// CONNECTION_CLOSE Frame {
23+
// Type (8) = 0x1c,
24+
// Error Code (8) = 0x02,
25+
// Frame Type (8)= 0x00,
26+
// Reason Phrase Length (8) =0x21,
27+
// Reason Phrase (264) = "The server refused the connection",
28+
// }
29+
30+
// Initial Packet {
31+
// Header Form (1) = 1,
32+
// Fixed Bit (1) = 1,
33+
// Long Packet Type (2) = 0,
34+
// Reserved Bits (2),
35+
// Packet Number Length (2),
36+
// Version (32),
37+
// Destination Connection ID Length (8),
38+
// Destination Connection ID (160), # assuming max length for CID
39+
// Source Connection ID Length (8),
40+
// Source Connection ID (160), # assuming max length for CID
41+
// Token Length (1),
42+
// Token (0) = no token,
43+
// Length (12) = length for packet number + payload is 304 bits which needs 12 bits to encode,
44+
// Packet Number (8),
45+
// Packet Payload (296) = CONNECTION_CLOSE Frame,
46+
// }
47+
48+
// As shown above, the total size of the initial packet is 693 bits which is 87 bytes.
49+
// We use a slightly larger buffer to ensure the buffer is large enough to hold the packet.
50+
const DEFAULT_PAYLOAD_SIZE: usize = 150;
51+
52+
#[derive(Debug)]
53+
pub struct Dispatch<Path: path::Handle> {
54+
transmissions: VecDeque<Transmission<Path>>,
55+
}
56+
57+
impl<Path: path::Handle> Dispatch<Path> {
58+
pub fn new(max_peers: usize, endpoint_type: endpoint::Type) -> Self {
59+
// Only the server endpoint can send CONNECTION_CLOSE frame to drop connection request
60+
let capacity = if endpoint_type.is_server() {
61+
max_peers
62+
} else {
63+
0
64+
};
65+
Self {
66+
transmissions: VecDeque::with_capacity(capacity),
67+
}
68+
}
69+
70+
pub fn queue<C: InitialKey>(
71+
&mut self,
72+
path_handle: Path,
73+
packet: &s2n_quic_core::packet::initial::ProtectedInitial,
74+
local_connection_id: connection::LocalId,
75+
) where
76+
<C as InitialKey>::HeaderKey: InitialHeaderKey,
77+
{
78+
if let Some(transmission) = Transmission::new::<C>(path_handle, packet, local_connection_id)
79+
{
80+
self.transmissions.push_back(transmission);
81+
}
82+
}
83+
84+
pub fn on_transmit<Tx: tx::Queue<Handle = Path>, Pub: event::EndpointPublisher>(
85+
&mut self,
86+
queue: &mut Tx,
87+
publisher: &mut Pub,
88+
) {
89+
while let Some(transmission) = self.transmissions.pop_front() {
90+
match queue.push(&transmission) {
91+
Ok(tx::Outcome { len, .. }) => {
92+
publisher.on_endpoint_packet_sent(event::builder::EndpointPacketSent {
93+
packet_header: event::builder::PacketHeader::Initial {
94+
number: transmission.packet_number,
95+
version: transmission.version,
96+
},
97+
});
98+
99+
publisher.on_endpoint_datagram_sent(event::builder::EndpointDatagramSent {
100+
len: len as u16,
101+
gso_offset: 0,
102+
});
103+
104+
publisher.on_endpoint_connection_attempt_failed(
105+
event::builder::EndpointConnectionAttemptFailed {
106+
error: transport::Error::CONNECTION_REFUSED.into(),
107+
},
108+
);
109+
}
110+
Err(_) => {
111+
self.transmissions.push_front(transmission);
112+
return;
113+
}
114+
}
115+
}
116+
}
117+
}
118+
119+
pub struct Transmission<Path: path::Handle> {
120+
path: Path,
121+
packet: [u8; DEFAULT_PAYLOAD_SIZE],
122+
packet_range: Range<usize>,
123+
version: u32,
124+
packet_number: u64,
125+
}
126+
127+
impl<Path: path::Handle> core::fmt::Debug for Transmission<Path> {
128+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
129+
f.debug_struct("Transmission")
130+
.field("remote_address", &self.path.remote_address())
131+
.field("local_address", &self.path.local_address())
132+
.field("packet", &&self.packet[self.packet_range.clone()])
133+
.finish()
134+
}
135+
}
136+
137+
impl<Path: path::Handle> Transmission<Path> {
138+
pub fn new<C: InitialKey>(
139+
path: Path,
140+
packet: &s2n_quic_core::packet::initial::ProtectedInitial,
141+
local_connection_id: connection::LocalId,
142+
) -> Option<Self>
143+
where
144+
<C as InitialKey>::HeaderKey: InitialHeaderKey,
145+
{
146+
use s2n_quic_core::packet::encoding::PacketEncoder;
147+
148+
let mut packet_buf = [0u8; DEFAULT_PAYLOAD_SIZE];
149+
150+
//= https://www.rfc-editor.org/rfc/rfc9000#section-5.2.2
151+
//# If a server refuses to accept a new connection, it SHOULD send an
152+
//# Initial packet containing a CONNECTION_CLOSE frame with error code
153+
//# CONNECTION_REFUSED.
154+
155+
// We need to ensure that the packet is at least 22 bytes longer than the the minimum connection ID length,
156+
// that it requests the peer to include in its packets
157+
// Hewnce, we need to use a reason that's more than 15 bytes to ensure the packet will be sent.
158+
let connection_close = ConnectionClose {
159+
error_code: transport::Error::CONNECTION_REFUSED.code.as_varint(),
160+
frame_type: Some(VarInt::ZERO),
161+
reason: Some(b"The server refused the connection"),
162+
};
163+
164+
let mut encoded_frame = connection_close.encode_to_vec();
165+
166+
// Generate a new packet number for the connection close initial packet
167+
let packet_number = PacketNumberSpace::Initial.new_packet_number(Default::default());
168+
169+
// Create an initial packet with the connection close frame
170+
let initial_packet = CleartextInitial {
171+
version: packet.version,
172+
destination_connection_id: packet.source_connection_id(),
173+
source_connection_id: local_connection_id.as_bytes(),
174+
token: &[],
175+
packet_number,
176+
payload: DecoderBufferMut::new(&mut encoded_frame),
177+
};
178+
179+
let (mut initial_key, initial_header_key) =
180+
C::new_server(packet.destination_connection_id());
181+
182+
// There is no packet acknowledged yet, since no packet is taken from the peer.
183+
// The endpoint just close the connection immediately.
184+
let largest_acknowledged_packet_number =
185+
PacketNumberSpace::Initial.new_packet_number(VarInt::ZERO);
186+
187+
// Use the PacketEncoder trait to encode, encrypt, and protect the packet
188+
let encrypted_initial_packet = initial_packet
189+
.encode_packet(
190+
&mut initial_key,
191+
&initial_header_key,
192+
largest_acknowledged_packet_number,
193+
None,
194+
s2n_codec::EncoderBuffer::new(&mut packet_buf),
195+
)
196+
.unwrap();
197+
198+
let packet_range = 0..encrypted_initial_packet.0.len();
199+
200+
Some(Self {
201+
path,
202+
packet: packet_buf,
203+
packet_range,
204+
version: packet.version,
205+
packet_number: packet_number.as_u64(),
206+
})
207+
}
208+
}
209+
210+
impl<Path: path::Handle> AsRef<[u8]> for Transmission<Path> {
211+
fn as_ref(&self) -> &[u8] {
212+
&self.packet[self.packet_range.clone()]
213+
}
214+
}
215+
216+
impl<Path: path::Handle> tx::Message for &Transmission<Path> {
217+
type Handle = Path;
218+
219+
#[inline]
220+
fn path_handle(&self) -> &Self::Handle {
221+
&self.path
222+
}
223+
224+
#[inline]
225+
fn ecn(&mut self) -> ExplicitCongestionNotification {
226+
Default::default()
227+
}
228+
229+
#[inline]
230+
fn delay(&mut self) -> time::Duration {
231+
Default::default()
232+
}
233+
234+
#[inline]
235+
fn ipv6_flow_label(&mut self) -> u32 {
236+
0
237+
}
238+
239+
#[inline]
240+
fn can_gso(&self, segment_len: usize, _segment_count: usize) -> bool {
241+
segment_len >= self.as_ref().len()
242+
}
243+
244+
#[inline]
245+
fn write_payload(
246+
&mut self,
247+
mut buffer: tx::PayloadBuffer,
248+
_gso_offset: usize,
249+
) -> Result<usize, tx::Error> {
250+
buffer.write(self.as_ref())
251+
}
252+
}

0 commit comments

Comments
 (0)