Skip to content

Commit ded5104

Browse files
author
Valentin Obst
committed
rust/net: add sock, tcp_sock and icsk wrappers
Signed-off-by: Valentin Obst <[email protected]>
1 parent 5c84c6c commit ded5104

File tree

6 files changed

+323
-0
lines changed

6 files changed

+323
-0
lines changed

net/ipv4/Kconfig

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,22 @@ config INET_DIAG_DESTROY
466466
had been disconnected.
467467
If unsure, say N.
468468

469+
config RUST_SOCK_ABSTRACTIONS
470+
bool "INET: Rust sock abstractions"
471+
depends on RUST
472+
help
473+
Adds Rust abstractions for working with `struct sock`s.
474+
475+
If unsure, say N.
476+
477+
config RUST_TCP_ABSTRACTIONS
478+
bool "TCP: Rust abstractions"
479+
depends on RUST_SOCK_ABSTRACTIONS
480+
help
481+
Adds support for writing Rust kernel modules that integrate with TCP.
482+
483+
If unsure, say N.
484+
469485
menuconfig TCP_CONG_ADVANCED
470486
bool "TCP: advanced congestion control"
471487
help

rust/bindings/bindings_helper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <linux/slab.h>
1818
#include <linux/wait.h>
1919
#include <linux/workqueue.h>
20+
#include <net/tcp.h>
2021

2122
/* `bindgen` gets confused at certain things. */
2223
const size_t RUST_CONST_HELPER_ARCH_SLAB_MINALIGN = ARCH_SLAB_MINALIGN;

rust/helpers.c

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <linux/spinlock.h>
3232
#include <linux/wait.h>
3333
#include <linux/workqueue.h>
34+
#include <net/tcp.h>
3435

