Skip to content

Commit a6576c0

Browse files
authored
feat(tcp): add dynamic buffer resizing and improve window handling (#2)
- Add`set_send_buffer_size`and`set_recv_buffer_size` methods for dynamic buffer resizing when`alloc`feature is enabled - Fix window calculation in`last_scaled_window` to prevent negative values and return zero when invalid - Clamp`remote_last_seq`in`dispatch` to prevent underflow and ensure valid sequence numbers - Add`resize`method to`RingBuffer` for owned storage, preserving data order and resetting read pointer - Add comprehensive unit tests for new buffer resizing functionality and window handling edge cases Signed-off-by: longjin <longjin@DragonOS.org>
1 parent 861e605 commit a6576c0

File tree

2 files changed

+150
-1
lines changed

2 files changed

+150
-1
lines changed

src/socket/tcp.rs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,16 @@ impl<'a> Socket<'a> {
558558
}
559559
}
560560

561+
#[cfg(feature = "alloc")]
562+
pub fn set_send_buffer_size(&mut self, size: usize) {
563+
self.tx_buffer.resize(size, 0);
564+
}
565+
566+
#[cfg(feature = "alloc")]
567+
pub fn set_recv_buffer_size(&mut self, size: usize) {
568+
self.rx_buffer.resize(size, 0);
569+
}
570+
561571
/// Enable or disable TCP Timestamp.
562572
pub fn set_tsval_generator(&mut self, generator: Option<TcpTimestampGenerator>) {
563573
self.tsval_generator = generator;
@@ -695,7 +705,11 @@ impl<'a> Socket<'a> {
695705
let next_ack = self.remote_seq_no + self.rx_buffer.len();
696706

697707
let last_win = (self.remote_last_win as usize) << self.remote_win_shift;
698-
let last_win_adjusted = last_ack + last_win - next_ack;
708+
let window_edge = last_ack + last_win;
709+
if next_ack > window_edge {
710+
return Some(0);
711+
}
712+
let last_win_adjusted = window_edge - next_ack;
699713

700714
Some(u16::try_from(last_win_adjusted >> self.remote_win_shift).unwrap_or(u16::MAX))
701715
}
@@ -2393,6 +2407,15 @@ impl<'a> Socket<'a> {
23932407
| State::Closing
23942408
| State::CloseWait
23952409
| State::LastAck => {
2410+
// Ensure remote_last_seq is at least local_seq_no.
2411+
// This can happen if we receive an ACK for data we haven't sent yet
2412+
// (which is invalid but shouldn't crash us), or if the remote side
2413+
// has acknowledged data that we were about to retransmit.
2414+
if self.remote_last_seq < self.local_seq_no {
2415+
self.remote_last_seq = self.local_seq_no;
2416+
repr.seq_number = self.remote_last_seq;
2417+
}
2418+
23962419
// Extract as much data as the remote side can receive in this packet
23972420
// from the transmit buffer.
23982421

@@ -6820,6 +6843,16 @@ mod test {
68206843
assert!(s.window_to_update());
68216844
}
68226845

6846+
#[test]
6847+
fn test_last_scaled_window_returns_zero_on_invalid_last_window() {
6848+
let mut s = socket_established();
6849+
s.remote_last_ack = Some(s.remote_seq_no);
6850+
s.remote_last_win = 0;
6851+
assert_eq!(s.rx_buffer.enqueue_slice(&[0u8; 4]), 4);
6852+
6853+
assert_eq!(s.last_scaled_window(), Some(0));
6854+
}
6855+
68236856
// =========================================================================================//
68246857
// Tests for timeouts.
68256858
// =========================================================================================//
@@ -7225,6 +7258,43 @@ mod test {
72257258
);
72267259
}
72277260

7261+
#[test]
7262+
fn test_dispatch_clamps_remote_last_seq_before_sending() {
7263+
let mut s = socket_established();
7264+
s.set_nagle_enabled(false);
7265+
7266+
let local_seq = s.local_seq_no;
7267+
s.remote_last_seq = local_seq - 5;
7268+
7269+
assert_eq!(s.send_slice(b"abc"), Ok(3));
7270+
recv!(s, time 0, Ok(TcpRepr {
7271+
control: TcpControl::Psh,
7272+
seq_number: local_seq,
7273+
ack_number: Some(REMOTE_SEQ + 1),
7274+
payload: &b"abc"[..],
7275+
..RECV_TEMPL
7276+
}), exact);
7277+
}
7278+
7279+
#[test]
7280+
#[cfg(feature = "alloc")]
7281+
fn test_set_buffer_size_updates_capacity() {
7282+
let mut s = socket_established();
7283+
s.set_send_buffer_size(128);
7284+
s.set_recv_buffer_size(256);
7285+
assert_eq!(s.tx_buffer.capacity(), 128);
7286+
assert_eq!(s.rx_buffer.capacity(), 256);
7287+
}
7288+
7289+
#[test]
7290+
#[cfg(feature = "alloc")]
7291+
fn test_set_send_buffer_size_does_not_shrink_below_length() {
7292+
let mut s = socket_established_with_buffer_sizes(16, 64);
7293+
assert_eq!(s.send_slice(b"abcdef"), Ok(6));
7294+
s.set_send_buffer_size(4);
7295+
assert_eq!(s.tx_buffer.capacity(), 16);
7296+
}
7297+
72287298
// =========================================================================================//
72297299
// Tests for graceful vs ungraceful rx close
72307300
// =========================================================================================//

src/storage/ring_buffer.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// these functions may have side effects, and it's implemented by [RFC 1940].
33
// [RFC 1940]: https://github.com/rust-lang/rust/issues/43302
44

5+
#[cfg(feature = "alloc")]
6+
use alloc::vec::Vec;
57
use core::cmp;
68
use managed::ManagedSlice;
79

@@ -51,6 +53,37 @@ impl<'a, T: 'a> RingBuffer<'a, T> {
5153
self.length = 0;
5254
}
5355

56+
#[cfg(feature = "alloc")]
57+
pub fn resize(&mut self, new_capacity: usize, default: T)
58+
where
59+
T: Clone,
60+
{
61+
if new_capacity < self.length {
62+
return;
63+
}
64+
65+
match &mut self.storage {
66+
ManagedSlice::Owned(vec) => {
67+
let old_capacity = vec.len();
68+
if new_capacity == old_capacity {
69+
return;
70+
}
71+
72+
let mut new_vec = Vec::with_capacity(new_capacity);
73+
new_vec.resize(new_capacity, default);
74+
75+
for i in 0..self.length {
76+
let old_idx = (self.read_at + i) % old_capacity;
77+
new_vec[i] = vec[old_idx].clone();
78+
}
79+
80+
*vec = new_vec;
81+
self.read_at = 0;
82+
}
83+
_ => {}
84+
}
85+
}
86+
5487
/// Return the maximum number of elements in the ring buffer.
5588
pub fn capacity(&self) -> usize {
5689
self.storage.len()
@@ -435,6 +468,52 @@ mod test {
435468
assert_eq!(ring.window(), 0);
436469
}
437470

471+
#[test]
472+
#[cfg(feature = "alloc")]
473+
fn test_buffer_resize_preserves_order_and_resets_read_at() {
474+
let mut ring = RingBuffer::new(vec![0u8; 4]);
475+
assert_eq!(ring.enqueue_slice(b"abcd"), 4);
476+
477+
ring.dequeue_many(2).copy_from_slice(b"..");
478+
assert_eq!(ring.enqueue_slice(b"ef"), 2);
479+
assert_eq!(ring.len(), 4);
480+
481+
ring.resize(8, 0);
482+
assert_eq!(ring.capacity(), 8);
483+
assert_eq!(ring.read_at, 0);
484+
assert_eq!(ring.len(), 4);
485+
486+
let mut data = vec![0u8; ring.len()];
487+
assert_eq!(ring.read_allocated(0, &mut data[..]), 4);
488+
assert_eq!(&data[..], b"cdef");
489+
}
490+
491+
#[test]
492+
#[cfg(feature = "alloc")]
493+
fn test_buffer_resize_does_not_shrink_below_length() {
494+
let mut ring = RingBuffer::new(vec![0u8; 4]);
495+
assert_eq!(ring.enqueue_slice(b"abc"), 3);
496+
497+
ring.resize(2, 0);
498+
assert_eq!(ring.capacity(), 4);
499+
assert_eq!(ring.len(), 3);
500+
501+
let mut data = vec![0u8; ring.len()];
502+
assert_eq!(ring.read_allocated(0, &mut data[..]), 3);
503+
assert_eq!(&data[..], b"abc");
504+
}
505+
506+
#[test]
507+
#[cfg(feature = "alloc")]
508+
fn test_buffer_resize_is_noop_for_borrowed_storage() {
509+
let mut storage = [0u8; 4];
510+
let mut ring = RingBuffer::new(&mut storage[..]);
511+
assert_eq!(ring.capacity(), 4);
512+
513+
ring.resize(8, 0);
514+
assert_eq!(ring.capacity(), 4);
515+
}
516+
438517
#[test]
439518
fn test_buffer_enqueue_dequeue_one_with() {
440519
let mut ring = RingBuffer::new(vec![0; 5]);

0 commit comments

Comments
 (0)