Skip to content

Commit 8e33d5d

Browse files
author
Alexei Starovoitov
committed
Merge branch 'bpf-fix-backward-progress-bug-in-bpf_iter_udp'
Martin KaFai Lau says: ==================== bpf: Fix backward progress bug in bpf_iter_udp From: Martin KaFai Lau <[email protected]> This patch set fixes an issue in bpf_iter_udp that makes backward progress and prevents the user space process from finishing. There is a test at the end to reproduce the bug. Please see individual patches for details. v3: - Fixed the iter_fd check and local_port check in the patch 3 selftest. (Yonghong) - Moved jhash2 to test_jhash.h in the patch 3. (Yonghong) - Added explanation in the bucket selection in the patch 3. (Yonghong) v2: - Added patch 1 to fix another bug that goes back to the previous bucket - Simplify the fix in patch 2 to always reset iter->offset to 0 - Add a test case to close all udp_sk in a bucket while in the middle of the iteration. ==================== Link: https://lore.kernel.org/r/[email protected] Signed-off-by: Alexei Starovoitov <[email protected]>
2 parents 894d750 + dbd7db7 commit 8e33d5d

File tree

5 files changed

+270
-12
lines changed

5 files changed

+270
-12
lines changed

net/ipv4/udp.c

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3137,16 +3137,18 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
31373137
struct bpf_udp_iter_state *iter = seq->private;
31383138
struct udp_iter_state *state = &iter->state;
31393139
struct net *net = seq_file_net(seq);
3140+
int resume_bucket, resume_offset;
31403141
struct udp_table *udptable;
31413142
unsigned int batch_sks = 0;
31423143
bool resized = false;
31433144
struct sock *sk;
31443145

3146+
resume_bucket = state->bucket;
3147+
resume_offset = iter->offset;
3148+
31453149
/* The current batch is done, so advance the bucket. */
3146-
if (iter->st_bucket_done) {
3150+
if (iter->st_bucket_done)
31473151
state->bucket++;
3148-
iter->offset = 0;
3149-
}
31503152

31513153
udptable = udp_get_table_seq(seq, net);
31523154