3536
__noreturn void rust_helper_BUG(void)
3637
{
@@ -157,6 +158,36 @@ void rust_helper_init_work_with_key(struct work_struct *work, work_func_t func,
157158
}
158159
EXPORT_SYMBOL_GPL(rust_helper_init_work_with_key);
159160

161+
bool rust_helper_tcp_in_slow_start(const struct tcp_sock *tp)
162+
{
163+
return tcp_in_slow_start(tp);
164+
}
165+
EXPORT_SYMBOL_GPL(rust_helper_tcp_in_slow_start);
166+
167+
bool rust_helper_tcp_is_cwnd_limited(const struct sock *sk)
168+
{
169+
return tcp_is_cwnd_limited(sk);
170+
}
171+
EXPORT_SYMBOL_GPL(rust_helper_tcp_is_cwnd_limited);
172+
173+
struct tcp_sock *rust_helper_tcp_sk(struct sock *sk)
174+
{
175+
return tcp_sk(sk);
176+
}
177+
EXPORT_SYMBOL_GPL(rust_helper_tcp_sk);
178+
179+
u32 rust_helper_tcp_snd_cwnd(const struct tcp_sock *tp)
180+
{
181+
return tcp_snd_cwnd(tp);
182+
}
183+
EXPORT_SYMBOL_GPL(rust_helper_tcp_snd_cwnd);
184+
185+
struct inet_connection_sock *rust_helper_inet_csk(const struct sock *sk)
186+
{
187+
return inet_csk(sk);
188+
}
189+
EXPORT_SYMBOL_GPL(rust_helper_inet_csk);
190+
160191
/*
161192
* `bindgen` binds the C `size_t` type as the Rust `usize` type, so we can
162193
* use it in contexts where Rust expects a `usize` like slice (array) indices.

rust/kernel/net.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@
44
55
#[cfg(CONFIG_RUST_PHYLIB_ABSTRACTIONS)]
66
pub mod phy;
7+
#[cfg(CONFIG_RUST_SOCK_ABSTRACTIONS)]
8+
pub mod sock;
9+
#[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)]
10+
pub mod tcp;

rust/kernel/net/sock.rs

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// SPDX-License-Identifier: GPL-2.0-only
2+
3+
//! Representation of a C `struct sock`.
4+
//!
5+
//! C header: [`include/net/sock.h`](srctree/include/net/sock.h)
6+
7+
#[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)]
8+
use crate::net::tcp::{self, InetConnectionSock, TcpSock};
9+
use crate::types::Opaque;
10+
use core::convert::TryFrom;
11+
use core::ptr::addr_of;
12+
13+
/// Representation of a C `struct sock`.
14+
///
15+
/// Not intended to be used directly by modules. Abstractions should provide a
16+
/// safe interface to only those operations that are OK to use for the module.
17+
///
18+
/// # Invariants
19+
///
20+
/// Referencing a `sock` using this struct asserts that you are in
21+
/// a context where all safe methods defined on this struct are indeed safe to
22+
/// call.
23+
#[repr(transparent)]
24+
pub(crate) struct Sock {
25+
sk: Opaque<bindings::sock>,
26+
}
27+
28+
impl Sock {
29+
/// Returns a raw pointer to the wrapped `struct sock`.
30+
///
31+
/// It is up to the caller to use it correctly.
32+
#[inline]
33+
pub(crate) fn raw_sk_mut(&mut self) -> *mut bindings::sock {
34+
self.sk.get()
35+
}
36+
37+
/// Returns the sockets pacing rate in bytes per second.
38+
#[inline]
39+
pub(crate) fn sk_pacing_rate(&self) -> u64 {
40+
// NOTE: C uses READ_ONCE for this field, thus `read_volatile`.
41+
// SAFETY: The struct invariant ensures that we may access
42+
// this field without additional synchronization. It is a C unsigned
43+
// long so we can always convert it to a u64 without loss.
44+
unsafe { addr_of!((*self.sk.get()).sk_pacing_rate).read_volatile() as u64 }
45+
}
46+
47+
/// Returns the sockets pacing status.
48+
#[inline]
49+
pub(crate) fn sk_pacing_status(&self) -> Result<Pacing, ()> {
50+
// SAFETY: The struct invariant ensures that we may access
51+
// this field without additional synchronization.
52+
unsafe { Pacing::try_from(*addr_of!((*self.sk.get()).sk_pacing_status)) }
53+
}
54+
55+
/// Returns the sockets maximum GSO segment size to build.
56+
#[inline]
57+
pub(crate) fn sk_gso_max_size(&self) -> u32 {
58+
// SAFETY: The struct invariant ensures that we may access
59+
// this field without additional synchronization. It is an unsigned int
60+
// and we are guaranteed that this will always fit into a u32.
61+
unsafe { *addr_of!((*self.sk.get()).sk_gso_max_size) as u32 }
62+
}
63+
64+
/// Returns the [`TcpSock`] that is containing the `Sock`.
65+
///
66+
/// # Safety
67+
///
68+
/// `sk` must be valid for `tcp_sk`.
69+
#[inline]
70+
#[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)]
71+
pub(crate) unsafe fn tcp_sk<'a>(&'a self) -> &'a TcpSock {
72+
// SAFETY:
73+
// - Downcasting via `tcp_sk` is OK by the functions precondition.
74+
// - The cast is OK since `TcpSock` is transparent to `struct tcp_sock`.
75+
unsafe { &*(bindings::tcp_sk(self.sk.get()) as *const TcpSock) }
76+
}
77+
78+
/// Returns the [`TcpSock`] that is containing the `Sock`.
79+
///
80+
/// # Safety
81+
///
82+
/// `sk` must be valid for `tcp_sk`.
83+
#[inline]
84+
#[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)]
85+
pub(crate) unsafe fn tcp_sk_mut<'a>(&'a mut self) -> &'a mut TcpSock {
86+
// SAFETY:
87+
// - Downcasting via `tcp_sk` is OK by the functions precondition.
88+
// - The cast is OK since `TcpSock` is transparent to `struct tcp_sock`.
89+
unsafe { &mut *(bindings::tcp_sk(self.sk.get()) as *mut TcpSock) }
90+
}
91+
92+
/// Returns the [`InetConnectionSock`] view of this socket.
93+
///
94+
/// # Safety
95+
///
96+
/// `sk` must be valid for `inet_csk`.
97+
#[inline]
98+
#[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)]
99+
pub(crate) unsafe fn inet_csk<'a>(&'a self) -> &'a InetConnectionSock {
100+
// SAFETY:
101+
// - Calling `inet_csk` is OK by the functions precondition.
102+
// - The cast is OK since `InetConnectionSock` is transparent to
103+
// `struct inet_connection_sock`.
104+
unsafe { &*(bindings::inet_csk(self.sk.get()) as *const InetConnectionSock) }
105+
}
106+
107+
/// Tests if the connection's sending rate is limited by the cwnd.
108+
///
109+
/// # Safety
110+
///
111+
/// `sk` must be valid for `tcp_is_cwnd_limited`.
112+
#[inline]
113+
#[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)]
114+
pub(crate) unsafe fn tcp_is_cwnd_limited(&self) -> bool {
115+
// SAFETY: Calling `tcp_is_cwnd_limited` is OK by the functions
116+
// precondition.
117+
unsafe { bindings::tcp_is_cwnd_limited(self.sk.get()) }
118+
}
119+
}
120+
121+
/// The socket's pacing status.
122+
#[repr(u32)]
123+
#[allow(missing_docs)]
124+
pub enum Pacing {
125+
r#None = bindings::sk_pacing_SK_PACING_NONE,
126+
Needed = bindings::sk_pacing_SK_PACING_NEEDED,
127+
Fq = bindings::sk_pacing_SK_PACING_FQ,
128+
}
129+
130+
// TODO: Replace with automatically generated code by bindgen when it becomes
131+
// possible.
132+
impl TryFrom<u32> for Pacing {
133+
type Error = ();
134+
135+
fn try_from(val: u32) -> Result<Self, Self::Error> {
136+
match val {
137+
x if x == Pacing::r#None as u32 => Ok(Pacing::r#None),
138+
x if x == Pacing::Needed as u32 => Ok(Pacing::Needed),
139+
x if x == Pacing::Fq as u32 => Ok(Pacing::Fq),
140+
_ => Err(()),
141+
}
142+
}
143+
}

rust/kernel/net/tcp.rs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// SPDX-License-Identifier: GPL-2.0-only
2+
3+
//! Transmission Control Protocol (TCP).
4+
5+
use crate::time;
6+
use crate::types::Opaque;
7+
use core::{num, ptr};
8+
9+
/// Representation of a `struct inet_connection_sock`.
10+
///
11+
/// # Invariants
12+
///
13+
/// Referencing a `inet_connection_sock` using this struct asserts that you are
14+
/// in a context where all safe methods defined on this struct are indeed safe
15+
/// to call.
16+
///
17+
/// C header: [`include/net/inet_connection_sock.h`](srctree/include/net/inet_connection_sock.h)
18+
#[repr(transparent)]
19+
pub struct InetConnectionSock {
20+
icsk: Opaque<bindings::inet_connection_sock>,
21+
}
22+
23+
/// Representation of a `struct tcp_sock`.
24+
///
25+
/// # Invariants
26+
///
27+
/// Referencing a `tcp_sock` using this struct asserts that you are in
28+
/// a context where all safe methods defined on this struct are indeed safe to
29+
/// call.
30+
///
31+
/// C header: [`include/linux/tcp.h`](srctree/include/linux/tcp.h)
32+
#[repr(transparent)]
33+
pub struct TcpSock {
34+
tp: Opaque<bindings::tcp_sock>,
35+
}
36+
37+
impl TcpSock {
38+
/// Returns true iff `snd_cwnd < snd_ssthresh`.
39+
#[inline]
40+
pub fn in_slow_start(&self) -> bool {
41+
// SAFETY: The struct invariant ensures that we may call this function
42+
// without additional synchronization.
43+
unsafe { bindings::tcp_in_slow_start(self.tp.get()) }
44+
}
45+
46+
/// Performs the standard slow start increment of cwnd.
47+
///
48+
/// If this causes the socket to exit slow start, any leftover ACKs are
49+
/// returned.
50+
#[inline]
51+
pub fn slow_start(&mut self, acked: u32) -> u32 {
52+
// SAFETY: The struct invariant ensures that we may call this function
53+
// without additional synchronization.
54+
unsafe { bindings::tcp_slow_start(self.tp.get(), acked) }
55+
}
56+
57+
/// Performs the standard increase of cwnd during congestion avoidance.
58+
///
59+
/// The increase per ACK is upper bounded by `1 / w`.
60+
#[inline]
61+
pub fn cong_avoid_ai(&mut self, w: num::NonZeroU32, acked: u32) {
62+
// SAFETY: The struct invariant ensures that we may call this function
63+
// without additional synchronization.
64+
unsafe { bindings::tcp_cong_avoid_ai(self.tp.get(), w.get(), acked) };
65+
}
66+
67+
/// Returns the connection's current cwnd.
68+
#[inline]
69+
pub fn snd_cwnd(&self) -> u32 {
70+
// SAFETY: The struct invariant ensures that we may call this function
71+
// without additional synchronization.
72+
unsafe { bindings::tcp_snd_cwnd(self.tp.get()) }
73+
}
74+
75+
/// Returns the connection's current ssthresh.
76+
#[inline]
77+
pub fn snd_ssthresh(&self) -> u32 {
78+
// SAFETY: The struct invariant ensures that we may access
79+
// this field without additional synchronization.
80+
unsafe { *ptr::addr_of!((*self.tp.get()).snd_ssthresh) }
81+
}
82+
83+
/// Returns the sequence number of the next byte that will be sent.
84+
#[inline]
85+
pub fn snd_nxt(&self) -> u32 {
86+
// SAFETY: The struct invariant ensures that we may access
87+
// this field without additional synchronization.
88+
unsafe { *ptr::addr_of!((*self.tp.get()).snd_nxt) }
89+
}
90+
91+
/// Returns the sequence number of the first unacknowledged byte.
92+
#[inline]
93+
pub fn snd_una(&self) -> u32 {
94+
// SAFETY: The struct invariant ensures that we may access
95+
// this field without additional synchronization.
96+
unsafe { *ptr::addr_of!((*self.tp.get()).snd_una) }
97+
}
98+
99+
/// Returns the time when the last packet was received or sent.
100+
#[inline]
101+
pub fn tcp_mstamp(&self) -> time::Usecs {
102+
// SAFETY: The struct invariant ensures that we may access
103+
// this field without additional synchronization.
104+
unsafe { *ptr::addr_of!((*self.tp.get()).tcp_mstamp) }
105+
}
106+
107+
/// Sets the connection's ssthresh.
108+
#[inline]
109+
pub fn set_snd_ssthresh(&mut self, new: u32) {
110+
// SAFETY: The struct invariant ensures that we may access
111+
// this field without additional synchronization.
112+
unsafe { *ptr::addr_of_mut!((*self.tp.get()).snd_ssthresh) = new };
113+
}
114+
115+
/// Returns the timestamp of the last send data packet in 32bit Jiffies.
116+
#[inline]
117+
pub fn lsndtime(&self) -> time::Jiffies32 {
118+
// SAFETY: The struct invariant ensures that we may access
119+
// this field without additional synchronization.
120+
unsafe { *ptr::addr_of!((*self.tp.get()).lsndtime) as time::Jiffies32 }
121+
}
122+
}
123+
124+
/// Tests if `sqn_1` comes after `sqn_2`.
125+
#[inline]
126+
pub fn after(sqn_1: u32, sqn_2: u32) -> bool {
127+
(sqn_2.wrapping_sub(sqn_1) as i32) < 0
128+
}

0 commit comments

Comments
 (0)