Skip to content

Commit 2bcd350

Browse files
committed
Merge branch 'wg-fixes'
Jason A. Donenfeld says: ==================== wireguard fixes for 5.8-rc3 This series contains two fixes, one cosmetic and one quite important: 1) Avoid the `if ((x = f()) == y)` pattern, from Frank Werner-Krippendorf. 2) Mitigate a potential memory leak by creating circular netns references, while also making the netns semantics a bit more robust. Patch (2) has a "Fixes:" line and should be backported to stable. ==================== Signed-off-by: David S. Miller <[email protected]>
2 parents f7fb92a + 900575a commit 2bcd350

File tree

6 files changed

+69
-48
lines changed

6 files changed

+69
-48
lines changed

drivers/net/wireguard/device.c

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,18 @@ static int wg_open(struct net_device *dev)
4545
if (dev_v6)
4646
dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE;
4747

48+
mutex_lock(&wg->device_update_lock);
4849
ret = wg_socket_init(wg, wg->incoming_port);
4950
if (ret < 0)
50-
return ret;
51-
mutex_lock(&wg->device_update_lock);
51+
goto out;
5252
list_for_each_entry(peer, &wg->peer_list, peer_list) {
5353
wg_packet_send_staged_packets(peer);
5454
if (peer->persistent_keepalive_interval)
5555
wg_packet_send_keepalive(peer);
5656
}
57+
out:
5758
mutex_unlock(&wg->device_update_lock);
58-
return 0;
59+
return ret;
5960
}
6061

6162
#ifdef CONFIG_PM_SLEEP
@@ -225,6 +226,7 @@ static void wg_destruct(struct net_device *dev)
225226
list_del(&wg->device_list);
226227
rtnl_unlock();
227228
mutex_lock(&wg->device_update_lock);
229+
rcu_assign_pointer(wg->creating_net, NULL);
228230
wg->incoming_port = 0;
229231
wg_socket_reinit(wg, NULL, NULL);
230232
/* The final references are cleared in the below calls to destroy_workqueue. */
@@ -240,13 +242,11 @@ static void wg_destruct(struct net_device *dev)
240242
skb_queue_purge(&wg->incoming_handshakes);
241243
free_percpu(dev->tstats);
242244
free_percpu(wg->incoming_handshakes_worker);
243-
if (wg->have_creating_net_ref)
244-
put_net(wg->creating_net);
245245
kvfree(wg->index_hashtable);
246246
kvfree(wg->peer_hashtable);
247247
mutex_unlock(&wg->device_update_lock);
248248

249-
pr_debug("%s: Interface deleted\n", dev->name);
249+
pr_debug("%s: Interface destroyed\n", dev->name);
250250
free_netdev(dev);
251251
}
252252

@@ -292,7 +292,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
292292
struct wg_device *wg = netdev_priv(dev);
293293
int ret = -ENOMEM;
294294

295-
wg->creating_net = src_net;
295+
rcu_assign_pointer(wg->creating_net, src_net);
296296
init_rwsem(&wg->static_identity.lock);
297297
mutex_init(&wg->socket_update_lock);
298298
mutex_init(&wg->device_update_lock);
@@ -393,30 +393,26 @@ static struct rtnl_link_ops link_ops __read_mostly = {
393393
.newlink = wg_newlink,
394394
};
395395