@@ -3166,19 +3168,19 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
31663168
for (; state->bucket <= udptable->mask; state->bucket++) {
31673169
struct udp_hslot *hslot2 = &udptable->hash2[state->bucket];
31683170

3169-
if (hlist_empty(&hslot2->head)) {
3170-
iter->offset = 0;
3171+
if (hlist_empty(&hslot2->head))
31713172
continue;
3172-
}
31733173

3174+
iter->offset = 0;
31743175
spin_lock_bh(&hslot2->lock);
31753176
udp_portaddr_for_each_entry(sk, &hslot2->head) {
31763177
if (seq_sk_match(seq, sk)) {
31773178
/* Resume from the last iterated socket at the
31783179
* offset in the bucket before iterator was stopped.
31793180
*/
3180-
if (iter->offset) {
3181-
--iter->offset;
3181+
if (state->bucket == resume_bucket &&
3182+
iter->offset < resume_offset) {
3183+
++iter->offset;
31823184
continue;
31833185
}
31843186
if (iter->end_sk < iter->max_sk) {
@@ -3192,9 +3194,6 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
31923194

31933195
if (iter->end_sk)
31943196
break;
3195-
3196-
/* Reset the current bucket's offset before moving to the next bucket. */
3197-
iter->offset = 0;
31983197
}
31993198

32003199
/* All done: no batch made. */
@@ -3213,7 +3212,6 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
32133212
/* After allocating a larger batch, retry one more time to grab
32143213
* the whole bucket.
32153214
*/
3216-
state->bucket--;
32173215
goto again;
32183216
}
32193217
done:
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// SPDX-License-Identifier: GPL-2.0
2+
// Copyright (c) 2024 Meta
3+
4+
#include <test_progs.h>
5+
#include "network_helpers.h"
6+
#include "sock_iter_batch.skel.h"
7+
8+
#define TEST_NS "sock_iter_batch_netns"
9+
10+
static const int nr_soreuse = 4;
11+
12+
static void do_test(int sock_type, bool onebyone)
13+
{
14+
int err, i, nread, to_read, total_read, iter_fd = -1;
15+
int first_idx, second_idx, indices[nr_soreuse];
16+
struct bpf_link *link = NULL;
17+
struct sock_iter_batch *skel;
18+
int *fds[2] = {};
19+
20+
skel = sock_iter_batch__open();
21+
if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open"))
22+
return;
23+
24+
/* Prepare 2 buckets of sockets in the kernel hashtable */
25+
for (i = 0; i < ARRAY_SIZE(fds); i++) {
26+
int local_port;
27+
28+
fds[i] = start_reuseport_server(AF_INET6, sock_type, "::1", 0, 0,
29+
nr_soreuse);
30+
if (!ASSERT_OK_PTR(fds[i], "start_reuseport_server"))
31+
goto done;
32+
local_port = get_socket_local_port(*fds[i]);
33+
if (!ASSERT_GE(local_port, 0, "get_socket_local_port"))
34+
goto done;
35+
skel->rodata->ports[i] = ntohs(local_port);
36+
}
37+
38+
err = sock_iter_batch__load(skel);
39+
if (!ASSERT_OK(err, "sock_iter_batch__load"))
40+
goto done;
41+
42+
link = bpf_program__attach_iter(sock_type == SOCK_STREAM ?
43+
skel->progs.iter_tcp_soreuse :
44+
skel->progs.iter_udp_soreuse,
45+
NULL);
46+
if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter"))
47+
goto done;
48+
49+
iter_fd = bpf_iter_create(bpf_link__fd(link));
50+
if (!ASSERT_GE(iter_fd, 0, "bpf_iter_create"))
51+
goto done;
52+
53+
/* Test reading a bucket (either from fds[0] or fds[1]).
54+
* Only read "nr_soreuse - 1" number of sockets
55+
* from a bucket and leave one socket out from
56+
* that bucket on purpose.
57+
*/
58+
to_read = (nr_soreuse - 1) * sizeof(*indices);
59+
total_read = 0;
60+
first_idx = -1;
61+
do {
62+
nread = read(iter_fd, indices, onebyone ? sizeof(*indices) : to_read);
63+
if (nread <= 0 || nread % sizeof(*indices))
64+
break;
65+
total_read += nread;
66+
67+
if (first_idx == -1)
68+
first_idx = indices[0];
69+
for (i = 0; i < nread / sizeof(*indices); i++)
70+
ASSERT_EQ(indices[i], first_idx, "first_idx");
71+
} while (total_read < to_read);
72+
ASSERT_EQ(nread, onebyone ? sizeof(*indices) : to_read, "nread");
73+
ASSERT_EQ(total_read, to_read, "total_read");
74+
75+
free_fds(fds[first_idx], nr_soreuse);
76+
fds[first_idx] = NULL;
77+
78+
/* Read the "whole" second bucket */
79+
to_read = nr_soreuse * sizeof(*indices);
80+
total_read = 0;
81+
second_idx = !first_idx;
82+
do {
83+
nread = read(iter_fd, indices, onebyone ? sizeof(*indices) : to_read);
84+
if (nread <= 0 || nread % sizeof(*indices))
85+
break;
86+
total_read += nread;
87+
88+
for (i = 0; i < nread / sizeof(*indices); i++)
89+
ASSERT_EQ(indices[i], second_idx, "second_idx");
90+
} while (total_read <= to_read);
91+
ASSERT_EQ(nread, 0, "nread");
92+
/* Both so_reuseport ports should be in different buckets, so
93+
* total_read must equal to the expected to_read.
94+
*
95+
* For a very unlikely case, both ports collide at the same bucket,
96+
* the bucket offset (i.e. 3) will be skipped and it cannot
97+
* expect the to_read number of bytes.
98+
*/
99+
if (skel->bss->bucket[0] != skel->bss->bucket[1])
100+
ASSERT_EQ(total_read, to_read, "total_read");
101+
102+
done:
103+
for (i = 0; i < ARRAY_SIZE(fds); i++)
104+
free_fds(fds[i], nr_soreuse);
105+
if (iter_fd < 0)
106+
close(iter_fd);
107+
bpf_link__destroy(link);
108+
sock_iter_batch__destroy(skel);
109+
}
110+
111+
void test_sock_iter_batch(void)
112+
{
113+
struct nstoken *nstoken = NULL;
114+
115+
SYS_NOFAIL("ip netns del " TEST_NS " &> /dev/null");
116+
SYS(done, "ip netns add %s", TEST_NS);
117+
SYS(done, "ip -net %s link set dev lo up", TEST_NS);
118+
119+
nstoken = open_netns(TEST_NS);
120+
if (!ASSERT_OK_PTR(nstoken, "open_netns"))
121+
goto done;
122+
123+
if (test__start_subtest("tcp")) {
124+
do_test(SOCK_STREAM, true);
125+
do_test(SOCK_STREAM, false);
126+
}
127+
if (test__start_subtest("udp")) {
128+
do_test(SOCK_DGRAM, true);
129+
do_test(SOCK_DGRAM, false);
130+
}
131+
close_netns(nstoken);
132+
133+
done:
134+
SYS_NOFAIL("ip netns del " TEST_NS " &> /dev/null");
135+
}

tools/testing/selftests/bpf/progs/bpf_tracing_net.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272
#define inet_rcv_saddr sk.__sk_common.skc_rcv_saddr
7373
#define inet_dport sk.__sk_common.skc_dport
7474

75+
#define udp_portaddr_hash inet.sk.__sk_common.skc_u16hashes[1]
76+
7577
#define ir_loc_addr req.__req_common.skc_rcv_saddr
7678
#define ir_num req.__req_common.skc_num
7779
#define ir_rmt_addr req.__req_common.skc_daddr
@@ -85,6 +87,7 @@
8587
#define sk_rmem_alloc sk_backlog.rmem_alloc
8688
#define sk_refcnt __sk_common.skc_refcnt
8789
#define sk_state __sk_common.skc_state
90+
#define sk_net __sk_common.skc_net
8891
#define sk_v6_daddr __sk_common.skc_v6_daddr
8992
#define sk_v6_rcv_saddr __sk_common.skc_v6_rcv_saddr
9093
#define sk_flags __sk_common.skc_flags
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// SPDX-License-Identifier: GPL-2.0
2+
// Copyright (c) 2024 Meta
3+
4+
#include "vmlinux.h"
5+
#include <bpf/bpf_helpers.h>
6+
#include <bpf/bpf_core_read.h>
7+
#include <bpf/bpf_endian.h>
8+
#include "bpf_tracing_net.h"
9+
#include "bpf_kfuncs.h"
10+
11+
#define ATTR __always_inline
12+
#include "test_jhash.h"
13+
14+
static bool ipv6_addr_loopback(const struct in6_addr *a)
15+
{
16+
return (a->s6_addr32[0] | a->s6_addr32[1] |
17+
a->s6_addr32[2] | (a->s6_addr32[3] ^ bpf_htonl(1))) == 0;
18+
}
19+
20+
volatile const __u16 ports[2];
21+
unsigned int bucket[2];
22+
23+
SEC("iter/tcp")
24+
int iter_tcp_soreuse(struct bpf_iter__tcp *ctx)
25+
{
26+
struct sock *sk = (struct sock *)ctx->sk_common;
27+
struct inet_hashinfo *hinfo;
28+
unsigned int hash;
29+
struct net *net;
30+
int idx;
31+
32+
if (!sk)
33+
return 0;
34+
35+
sk = bpf_rdonly_cast(sk, bpf_core_type_id_kernel(struct sock));
36+
if (sk->sk_family != AF_INET6 ||
37+
sk->sk_state != TCP_LISTEN ||
38+
!ipv6_addr_loopback(&sk->sk_v6_rcv_saddr))
39+
return 0;
40+
41+
if (sk->sk_num == ports[0])
42+
idx = 0;
43+
else if (sk->sk_num == ports[1])
44+
idx = 1;
45+
else
46+
return 0;
47+
48+
/* bucket selection as in inet_lhash2_bucket_sk() */
49+
net = sk->sk_net.net;
50+
hash = jhash2(sk->sk_v6_rcv_saddr.s6_addr32, 4, net->hash_mix);
51+
hash ^= sk->sk_num;
52+
hinfo = net->ipv4.tcp_death_row.hashinfo;
53+
bucket[idx] = hash & hinfo->lhash2_mask;
54+
bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx));
55+
56+
return 0;
57+
}
58+
59+
#define udp_sk(ptr) container_of(ptr, struct udp_sock, inet.sk)
60+
61+
SEC("iter/udp")
62+
int iter_udp_soreuse(struct bpf_iter__udp *ctx)
63+
{
64+
struct sock *sk = (struct sock *)ctx->udp_sk;
65+
struct udp_table *udptable;
66+
int idx;
67+
68+
if (!sk)
69+
return 0;
70+
71+
sk = bpf_rdonly_cast(sk, bpf_core_type_id_kernel(struct sock));
72+
if (sk->sk_family != AF_INET6 ||
73+
!ipv6_addr_loopback(&sk->sk_v6_rcv_saddr))
74+
return 0;
75+
76+
if (sk->sk_num == ports[0])
77+
idx = 0;
78+
else if (sk->sk_num == ports[1])
79+
idx = 1;
80+
else
81+
return 0;
82+
83+
/* bucket selection as in udp_hashslot2() */
84+
udptable = sk->sk_net.net->ipv4.udp_table;
85+
bucket[idx] = udp_sk(sk)->udp_portaddr_hash & udptable->mask;
86+
bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx));
87+
88+
return 0;
89+
}
90+
91+
char _license[] SEC("license") = "GPL";

tools/testing/selftests/bpf/progs/test_jhash.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,34 @@ u32 jhash(const void *key, u32 length, u32 initval)
6969

7070
return c;
7171
}
72+
73+
static __always_inline u32 jhash2(const u32 *k, u32 length, u32 initval)
74+
{
75+
u32 a, b, c;
76+
77+
/* Set up the internal state */
78+
a = b = c = JHASH_INITVAL + (length<<2) + initval;
79+
80+
/* Handle most of the key */
81+
while (length > 3) {
82+
a += k[0];
83+
b += k[1];
84+
c += k[2];
85+
__jhash_mix(a, b, c);
86+
length -= 3;
87+
k += 3;
88+
}
89+
90+
/* Handle the last 3 u32's */
91+
switch (length) {
92+
case 3: c += k[2];
93+
case 2: b += k[1];
94+
case 1: a += k[0];
95+
__jhash_final(a, b, c);
96+
break;
97+
case 0: /* Nothing left to add */
98+
break;
99+
}
100+
101+
return c;
102+
}

0 commit comments

Comments
 (0)