1
1
// SPDX-License-Identifier: GPL-2.0
2
2
// Copyright (c) 2024 Meta
3
3
4
+ #include <poll.h>
4
5
#include <test_progs.h>
5
6
#include "network_helpers.h"
6
7
#include "sock_iter_batch.skel.h"
@@ -153,8 +154,71 @@ static void check_n_were_seen_once(int *fds, int fds_len, int n,
153
154
ASSERT_EQ (seen_once , n , "seen_once" );
154
155
}
155
156
157
+ static int accept_from_one (struct pollfd * server_poll_fds ,
158
+ int server_poll_fds_len )
159
+ {
160
+ static const int poll_timeout_ms = 5000 ; /* 5s */
161
+ int ret ;
162
+ int i ;
163
+
164
+ ret = poll (server_poll_fds , server_poll_fds_len , poll_timeout_ms );
165
+ if (!ASSERT_EQ (ret , 1 , "poll" ))
166
+ return -1 ;
167
+
168
+ for (i = 0 ; i < server_poll_fds_len ; i ++ )
169
+ if (server_poll_fds [i ].revents & POLLIN )
170
+ return accept (server_poll_fds [i ].fd , NULL , NULL );
171
+
172
+ return -1 ;
173
+ }
174
+
175
+ static int * connect_to_server (int family , int sock_type , const char * addr ,
176
+ __u16 port , int nr_connects , int * server_fds ,
177
+ int server_fds_len )
178
+ {
179
+ struct pollfd * server_poll_fds = NULL ;
180
+ int * established_socks = NULL ;
181
+ int i ;
182
+
183
+ server_poll_fds = calloc (server_fds_len , sizeof (* server_poll_fds ));
184
+ if (!ASSERT_OK_PTR (server_poll_fds , "server_poll_fds" ))
185
+ return NULL ;
186
+
187
+ for (i = 0 ; i < server_fds_len ; i ++ ) {
188
+ server_poll_fds [i ].fd = server_fds [i ];
189
+ server_poll_fds [i ].events = POLLIN ;
190
+ }
191
+
192
+ i = 0 ;
193
+
194
+ established_socks = malloc (sizeof (* established_socks ) * nr_connects * 2 );
195
+ if (!ASSERT_OK_PTR (established_socks , "established_socks" ))
196
+ goto error ;
197
+
198
+ while (nr_connects -- ) {
199
+ established_socks [i ] = connect_to_addr_str (family , sock_type ,
200
+ addr , port , NULL );
201
+ if (!ASSERT_OK_FD (established_socks [i ], "connect_to_addr_str" ))
202
+ goto error ;
203
+ i ++ ;
204
+ established_socks [i ] = accept_from_one (server_poll_fds ,
205
+ server_fds_len );
206
+ if (!ASSERT_OK_FD (established_socks [i ], "accept_from_one" ))
207
+ goto error ;
208
+ i ++ ;
209
+ }
210
+
211
+ free (server_poll_fds );
212
+ return established_socks ;
213
+ error :
214
+ free_fds (established_socks , i );
215
+ free (server_poll_fds );
216
+ return NULL ;
217
+ }
218
+
156
219
static void remove_seen (int family , int sock_type , const char * addr , __u16 port ,
157
- int * socks , int socks_len , struct sock_count * counts ,
220
+ int * socks , int socks_len , int * established_socks ,
221
+ int established_socks_len , struct sock_count * counts ,
158
222
int counts_len , struct bpf_link * link , int iter_fd )
159
223
{
160
224
int close_idx ;
@@ -185,6 +249,7 @@ static void remove_seen(int family, int sock_type, const char *addr, __u16 port,
185
249
186
250
static void remove_unseen (int family , int sock_type , const char * addr ,
187
251
__u16 port , int * socks , int socks_len ,
252
+ int * established_socks , int established_socks_len ,
188
253
struct sock_count * counts , int counts_len ,
189
254
struct bpf_link * link , int iter_fd )
190
255
{
@@ -217,6 +282,7 @@ static void remove_unseen(int family, int sock_type, const char *addr,
217
282
218
283
static void remove_all (int family , int sock_type , const char * addr ,
219
284
__u16 port , int * socks , int socks_len ,
285
+ int * established_socks , int established_socks_len ,
220
286
struct sock_count * counts , int counts_len ,
221
287
struct bpf_link * link , int iter_fd )
222
288
{
@@ -244,7 +310,8 @@ static void remove_all(int family, int sock_type, const char *addr,
244
310
}
245
311
246
312
static void add_some (int family , int sock_type , const char * addr , __u16 port ,
247
- int * socks , int socks_len , struct sock_count * counts ,
313
+ int * socks , int socks_len , int * established_socks ,
314
+ int established_socks_len , struct sock_count * counts ,
248
315
int counts_len , struct bpf_link * link , int iter_fd )
249
316
{
250
317
int * new_socks = NULL ;
@@ -274,6 +341,7 @@ static void add_some(int family, int sock_type, const char *addr, __u16 port,
274
341
275
342
static void force_realloc (int family , int sock_type , const char * addr ,
276
343
__u16 port , int * socks , int socks_len ,
344
+ int * established_socks , int established_socks_len ,
277
345
struct sock_count * counts , int counts_len ,
278
346
struct bpf_link * link , int iter_fd )
279
347
{
@@ -302,10 +370,12 @@ static void force_realloc(int family, int sock_type, const char *addr,
302
370
303
371
struct test_case {
304
372
void (* test )(int family , int sock_type , const char * addr , __u16 port ,
305
- int * socks , int socks_len , struct sock_count * counts ,
373
+ int * socks , int socks_len , int * established_socks ,
374
+ int established_socks_len , struct sock_count * counts ,
306
375
int counts_len , struct bpf_link * link , int iter_fd );
307
376
const char * description ;
308
377
int ehash_buckets ;
378
+ int connections ;
309
379
int init_socks ;
310
380
int max_socks ;
311
381
int sock_type ;
@@ -416,6 +486,7 @@ static void do_resume_test(struct test_case *tc)
416
486
static const __u16 port = 10001 ;
417
487
struct nstoken * nstoken = NULL ;
418
488
struct bpf_link * link = NULL ;
489
+ int * established_fds = NULL ;
419
490
int err , iter_fd = -1 ;
420
491
const char * addr ;
421
492
int * fds = NULL ;
@@ -444,6 +515,14 @@ static void do_resume_test(struct test_case *tc)
444
515
tc -> init_socks );
445
516
if (!ASSERT_OK_PTR (fds , "start_reuseport_server" ))
446
517
goto done ;
518
+ if (tc -> connections ) {
519
+ established_fds = connect_to_server (tc -> family , tc -> sock_type ,
520
+ addr , port ,
521
+ tc -> connections , fds ,
522
+ tc -> init_socks );
523
+ if (!ASSERT_OK_PTR (established_fds , "connect_to_server" ))
524
+ goto done ;
525
+ }
447
526
skel -> rodata -> ports [0 ] = 0 ;
448
527
skel -> rodata -> ports [1 ] = 0 ;
449
528
skel -> rodata -> sf = tc -> family ;
@@ -465,13 +544,15 @@ static void do_resume_test(struct test_case *tc)
465
544
goto done ;
466
545
467
546
tc -> test (tc -> family , tc -> sock_type , addr , port , fds , tc -> init_socks ,
468
- counts , tc -> max_socks , link , iter_fd );
547
+ established_fds , tc -> connections * 2 , counts , tc -> max_socks ,
548
+ link , iter_fd );
469
549
done :
470
550
close_netns (nstoken );
471
551
SYS_NOFAIL ("ip netns del " TEST_CHILD_NS );
472
552
SYS_NOFAIL ("sysctl -w net.ipv4.tcp_child_ehash_entries=0" );
473
553
free (counts );
474
554
free_fds (fds , tc -> init_socks );
555
+ free_fds (established_fds , tc -> connections * 2 );
475
556
if (iter_fd >= 0 )
476
557
close (iter_fd );
477
558
bpf_link__destroy (link );
0 commit comments