Skip to content

Commit a9e90d9

Browse files
zx2c4davem330
authored andcommitted
wireguard: noise: separate receive counter from send counter
In "wireguard: queueing: preserve flow hash across packet scrubbing", we were required to slightly increase the size of the receive replay counter to something still fairly small, but an increase nonetheless. It turns out that we can recoup some of the additional memory overhead by splitting up the prior union type into two distinct types. Before, we used the same "noise_counter" union for both sending and receiving, with sending just using a simple atomic64_t, while receiving used the full replay counter checker. This meant that most of the memory being allocated for the sending counter was being wasted. Since the old "noise_counter" type increased in size in the prior commit, now is a good time to split up that union type into a distinct "noise_replay_ counter" for receiving and a boring atomic64_t for sending, each using neither more nor less memory than required. Also, since sometimes the replay counter is accessed without necessitating additional accesses to the bitmap, we can reduce cache misses by hoisting the always-necessary lock above the bitmap in the struct layout. We also change a "noise_replay_counter" stack allocation to kmalloc in a -DDEBUG selftest so that KASAN doesn't trigger a stack frame warning. All and all, removing a bit of abstraction in this commit makes the code simpler and smaller, in addition to the motivating memory usage recuperation. For example, passing around raw "noise_symmetric_key" structs is something that really only makes sense within noise.c, in the one place where the sending and receiving keys can safely be thought of as the same type of object; subsequent to that, it's important that we uniformly access these through keypair->{sending,receiving}, where their distinct roles are always made explicit. So this patch allows us to draw that distinction clearly as well. Fixes: e7096c1 ("net: WireGuard secure network tunnel") Signed-off-by: Jason A. Donenfeld <[email protected]> Signed-off-by: David S. Miller <[email protected]>
1 parent c78a0b4 commit a9e90d9

File tree

5 files changed

+48
-53
lines changed

5 files changed

+48
-53
lines changed

drivers/net/wireguard/noise.c

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ static struct noise_keypair *keypair_create(struct wg_peer *peer)
104104

105105
if (unlikely(!keypair))
106106
return NULL;
107+
spin_lock_init(&keypair->receiving_counter.lock);
107108
keypair->internal_id = atomic64_inc_return(&keypair_counter);
108109
keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
109110
keypair->entry.peer = peer;
@@ -358,25 +359,16 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
358359
memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
359360
}
360361

361-
static void symmetric_key_init(struct noise_symmetric_key *key)
362-
{
363-
spin_lock_init(&key->counter.receive.lock);
364-
atomic64_set(&key->counter.counter, 0);
365-
memset(key->counter.receive.backtrack, 0,
366-
sizeof(key->counter.receive.backtrack));
367-
key->birthdate = ktime_get_coarse_boottime_ns();
368-
key->is_valid = true;
369-
}
370-
371362
static void derive_keys(struct noise_symmetric_key *first_dst,
372363
struct noise_symmetric_key *second_dst,
373364
const u8 chaining_key[NOISE_HASH_LEN])
374365
{
366+
u64 birthdate = ktime_get_coarse_boottime_ns();
375367
kdf(first_dst->key, second_dst->key, NULL, NULL,
376368
NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
377369
chaining_key);
378-
symmetric_key_init(first_dst);
379-
symmetric_key_init(second_dst);
370+
first_dst->birthdate = second_dst->birthdate = birthdate;
371+
first_dst->is_valid = second_dst->is_valid = true;
380372
}
381373

382374
static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],

drivers/net/wireguard/noise.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,24 @@
1515
#include <linux/mutex.h>
1616
#include <linux/kref.h>
1717

18-
union noise_counter {
19-
struct {
20-
u64 counter;
21-
unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
22-
spinlock_t lock;
23-
} receive;
24-
atomic64_t counter;
18+
struct noise_replay_counter {
19+
u64 counter;
20+
spinlock_t lock;
21+
unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
2522
};
2623

2724
struct noise_symmetric_key {
2825
u8 key[NOISE_SYMMETRIC_KEY_LEN];
29-
union noise_counter counter;
3026
u64 birthdate;
3127
bool is_valid;
3228
};
3329

3430
struct noise_keypair {
3531
struct index_hashtable_entry entry;
3632
struct noise_symmetric_key sending;
33+
atomic64_t sending_counter;
3734
struct noise_symmetric_key receiving;
35+
struct noise_replay_counter receiving_counter;
3836
__le32 remote_index;
3937
bool i_am_the_initiator;
4038
struct kref refcount;

drivers/net/wireguard/receive.c

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -245,20 +245,20 @@ static void keep_key_fresh(struct wg_peer *peer)
245245
}
246246
}
247247

