44
55use std:: ffi:: c_int;
66use std:: ffi:: c_void;
7+ use std:: ffi:: CStr ;
78use std:: io;
89use std:: io:: Read as _;
910use std:: io:: Write as _;
@@ -12,6 +13,7 @@ use std::net::TcpStream;
1213use std:: os:: fd:: AsFd as _;
1314use std:: os:: fd:: AsRawFd as _;
1415use std:: os:: fd:: BorrowedFd ;
16+ use std:: ptr:: copy_nonoverlapping;
1517use std:: thread;
1618
1719use clap:: Parser ;
@@ -22,13 +24,18 @@ use libc::TCP_CONGESTION;
2224
2325use libbpf_rs:: skel:: OpenSkel ;
2426use libbpf_rs:: skel:: SkelBuilder ;
27+ use libbpf_rs:: AsRawLibbpf as _;
2528use libbpf_rs:: ErrorExt as _;
29+ use libbpf_rs:: ErrorKind ;
2630use libbpf_rs:: Result ;
2731
2832use crate :: tcp_ca:: TcpCaSkelBuilder ;
2933
3034mod tcp_ca {
31- include ! ( concat!( env!( "CARGO_MANIFEST_DIR" ) , "/src/bpf/tcp_ca.skel.rs" ) ) ;
35+ include ! ( concat!(
36+ env!( "CARGO_MANIFEST_DIR" ) ,
37+ "/src/bpf/tcp_ca.skel.rs"
38+ ) ) ;
3239}
3340
3441const TCP_CA_UPDATE : & [ u8 ] = b"tcp_ca_update\0 " ;
@@ -58,24 +65,29 @@ fn set_sock_opt(
5865
5966/// Set the `tcp_ca_update` congestion algorithm on the socket represented by
6067/// the provided file descriptor.
61- fn set_tcp_ca ( fd : BorrowedFd < ' _ > ) -> Result < ( ) > {
68+ fn set_tcp_ca ( fd : BorrowedFd < ' _ > , tcp_ca : & CStr ) -> Result < ( ) > {
6269 let ( ) = set_sock_opt (
6370 fd,
6471 IPPROTO_TCP ,
6572 TCP_CONGESTION ,
66- TCP_CA_UPDATE . as_ptr ( ) . cast ( ) ,
67- ( TCP_CA_UPDATE . len ( ) - 1 ) as _ ,
73+ tcp_ca . as_ptr ( ) . cast ( ) ,
74+ tcp_ca . to_bytes ( ) . len ( ) as _ ,
6875 )
69- . context ( "failed to set TCP_CONGESTION" ) ?;
76+ . with_context ( || {
77+ format ! (
78+ "failed to set TCP_CONGESTION algorithm `{}`" ,
79+ tcp_ca. to_str( ) . unwrap( )
80+ )
81+ } ) ?;
7082 Ok ( ( ) )
7183}
7284
7385/// Send and receive a bunch of data over TCP sockets using the `tcp_ca_update`
7486/// congestion algorithm.
75- fn send_recv ( ) -> Result < ( ) > {
87+ fn send_recv ( tcp_ca : & CStr ) -> Result < ( ) > {
7688 let num_bytes = 8 * 1024 * 1024 ;
7789 let listener = TcpListener :: bind ( "[::1]:0" ) ?;
78- let ( ) = set_tcp_ca ( listener. as_fd ( ) ) ?;
90+ let ( ) = set_tcp_ca ( listener. as_fd ( ) , tcp_ca ) ?;
7991 let addr = listener. local_addr ( ) ?;
8092
8193 let send_handle = thread:: spawn ( move || {
@@ -86,38 +98,96 @@ fn send_recv() -> Result<()> {
8698
8799 let mut received = Vec :: new ( ) ;
88100 let mut stream = TcpStream :: connect ( addr) ?;
89- let ( ) = set_tcp_ca ( stream. as_fd ( ) ) ?;
101+ let ( ) = set_tcp_ca ( stream. as_fd ( ) , tcp_ca ) ?;
90102 let _count = stream. read_to_end ( & mut received) ?;
91103 let ( ) = send_handle. join ( ) . unwrap ( ) ;
92104
93105 assert_eq ! ( received. len( ) , num_bytes) ;
94106 Ok ( ( ) )
95107}
96108
97- fn main ( ) -> Result < ( ) > {
98- let args = Args :: parse ( ) ;
99-
109+ fn test ( name_to_register : Option < & CStr > , name_to_use : & CStr , verbose : bool ) -> Result < ( ) > {
100110 let mut skel_builder = TcpCaSkelBuilder :: default ( ) ;
101- if args . verbose {
111+ if verbose {
102112 skel_builder. obj_builder . debug ( true ) ;
103113 }
104114
105- let open_skel = skel_builder. open ( ) ?;
115+ let mut open_skel = skel_builder. open ( ) ?;
116+
117+ if let Some ( name) = name_to_register {
118+ // Here we illustrate the possibility of updating `struct_ops` data before
119+ // load. That can be used to communicate data to the kernel, e.g., for
120+ // initialization purposes.
121+ let ca_update = open_skel. struct_ops . ca_update_mut ( ) ;
122+ if name. to_bytes_with_nul ( ) . len ( ) > ca_update. name . len ( ) {
123+ panic ! (
124+ "TCP CA name `{}` exceeds maximum length {}" ,
125+ name. to_str( ) . unwrap( ) ,
126+ ca_update. name. len( )
127+ ) ;
128+ }
129+ let len = name. to_bytes_with_nul ( ) . len ( ) ;
130+ let ( ) = unsafe { copy_nonoverlapping ( name. as_ptr ( ) , ca_update. name . as_mut_ptr ( ) , len) } ;
131+ let ( ) = ca_update. name [ len..] . fill ( 0 ) ;
132+ }
133+
134+ let ca_update_cong_control2 = open_skel
135+ . progs ( )
136+ . ca_update_cong_control2 ( )
137+ . as_libbpf_object ( )
138+ . as_ptr ( ) ;
139+ let ca_update = open_skel. struct_ops . ca_update_mut ( ) ;
140+ ca_update. cong_control = ca_update_cong_control2;
141+
106142 let mut skel = open_skel. load ( ) ?;
107143 let mut maps = skel. maps_mut ( ) ;
108144 let map = maps. ca_update ( ) ;
109145 let _link = map. attach_struct_ops ( ) ?;
110146
111- println ! ( "Registered `tcp_ca_update` congestion algorithm; using it for loopback based data exchange..." ) ;
147+ println ! (
148+ "Registered `{}` congestion algorithm; using `{}` for loopback based data exchange..." ,
149+ name_to_register. unwrap_or( name_to_use) . to_str( ) . unwrap( ) ,
150+ name_to_use. to_str( ) . unwrap( )
151+ ) ;
152+
153+ // NB: At this point `/proc/sys/net/ipv4/tcp_available_congestion_control`
154+ // would list the registered congestion algorithm.
112155
113156 assert_eq ! ( skel. bss( ) . ca_cnt, 0 ) ;
157+ assert ! ( !skel. bss( ) . cong_control) ;
114158
115159 // Use our registered TCP congestion algorithm while sending a bunch of data
116160 // over the loopback device.
117- let ( ) = send_recv ( ) ?;
161+ let ( ) = send_recv ( name_to_use ) ?;
118162 println ! ( "Done." ) ;
119163
120- let saved_ca1_cnt = skel. bss ( ) . ca_cnt ;
121- assert_ne ! ( saved_ca1_cnt, 0 ) ;
164+ let saved_ca_cnt = skel. bss ( ) . ca_cnt ;
165+ assert_ne ! ( saved_ca_cnt, 0 ) ;
166+ // With `ca_update_cong_control2` active, we should have seen the
167+ // `cong_control` value changed as well.
168+ assert ! ( skel. bss( ) . cong_control) ;
169+ Ok ( ( ) )
170+ }
171+
172+ fn main ( ) -> Result < ( ) > {
173+ let args = Args :: parse ( ) ;
174+
175+ let tcp_ca = CStr :: from_bytes_until_nul ( TCP_CA_UPDATE ) . unwrap ( ) ;
176+ let ( ) = test ( None , tcp_ca, args. verbose ) ?;
177+
178+ // Use a different name under which the algorithm is registered; just for
179+ // illustration purposes of how to change `struct_ops` related data before
180+ // load/attachment.
181+ let new_ca = CStr :: from_bytes_until_nul ( b"anotherca\0 " ) . unwrap ( ) ;
182+ let ( ) = test ( Some ( new_ca) , new_ca, args. verbose ) ?;
183+
184+ // Just to be sure we are not bullshitting with the above, use a different
185+ // congestion algorithm than what we register. This is expected to fail,
186+ // because said algorithm to use cannot be found.
187+ let to_register = CStr :: from_bytes_until_nul ( b"holycowca\0 " ) . unwrap ( ) ;
188+ let err = test ( Some ( to_register) , tcp_ca, args. verbose ) . unwrap_err ( ) ;
189+ assert_eq ! ( err. kind( ) , ErrorKind :: NotFound ) ;
190+ println ! ( "Expected failure: {err:#}" ) ;
191+
122192 Ok ( ( ) )
123193}
0 commit comments