Skip to content

Commit 29c6b40

Browse files
author
Valentin Obst
committed
net/tcp: add Rust implementation of BIC
Reimplement the Binary Increase Congestion (BIC) control algorithm in Rust. BIC is one of the smallest CCAs in the kernel and this mainly serves as a minimal example for a real-world algorithm. Signed-off-by: Valentin Obst <[email protected]>
1 parent 19a8136 commit 29c6b40

File tree

3 files changed

+326
-0
lines changed

3 files changed

+326
-0
lines changed

net/ipv4/Kconfig

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,15 @@ config TCP_CONG_BIC
509509
increase provides TCP friendliness.
510510
See http://www.csc.ncsu.edu/faculty/rhee/export/bitcp/
511511

512+
config TCP_CONG_BIC_RUST
513+
tristate "Binary Increase Congestion (BIC) control (Rust rewrite)"
514+
depends on RUST_TCP_ABSTRACTIONS
515+
help
516+
Rust rewrite of the original implementation of Binary Increase
517+
Congestion (BIC) control.
518+
519+
If unsure, say N.
520+
512521
config TCP_CONG_CUBIC
513522
tristate "CUBIC TCP"
514523
default y
@@ -704,6 +713,9 @@ choice
704713
config DEFAULT_BIC
705714
bool "Bic" if TCP_CONG_BIC=y
706715

716+
config DEFAULT_BIC_RUST
717+
bool "Bic (Rust)" if TCP_CONG_BIC_RUST=y
718+
707719
config DEFAULT_CUBIC
708720
bool "Cubic" if TCP_CONG_CUBIC=y
709721

@@ -745,6 +757,7 @@ config TCP_CONG_CUBIC
745757
config DEFAULT_TCP_CONG
746758
string
747759
default "bic" if DEFAULT_BIC
760+
default "bic_rust" if DEFAULT_BIC_RUST
748761
default "cubic" if DEFAULT_CUBIC
749762
default "htcp" if DEFAULT_HTCP
750763
default "hybla" if DEFAULT_HYBLA

net/ipv4/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ obj-$(CONFIG_INET_UDP_DIAG) += udp_diag.o
4646
obj-$(CONFIG_INET_RAW_DIAG) += raw_diag.o
4747
obj-$(CONFIG_TCP_CONG_BBR) += tcp_bbr.o
4848
obj-$(CONFIG_TCP_CONG_BIC) += tcp_bic.o
49+
obj-$(CONFIG_TCP_CONG_BIC_RUST) += tcp_bic_rust.o
4950
obj-$(CONFIG_TCP_CONG_CDG) += tcp_cdg.o
5051
obj-$(CONFIG_TCP_CONG_CUBIC) += tcp_cubic.o
5152
obj-$(CONFIG_TCP_CONG_DCTCP) += tcp_dctcp.o