248-
static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
248+
static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
249249
{
250250
struct scatterlist sg[MAX_SKB_FRAGS + 8];
251251
struct sk_buff *trailer;
252252
unsigned int offset;
253253
int num_frags;
254254

255-
if (unlikely(!key))
255+
if (unlikely(!keypair))
256256
return false;
257257

258-
if (unlikely(!READ_ONCE(key->is_valid) ||
259-
wg_birthdate_has_expired(key->birthdate, REJECT_AFTER_TIME) ||
260-
key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
261-
WRITE_ONCE(key->is_valid, false);
258+
if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
259+
wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
260+
keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
261+
WRITE_ONCE(keypair->receiving.is_valid, false);
262262
return false;
263263
}
264264

@@ -283,7 +283,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
283283

284284
if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
285285
PACKET_CB(skb)->nonce,
286-
key->key))
286+
keypair->receiving.key))
287287
return false;
288288

289289
/* Another ugly situation of pushing and pulling the header so as to
@@ -298,41 +298,41 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
298298
}
299299

300300
/* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
301-
static bool counter_validate(union noise_counter *counter, u64 their_counter)
301+
static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter)
302302
{
303303
unsigned long index, index_current, top, i;
304304
bool ret = false;
305305

306-
spin_lock_bh(&counter->receive.lock);
306+
spin_lock_bh(&counter->lock);
307307

308-
if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 ||
308+
if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 ||
309309
their_counter >= REJECT_AFTER_MESSAGES))
310310
goto out;
311311

312312
++their_counter;
313313

314314
if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
315-
counter->receive.counter))
315+
counter->counter))
316316
goto out;
317317

318318
index = their_counter >> ilog2(BITS_PER_LONG);
319319

320-
if (likely(their_counter > counter->receive.counter)) {
321-
index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
320+
if (likely(their_counter > counter->counter)) {
321+
index_current = counter->counter >> ilog2(BITS_PER_LONG);
322322
top = min_t(unsigned long, index - index_current,
323323
COUNTER_BITS_TOTAL / BITS_PER_LONG);
324324
for (i = 1; i <= top; ++i)
325-
counter->receive.backtrack[(i + index_current) &
325+
counter->backtrack[(i + index_current) &
326326
((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
327-
counter->receive.counter = their_counter;
327+
counter->counter = their_counter;
328328
}
329329

330330
index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
331331
ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
332-
&counter->receive.backtrack[index]);
332+
&counter->backtrack[index]);
333333

334334
out:
335-
spin_unlock_bh(&counter->receive.lock);
335+
spin_unlock_bh(&counter->lock);
336336
return ret;
337337
}
338338

@@ -472,12 +472,12 @@ int wg_packet_rx_poll(struct napi_struct *napi, int budget)
472472
if (unlikely(state != PACKET_STATE_CRYPTED))
473473
goto next;
474474

475-
if (unlikely(!counter_validate(&keypair->receiving.counter,
475+
if (unlikely(!counter_validate(&keypair->receiving_counter,
476476
PACKET_CB(skb)->nonce))) {
477477
net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
478478
peer->device->dev->name,
479479
PACKET_CB(skb)->nonce,
480-
keypair->receiving.counter.receive.counter);
480+
keypair->receiving_counter.counter);
481481
goto next;
482482
}
483483

@@ -511,8 +511,8 @@ void wg_packet_decrypt_worker(struct work_struct *work)
511511
struct sk_buff *skb;
512512

513513
while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
514-
enum packet_state state = likely(decrypt_packet(skb,
515-
&PACKET_CB(skb)->keypair->receiving)) ?
514+
enum packet_state state =
515+
likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ?
516516
PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
517517
wg_queue_enqueue_per_peer_napi(skb, state);
518518
if (need_resched())

drivers/net/wireguard/selftest/counter.c

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,24 @@
66
#ifdef DEBUG
77
bool __init wg_packet_counter_selftest(void)
88
{
9+
struct noise_replay_counter *counter;
910
unsigned int test_num = 0, i;
10-
union noise_counter counter;
1111
bool success = true;
1212

13-
#define T_INIT do { \
14-
memset(&counter, 0, sizeof(union noise_counter)); \
15-
spin_lock_init(&counter.receive.lock); \
13+
counter = kmalloc(sizeof(*counter), GFP_KERNEL);
14+
if (unlikely(!counter)) {
15+
pr_err("nonce counter self-test malloc: FAIL\n");
16+
return false;
17+
}
18+
19+
#define T_INIT do { \
20+
memset(counter, 0, sizeof(*counter)); \
21+
spin_lock_init(&counter->lock); \
1622
} while (0)
1723
#define T_LIM (COUNTER_WINDOW_SIZE + 1)
1824
#define T(n, v) do { \
1925
++test_num; \
20-
if (counter_validate(&counter, n) != (v)) { \
26+
if (counter_validate(counter, n) != (v)) { \
2127
pr_err("nonce counter self-test %u: FAIL\n", \
2228
test_num); \
2329
success = false; \
@@ -99,6 +105,7 @@ bool __init wg_packet_counter_selftest(void)
99105

100106
if (success)
101107
pr_info("nonce counter self-tests: pass\n");
108+
kfree(counter);
102109
return success;
103110
}
104111
#endif

drivers/net/wireguard/send.c

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ static void keep_key_fresh(struct wg_peer *peer)
129129
rcu_read_lock_bh();
130130
keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
131131
send = keypair && READ_ONCE(keypair->sending.is_valid) &&
132-
(atomic64_read(&keypair->sending.counter.counter) > REKEY_AFTER_MESSAGES ||
132+
(atomic64_read(&keypair->sending_counter) > REKEY_AFTER_MESSAGES ||
133133
(keypair->i_am_the_initiator &&
134134
wg_birthdate_has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME)));
135135
rcu_read_unlock_bh();
@@ -349,7 +349,6 @@ void wg_packet_purge_staged_packets(struct wg_peer *peer)
349349

350350
void wg_packet_send_staged_packets(struct wg_peer *peer)
351351
{
352-
struct noise_symmetric_key *key;
353352
struct noise_keypair *keypair;
354353
struct sk_buff_head packets;
355354
struct sk_buff *skb;
@@ -369,10 +368,9 @@ void wg_packet_send_staged_packets(struct wg_peer *peer)
369368
rcu_read_unlock_bh();
370369
if (unlikely(!keypair))
371370
goto out_nokey;
372-
key = &keypair->sending;
373-
if (unlikely(!READ_ONCE(key->is_valid)))
371+
if (unlikely(!READ_ONCE(keypair->sending.is_valid)))
374372
goto out_nokey;
375-
if (unlikely(wg_birthdate_has_expired(key->birthdate,
373+
if (unlikely(wg_birthdate_has_expired(keypair->sending.birthdate,
376374
REJECT_AFTER_TIME)))
377375
goto out_invalid;
378376

@@ -387,7 +385,7 @@ void wg_packet_send_staged_packets(struct wg_peer *peer)
387385
*/
388386
PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb);
389387
PACKET_CB(skb)->nonce =
390-
atomic64_inc_return(&key->counter.counter) - 1;
388+
atomic64_inc_return(&keypair->sending_counter) - 1;
391389
if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES))
392390
goto out_invalid;
393391
}
@@ -399,7 +397,7 @@ void wg_packet_send_staged_packets(struct wg_peer *peer)
399397
return;
400398

401399
out_invalid:
402-
WRITE_ONCE(key->is_valid, false);
400+
WRITE_ONCE(keypair->sending.is_valid, false);
403401
out_nokey:
404402
wg_noise_keypair_put(keypair, false);
405403

0 commit comments

Comments
 (0)