396-
static int wg_netdevice_notification(struct notifier_block *nb,
397-
unsigned long action, void *data)
396+
static void wg_netns_pre_exit(struct net *net)
398397
{
399-
struct net_device *dev = ((struct netdev_notifier_info *)data)->dev;
400-
struct wg_device *wg = netdev_priv(dev);
401-
402-
ASSERT_RTNL();
403-
404-
if (action != NETDEV_REGISTER || dev->netdev_ops != &netdev_ops)
405-
return 0;
398+
struct wg_device *wg;
406399

407-
if (dev_net(dev) == wg->creating_net && wg->have_creating_net_ref) {
408-
put_net(wg->creating_net);
409-
wg->have_creating_net_ref = false;
410-
} else if (dev_net(dev) != wg->creating_net &&
411-
!wg->have_creating_net_ref) {
412-
wg->have_creating_net_ref = true;
413-
get_net(wg->creating_net);
400+
rtnl_lock();
401+
list_for_each_entry(wg, &device_list, device_list) {
402+
if (rcu_access_pointer(wg->creating_net) == net) {
403+
pr_debug("%s: Creating namespace exiting\n", wg->dev->name);
404+
netif_carrier_off(wg->dev);
405+
mutex_lock(&wg->device_update_lock);
406+
rcu_assign_pointer(wg->creating_net, NULL);
407+
wg_socket_reinit(wg, NULL, NULL);
408+
mutex_unlock(&wg->device_update_lock);
409+
}
414410
}
415-
return 0;
411+
rtnl_unlock();
416412
}
417413

418-
static struct notifier_block netdevice_notifier = {
419-
.notifier_call = wg_netdevice_notification
414+
static struct pernet_operations pernet_ops = {
415+
.pre_exit = wg_netns_pre_exit
420416
};
421417

422418
int __init wg_device_init(void)
@@ -429,18 +425,18 @@ int __init wg_device_init(void)
429425
return ret;
430426
#endif
431427

432-
ret = register_netdevice_notifier(&netdevice_notifier);
428+
ret = register_pernet_device(&pernet_ops);
433429
if (ret)
434430
goto error_pm;
435431

436432
ret = rtnl_link_register(&link_ops);
437433
if (ret)
438-
goto error_netdevice;
434+
goto error_pernet;
439435

440436
return 0;
441437

442-
error_netdevice:
443-
unregister_netdevice_notifier(&netdevice_notifier);
438+
error_pernet:
439+
unregister_pernet_device(&pernet_ops);
444440
error_pm:
445441
#ifdef CONFIG_PM_SLEEP
446442
unregister_pm_notifier(&pm_notifier);
@@ -451,7 +447,7 @@ int __init wg_device_init(void)
451447
void wg_device_uninit(void)
452448
{
453449
rtnl_link_unregister(&link_ops);
454-
unregister_netdevice_notifier(&netdevice_notifier);
450+
unregister_pernet_device(&pernet_ops);
455451
#ifdef CONFIG_PM_SLEEP
456452
unregister_pm_notifier(&pm_notifier);
457453
#endif

drivers/net/wireguard/device.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct wg_device {
4040
struct net_device *dev;
4141
struct crypt_queue encrypt_queue, decrypt_queue;
4242
struct sock __rcu *sock4, *sock6;
43-
struct net *creating_net;
43+
struct net __rcu *creating_net;
4444
struct noise_static_identity static_identity;
4545
struct workqueue_struct *handshake_receive_wq, *handshake_send_wq;
4646
struct workqueue_struct *packet_crypt_wq;
@@ -56,7 +56,6 @@ struct wg_device {
5656
unsigned int num_peers, device_update_gen;
5757
u32 fwmark;
5858
u16 incoming_port;
59-
bool have_creating_net_ref;
6059
};
6160

6261
int wg_device_init(void);

drivers/net/wireguard/netlink.c

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -511,11 +511,15 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
511511
if (flags & ~__WGDEVICE_F_ALL)
512512
goto out;
513513

514-
ret = -EPERM;
515-
if ((info->attrs[WGDEVICE_A_LISTEN_PORT] ||
516-
info->attrs[WGDEVICE_A_FWMARK]) &&
517-
!ns_capable(wg->creating_net->user_ns, CAP_NET_ADMIN))
518-
goto out;
514+
if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
515+
struct net *net;
516+
rcu_read_lock();
517+
net = rcu_dereference(wg->creating_net);
518+
ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
519+
rcu_read_unlock();
520+
if (ret)
521+
goto out;
522+
}
519523

520524
++wg->device_update_gen;
521525

drivers/net/wireguard/noise.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,8 +617,8 @@ wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
617617
memcpy(handshake->hash, hash, NOISE_HASH_LEN);
618618
memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
619619
handshake->remote_index = src->sender_index;
620-
if ((s64)(handshake->last_initiation_consumption -
621-
(initiation_consumption = ktime_get_coarse_boottime_ns())) < 0)
620+
initiation_consumption = ktime_get_coarse_boottime_ns();
621+
if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
622622
handshake->last_initiation_consumption = initiation_consumption;
623623
handshake->state = HANDSHAKE_CONSUMED_INITIATION;
624624
up_write(&handshake->lock);

drivers/net/wireguard/socket.c

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ static void set_sock_opts(struct socket *sock)
347347

348348
int wg_socket_init(struct wg_device *wg, u16 port)
349349
{
350+
struct net *net;
350351
int ret;
351352
struct udp_tunnel_sock_cfg cfg = {
352353
.sk_user_data = wg,
@@ -371,37 +372,47 @@ int wg_socket_init(struct wg_device *wg, u16 port)
371372
};
372373
#endif
373374

375+
rcu_read_lock();
376+
net = rcu_dereference(wg->creating_net);
377+
net = net ? maybe_get_net(net) : NULL;
378+
rcu_read_unlock();
379+
if (unlikely(!net))
380+
return -ENONET;
381+
374382
#if IS_ENABLED(CONFIG_IPV6)
375383
retry:
376384
#endif
377385

378-
ret = udp_sock_create(wg->creating_net, &port4, &new4);
386+
ret = udp_sock_create(net, &port4, &new4);
379387
if (ret < 0) {
380388
pr_err("%s: Could not create IPv4 socket\n", wg->dev->name);
381-
return ret;
389+
goto out;
382390
}
383391
set_sock_opts(new4);
384-
setup_udp_tunnel_sock(wg->creating_net, new4, &cfg);
392+
setup_udp_tunnel_sock(net, new4, &cfg);
385393

386394
#if IS_ENABLED(CONFIG_IPV6)
387395
if (ipv6_mod_enabled()) {
388396
port6.local_udp_port = inet_sk(new4->sk)->inet_sport;
389-
ret = udp_sock_create(wg->creating_net, &port6, &new6);
397+
ret = udp_sock_create(net, &port6, &new6);
390398
if (ret < 0) {
391399
udp_tunnel_sock_release(new4);
392400
if (ret == -EADDRINUSE && !port && retries++ < 100)
393401
goto retry;
394402
pr_err("%s: Could not create IPv6 socket\n",
395403
wg->dev->name);
396-
return ret;
404+
goto out;
397405
}
398406
set_sock_opts(new6);
399-
setup_udp_tunnel_sock(wg->creating_net, new6, &cfg);
407+
setup_udp_tunnel_sock(net, new6, &cfg);
400408
}
401409
#endif
402410

403411
wg_socket_reinit(wg, new4->sk, new6 ? new6->sk : NULL);
404-
return 0;
412+
ret = 0;
413+
out:
414+
put_net(net);
415+
return ret;
405416
}
406417

407418
void wg_socket_reinit(struct wg_device *wg, struct sock *new4,

tools/testing/selftests/wireguard/netns.sh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,20 @@ ip0 link set wg0 up
587587
kill $ncat_pid
588588
ip0 link del wg0
589589

590+
# Ensure there aren't circular reference loops
591+
ip1 link add wg1 type wireguard
592+
ip2 link add wg2 type wireguard
593+
ip1 link set wg1 netns $netns2
594+
ip2 link set wg2 netns $netns1
595+
pp ip netns delete $netns1
596+
pp ip netns delete $netns2
597+
pp ip netns add $netns1
598+
pp ip netns add $netns2
599+
600+
sleep 2 # Wait for cleanup and grace periods
590601
declare -A objects
591602
while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
592-
[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
603+
[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ ?[0-9]*)\ .*(created|destroyed).* ]] || continue
593604
objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
594605
done < /dev/kmsg
595606
alldeleted=1

0 commit comments

Comments
 (0)