net/ipv4/tcp_bic_rust.rs

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
// SPDX-License-Identifier: GPL-2.0
2+
3+
//! Binary Increase Congestion control (BIC). Based on:
4+
//! Binary Increase Congestion Control (BIC) for Fast Long-Distance
5+
//! Networks - Lisong Xu, Khaled Harfoush, and Injong Rhee
6+
//! IEEE INFOCOM 2004, Hong Kong, China, 2004, pp. 2514-2524 vol.4
7+
//! doi: 10.1109/INFCOM.2004.1354672
8+
//! Link: https://doi.org/10.1109/INFCOM.2004.1354672
9+
//! Link: https://web.archive.org/web/20160417213452/http://netsrv.csc.ncsu.edu/export/bitcp.pdf
10+
11+
use core::cmp::{max, min};
12+
use core::num::NonZeroU32;
13+
use kernel::c_str;
14+
use kernel::net::tcp::cong::{self, module_cca};
15+
use kernel::prelude::*;
16+
use kernel::time;
17+
18+
const ACK_RATIO_SHIFT: u32 = 4;
19+
20+
// TODO: Convert to module parameters once they are available.
21+
/// The initial value of ssthresh for new connections. Setting this to `None`
22+
/// implies `i32::MAX`.
23+
const INITIAL_SSTHRESH: Option<u32> = None;
24+
/// If cwnd is larger than this threshold, BIC engages; otherwise normal TCP
25+
/// increase/decrease will be performed.
26+
const LOW_WINDOW: u32 = 14;
27+
/// In binary search, go to point: `cwnd + (W_max - cwnd) / BICTCP_B`.
28+
// TODO: Convert to `new::(x).unwrap()` once `const_option` is stabilised.
29+
// SAFETY: This will panic at compile time when passing zero.
30+
const BICTCP_B: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(4) };
31+
/// The maximum increment, i.e., `S_max`. This is used during additive increase.
32+
/// After crossing `W_max`, slow start is performed until passing
33+
/// `MAX_INCREMENT * (BICTCP_B - 1)`.
34+
// TODO: Convert to `new::(x).unwrap()` once `const_option` is stabilised.
35+
// SAFETY: This will panic at compile time when passing zero.
36+
const MAX_INCREMENT: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(16) };
37+
/// The number of RTT it takes to get from `W_max - BICTCP_B` to `W_max` (and
38+
/// from `W_max` to `W_max + BICTCP_B`). This is not part of the original paper
39+
/// and results in a slow additive increase across `W_max`.
40+
const SMOOTH_PART: u32 = 20;
41+
/// Whether to use fast convergence. This is a heuristic to increase the
42+
/// release of bandwidth by existing flows to speed up the convergence to a
43+
/// steady state when a new flow joins the link.
44+
const FAST_CONVERGENCE: bool = true;
45+
/// Factor for multiplicative decrease. In fast retransmit we have:
46+
/// `cwnd = cwnd * BETA/BETA_SCALE`
47+
/// and if fast convergence is active:
48+
/// `W_max = cwnd * (1 + BETA/BETA_SCALE)/2`
49+
/// instead of `W_max = cwnd`.
50+
const BETA: u32 = 819;
51+
/// Used to calculate beta in [0, 1] with integer arithmetics.
52+
// TODO: Convert to `new::(x).unwrap()` once `const_option` is stabilised.
53+
// SAFETY: This will panic at compile time when passing zero.
54+
const BETA_SCALE: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(1024) };
55+
/// The minimum amount of time that has to pass between two updates of the cwnd.
56+
const MIN_UPDATE_INTERVAL: time::Msecs32 = time::MSEC_PER_SEC / 32;
57+
58+
module_cca! {
59+
type: Bic,
60+
name: "tcp_bic_rust",
61+
author: "Rust for Linux Contributors",
62+
description: "Binary Increase Congestion control (BIC) algorithm, Rust implementation",
63+
license: "GPL v2",
64+
}
65+
66+
struct Bic {}
67+
68+
#[vtable]
69+
impl cong::Algorithm for Bic {
70+
type Data = BicState;
71+
72+
const NAME: &'static CStr = c_str!("bic_rust");
73+
74+
fn pkts_acked(sk: &mut cong::Sock<'_, Self>, sample: &cong::AckSample) {
75+
if let Ok(cong::State::Open) = sk.inet_csk().ca_state() {
76+
let ca = sk.inet_csk_ca_mut();
77+
78+
// Track delayed acknowledgment ratio using sliding window:
79+
// ratio = (15*ratio + sample) / 16
80+
ca.delayed_ack = ca.delayed_ack.wrapping_add(
81+
sample
82+
.pkts_acked()
83+
.wrapping_sub(ca.delayed_ack >> ACK_RATIO_SHIFT),
84+
);
85+
}
86+
}
87+
88+
fn ssthresh(sk: &mut cong::Sock<'_, Self>) -> u32 {
89+
let cwnd = sk.tcp_sk().snd_cwnd();
90+
let ca = sk.inet_csk_ca_mut();
91+
92+
pr_info!(
93+
// TODO: remove
94+
"Enter fast retransmit: time {}, start {}",
95+
time::ktime_get_boot_fast_ns(),
96+
ca.start_time
97+
);
98+
99+
// Epoch has ended.
100+
ca.epoch_start = 0;
101+
ca.last_max_cwnd = if cwnd < ca.last_max_cwnd && FAST_CONVERGENCE {
102+
(cwnd * (BETA_SCALE.get() + BETA)) / (2 * BETA_SCALE.get())
103+
} else {
104+
cwnd
105+
};
106+
107+
if cwnd <= LOW_WINDOW {
108+
// Act like normal TCP.
109+
max(cwnd >> 1, 2)
110+
} else {
111+
max((cwnd * BETA) / BETA_SCALE, 2)
112+
}
113+
}
114+
115+
fn cong_avoid(sk: &mut cong::Sock<'_, Self>, _ack: u32, mut acked: u32) {
116+
if !sk.tcp_is_cwnd_limited() {
117+
return;
118+
}
119+
120+
let tp = sk.tcp_sk_mut();
121+
122+
if tp.in_slow_start() {
123+
acked = tp.slow_start(acked);
124+
if acked == 0 {
125+
pr_info!(
126+
// TODO: remove
127+
"New cwnd {}, time {}, ssthresh {}, start {}, ss 1",
128+
sk.tcp_sk().snd_cwnd(),
129+
time::ktime_get_boot_fast_ns(),
130+
sk.tcp_sk().snd_ssthresh(),
131+
sk.inet_csk_ca().start_time
132+
);
133+
return;
134+
}
135+
}
136+
137+
let cwnd = tp.snd_cwnd();
138+
let cnt = sk.inet_csk_ca_mut().update(cwnd);
139+
sk.tcp_sk_mut().cong_avoid_ai(cnt, acked);
140+
141+
pr_info!(
142+
// TODO: remove
143+
"New cwnd {}, time {}, ssthresh {}, start {}, ss 0",
144+
sk.tcp_sk().snd_cwnd(),
145+
time::ktime_get_boot_fast_ns(),
146+
sk.tcp_sk().snd_ssthresh(),
147+
sk.inet_csk_ca().start_time
148+
);
149+
}
150+
151+
fn set_state(sk: &mut cong::Sock<'_, Self>, new_state: cong::State) {
152+
if matches!(new_state, cong::State::Loss) {
153+
pr_info!(
154+
// TODO: remove
155+
"Retransmission timeout fired: time {}, start {}",
156+
time::ktime_get_boot_fast_ns(),
157+
sk.inet_csk_ca().start_time
158+
);
159+
sk.inet_csk_ca_mut().reset()
160+
}
161+
}
162+
163+
fn undo_cwnd(sk: &mut cong::Sock<'_, Self>) -> u32 {
164+
pr_info!(
165+
// TODO: remove
166+
"Undo cwnd reduction: time {}, start {}",
167+
time::ktime_get_boot_fast_ns(),
168+
sk.inet_csk_ca().start_time
169+
);
170+
171+
cong::reno::undo_cwnd(sk)
172+
}
173+
174+
fn init(sk: &mut cong::Sock<'_, Self>) {
175+
if let Some(ssthresh) = INITIAL_SSTHRESH {
176+
sk.tcp_sk_mut().set_snd_ssthresh(ssthresh);
177+
}
178+
179+
// TODO: remove
180+
pr_info!("Socket created: start {}", sk.inet_csk_ca().start_time);
181+
}
182+
183+
// TODO: remove
184+
fn release(sk: &mut cong::Sock<'_, Self>) {
185+
pr_info!(
186+
"Socket destroyed: start {}, end {}",
187+
sk.inet_csk_ca().start_time,
188+
time::ktime_get_boot_fast_ns()
189+
);
190+
}
191+
}
192+
193+
/// Internal state of each instance of the algorithm.
194+
struct BicState {
195+
/// During congestion avoidance, cwnd is increased at most every `cnt`
196+
/// acknowledged packets, i.e., the average increase per acknowledged packet
197+
/// is proportional to `1 / cnt`.
198+
// NOTE: The C impl initialises this to zero. It then ensures that zero is
199+
// never passed to `cong_avoid_ai`, which could divide by it. Make it
200+
// explicit in the types that zero is not a valid value.
201+
cnt: NonZeroU32,
202+
/// Last maximum `snd_cwnd`, i.e, `W_max`.
203+
last_max_cwnd: u32,
204+
/// The last `snd_cwnd`.
205+
last_cwnd: u32,
206+
/// Time when `last_cwnd` was updated.
207+
last_time: time::Msecs32,
208+
/// Records the beginning of an epoch.
209+
epoch_start: time::Msecs32,
210+
/// Estimates the ratio of `packets/ACK << 4`. This allows us to adjust cwnd
211+
/// per packet when a receiver is sending a single ACK for multiple received
212+
/// packets.
213+
delayed_ack: u32,
214+
/// Time when algorithm was initialised.
215+
// TODO: remove
216+
start_time: time::Nsecs,
217+
}
218+
219+
impl Default for BicState {
220+
fn default() -> Self {
221+
Self {
222+
// NOTE: Initialising this to 1 deviates from the C code. It does
223+
// not change the behaviour of the algorithm.
224+
cnt: NonZeroU32::MIN,
225+
last_max_cwnd: 0,
226+
last_cwnd: 0,
227+
last_time: 0,
228+
epoch_start: 0,
229+
delayed_ack: 2 << ACK_RATIO_SHIFT,
230+
// TODO: remove
231+
start_time: time::ktime_get_boot_fast_ns(),
232+
}
233+
}
234+
}
235+
236+
impl BicState {
237+
/// Compute congestion window to use. Returns the new `cnt`.
238+
///
239+
/// This governs the behavior of the algorithm during congestion avoidance.
240+
fn update(&mut self, cwnd: u32) -> NonZeroU32 {
241+
let now = time::ktime_get_boot_fast_ms32();
242+
243+
// Do nothing if we are invoked too frequently.
244+
if self.last_cwnd == cwnd && now.wrapping_sub(self.last_time) <= MIN_UPDATE_INTERVAL {
245+
return self.cnt;
246+
}
247+
248+
self.last_cwnd = cwnd;
249+
self.last_time = now;
250+
251+
// Record the beginning of an epoch.
252+
if self.epoch_start == 0 {
253+
self.epoch_start = now;
254+
}
255+
256+
// Start off like normal TCP.
257+
if cwnd <= LOW_WINDOW {
258+
self.cnt = NonZeroU32::new(cwnd).unwrap_or(NonZeroU32::MIN);
259+
return self.cnt;
260+
}
261+
262+
let mut new_cnt = if cwnd < self.last_max_cwnd {
263+
// binary increase
264+
let dist: u32 = (self.last_max_cwnd - cwnd) / BICTCP_B;
265+
266+
if dist > MAX_INCREMENT.get() {
267+
// additive increase
268+
cwnd / MAX_INCREMENT
269+
} else if dist <= 1 {
270+
// careful additive increase
271+
(cwnd * SMOOTH_PART) / BICTCP_B
272+
} else {
273+
// binary search
274+
cwnd / dist
275+
}
276+
} else {
277+
if cwnd < self.last_max_cwnd + BICTCP_B.get() {
278+
// careful additive increase
279+
(cwnd * SMOOTH_PART) / BICTCP_B
280+
} else if cwnd < self.last_max_cwnd + MAX_INCREMENT.get() * (BICTCP_B.get() - 1) {
281+
// slow start
282+
(cwnd * (BICTCP_B.get() - 1)) / (cwnd - self.last_max_cwnd)
283+
} else {
284+
// linear increase
285+
cwnd / MAX_INCREMENT
286+
}
287+
};
288+
289+
// If in initial slow start or link utilization is very low.
290+
if self.last_max_cwnd == 0 {
291+
new_cnt = min(new_cnt, 20);
292+
}
293+
294+
// Account for estimated packets/ACK to ensure that we increase per
295+
// packet.
296+
new_cnt = (new_cnt << ACK_RATIO_SHIFT) / self.delayed_ack;
297+
298+
self.cnt = NonZeroU32::new(new_cnt).unwrap_or(NonZeroU32::MIN);
299+
300+
self.cnt
301+
}
302+
303+
fn reset(&mut self) {
304+
// TODO: remove
305+
let tmp = self.start_time;
306+
307+
*self = Self::default();
308+
309+
// TODO: remove
310+
self.start_time = tmp;
311+
}
312+
}

0 commit comments

Comments
 (0)