6
6
#include "allowedips.h"
7
7
#include "peer.h"
8
8
9
+ static struct kmem_cache * node_cache ;
10
+
9
11
static void swap_endian (u8 * dst , const u8 * src , u8 bits )
10
12
{
11
13
if (bits == 32 ) {
@@ -28,8 +30,11 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
28
30
node -> bitlen = bits ;
29
31
memcpy (node -> bits , src , bits / 8U );
30
32
}
31
- #define CHOOSE_NODE (parent , key ) \
32
- parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
33
+
34
+ static inline u8 choose (struct allowedips_node * node , const u8 * key )
35
+ {
36
+ return (key [node -> bit_at_a ] >> node -> bit_at_b ) & 1 ;
37
+ }
33
38
34
39
static void push_rcu (struct allowedips_node * * stack ,
35
40
struct allowedips_node __rcu * p , unsigned int * len )
@@ -40,6 +45,11 @@ static void push_rcu(struct allowedips_node **stack,
40
45
}
41
46
}
42
47
48
+ static void node_free_rcu (struct rcu_head * rcu )
49
+ {
50
+ kmem_cache_free (node_cache , container_of (rcu , struct allowedips_node , rcu ));
51
+ }
52
+
43
53
static void root_free_rcu (struct rcu_head * rcu )
44
54
{
45
55
struct allowedips_node * node , * stack [128 ] = {
@@ -49,7 +59,7 @@ static void root_free_rcu(struct rcu_head *rcu)
49
59
while (len > 0 && (node = stack [-- len ])) {
50
60
push_rcu (stack , node -> bit [0 ], & len );
51
61
push_rcu (stack , node -> bit [1 ], & len );
52
- kfree ( node );
62
+ kmem_cache_free ( node_cache , node );
53
63
}
54
64
}
55
65
@@ -66,60 +76,6 @@ static void root_remove_peer_lists(struct allowedips_node *root)
66
76
}
67
77
}
68
78
69
- static void walk_remove_by_peer (struct allowedips_node __rcu * * top ,
70
- struct wg_peer * peer , struct mutex * lock )
71
- {
72
- #define REF (p ) rcu_access_pointer(p)
73
- #define DEREF (p ) rcu_dereference_protected(*(p), lockdep_is_held(lock))
74
- #define PUSH (p ) ({ \
75
- WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \
76
- stack[len++] = p; \
77
- })
78
-
79
- struct allowedips_node __rcu * * stack [128 ], * * nptr ;
80
- struct allowedips_node * node , * prev ;
81
- unsigned int len ;
82
-
83
- if (unlikely (!peer || !REF (* top )))
84
- return ;
85
-
86
- for (prev = NULL , len = 0 , PUSH (top ); len > 0 ; prev = node ) {
87
- nptr = stack [len - 1 ];
88
- node = DEREF (nptr );
89
- if (!node ) {
90
- -- len ;
91
- continue ;
92
- }
93
- if (!prev || REF (prev -> bit [0 ]) == node ||
94
- REF (prev -> bit [1 ]) == node ) {
95
- if (REF (node -> bit [0 ]))
96
- PUSH (& node -> bit [0 ]);
97
- else if (REF (node -> bit [1 ]))
98
- PUSH (& node -> bit [1 ]);
99
- } else if (REF (node -> bit [0 ]) == prev ) {
100
- if (REF (node -> bit [1 ]))
101
- PUSH (& node -> bit [1 ]);
102
- } else {
103
- if (rcu_dereference_protected (node -> peer ,
104
- lockdep_is_held (lock )) == peer ) {
105
- RCU_INIT_POINTER (node -> peer , NULL );
106
- list_del_init (& node -> peer_list );
107
- if (!node -> bit [0 ] || !node -> bit [1 ]) {
108
- rcu_assign_pointer (* nptr , DEREF (
109
- & node -> bit [!REF (node -> bit [0 ])]));
110
- kfree_rcu (node , rcu );
111
- node = DEREF (nptr );
112
- }
113
- }
114
- -- len ;
115
- }
116
- }
117
-
118
- #undef REF
119
- #undef DEREF
120
- #undef PUSH
121
- }
122
-
123
79
static unsigned int fls128 (u64 a , u64 b )
124
80
{
125
81
return a ? fls64 (a ) + 64U : fls64 (b );
@@ -159,7 +115,7 @@ static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits,
159
115
found = node ;
160
116
if (node -> cidr == bits )
161
117
break ;
162
- node = rcu_dereference_bh (CHOOSE_NODE (node , key ));
118
+ node = rcu_dereference_bh (node -> bit [ choose (node , key )] );
163
119
}
164
120
return found ;
165
121
}
@@ -191,8 +147,7 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
191
147
u8 cidr , u8 bits , struct allowedips_node * * rnode ,
192
148
struct mutex * lock )
193
149
{
194
- struct allowedips_node * node = rcu_dereference_protected (trie ,
195
- lockdep_is_held (lock ));
150
+ struct allowedips_node * node = rcu_dereference_protected (trie , lockdep_is_held (lock ));
196
151
struct allowedips_node * parent = NULL ;
197
152
bool exact = false;
198
153
@@ -202,13 +157,24 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
202
157
exact = true;
203
158
break ;
204
159
}
205
- node = rcu_dereference_protected (CHOOSE_NODE (parent , key ),
206
- lockdep_is_held (lock ));
160
+ node = rcu_dereference_protected (parent -> bit [choose (parent , key )], lockdep_is_held (lock ));
207
161
}
208
162
* rnode = parent ;
209
163
return exact ;
210
164
}
211
165
166
+ static inline void connect_node (struct allowedips_node * * parent , u8 bit , struct allowedips_node * node )
167
+ {
168
+ node -> parent_bit_packed = (unsigned long )parent | bit ;
169
+ rcu_assign_pointer (* parent , node );
170
+ }
171
+
172
+ static inline void choose_and_connect_node (struct allowedips_node * parent , struct allowedips_node * node )
173
+ {
174
+ u8 bit = choose (parent , node -> bits );
175
+ connect_node (& parent -> bit [bit ], bit , node );
176
+ }
177
+
212
178
static int add (struct allowedips_node __rcu * * trie , u8 bits , const u8 * key ,
213
179
u8 cidr , struct wg_peer * peer , struct mutex * lock )
214
180
{
@@ -218,13 +184,13 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
218
184
return - EINVAL ;
219
185
220
186
if (!rcu_access_pointer (* trie )) {
221
- node = kzalloc ( sizeof ( * node ) , GFP_KERNEL );
187
+ node = kmem_cache_zalloc ( node_cache , GFP_KERNEL );
222
188
if (unlikely (!node ))
223
189
return - ENOMEM ;
224
190
RCU_INIT_POINTER (node -> peer , peer );
225
191
list_add_tail (& node -> peer_list , & peer -> allowedips_list );
226
192
copy_and_assign_cidr (node , key , cidr , bits );
227
- rcu_assign_pointer ( * trie , node );
193
+ connect_node ( trie , 2 , node );
228
194
return 0 ;
229
195
}
230
196
if (node_placement (* trie , key , cidr , bits , & node , lock )) {
@@ -233,7 +199,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
233
199
return 0 ;
234
200
}
235
201
236
- newnode = kzalloc ( sizeof ( * newnode ) , GFP_KERNEL );
202
+ newnode = kmem_cache_zalloc ( node_cache , GFP_KERNEL );
237
203
if (unlikely (!newnode ))
238
204
return - ENOMEM ;
239
205
RCU_INIT_POINTER (newnode -> peer , peer );
@@ -243,41 +209,40 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
243
209
if (!node ) {
244
210
down = rcu_dereference_protected (* trie , lockdep_is_held (lock ));
245
211
} else {
246
- down = rcu_dereference_protected ( CHOOSE_NODE ( node , key ),
247
- lockdep_is_held (lock ));
212
+ const u8 bit = choose ( node , key );
213
+ down = rcu_dereference_protected ( node -> bit [ bit ], lockdep_is_held (lock ));
248
214
if (!down ) {
249
- rcu_assign_pointer ( CHOOSE_NODE ( node , key ) , newnode );
215
+ connect_node ( & node -> bit [ bit ], bit , newnode );
250
216
return 0 ;
251
217
}
252
218
}
253
219
cidr = min (cidr , common_bits (down , key , bits ));
254
220
parent = node ;
255
221
256
222
if (newnode -> cidr == cidr ) {
257
- rcu_assign_pointer ( CHOOSE_NODE ( newnode , down -> bits ) , down );
223
+ choose_and_connect_node ( newnode , down );
258
224
if (!parent )
259
- rcu_assign_pointer ( * trie , newnode );
225
+ connect_node ( trie , 2 , newnode );
260
226
else
261
- rcu_assign_pointer (CHOOSE_NODE (parent , newnode -> bits ),
262
- newnode );
263
- } else {
264
- node = kzalloc (sizeof (* node ), GFP_KERNEL );
265
- if (unlikely (!node )) {
266
- list_del (& newnode -> peer_list );
267
- kfree (newnode );
268
- return - ENOMEM ;
269
- }
270
- INIT_LIST_HEAD (& node -> peer_list );
271
- copy_and_assign_cidr (node , newnode -> bits , cidr , bits );
227
+ choose_and_connect_node (parent , newnode );
228
+ return 0 ;
229
+ }
272
230
273
- rcu_assign_pointer (CHOOSE_NODE (node , down -> bits ), down );
274
- rcu_assign_pointer (CHOOSE_NODE (node , newnode -> bits ), newnode );
275
- if (!parent )
276
- rcu_assign_pointer (* trie , node );
277
- else
278
- rcu_assign_pointer (CHOOSE_NODE (parent , node -> bits ),
279
- node );
231
+ node = kmem_cache_zalloc (node_cache , GFP_KERNEL );
232
+ if (unlikely (!node )) {
233
+ list_del (& newnode -> peer_list );
234
+ kmem_cache_free (node_cache , newnode );
235
+ return - ENOMEM ;
280
236
}
237
+ INIT_LIST_HEAD (& node -> peer_list );
238
+ copy_and_assign_cidr (node , newnode -> bits , cidr , bits );
239
+
240
+ choose_and_connect_node (node , down );
241
+ choose_and_connect_node (node , newnode );
242
+ if (!parent )
243
+ connect_node (trie , 2 , node );
244
+ else
245
+ choose_and_connect_node (parent , node );
281
246
return 0 ;
282
247
}
283
248
@@ -335,9 +300,41 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
335
300
void wg_allowedips_remove_by_peer (struct allowedips * table ,
336
301
struct wg_peer * peer , struct mutex * lock )
337
302
{
303
+ struct allowedips_node * node , * child , * * parent_bit , * parent , * tmp ;
304
+ bool free_parent ;
305
+
306
+ if (list_empty (& peer -> allowedips_list ))
307
+ return ;
338
308
++ table -> seq ;
339
- walk_remove_by_peer (& table -> root4 , peer , lock );
340
- walk_remove_by_peer (& table -> root6 , peer , lock );
309
+ list_for_each_entry_safe (node , tmp , & peer -> allowedips_list , peer_list ) {
310
+ list_del_init (& node -> peer_list );
311
+ RCU_INIT_POINTER (node -> peer , NULL );
312
+ if (node -> bit [0 ] && node -> bit [1 ])
313
+ continue ;
314
+ child = rcu_dereference_protected (node -> bit [!rcu_access_pointer (node -> bit [0 ])],
315
+ lockdep_is_held (lock ));
316
+ if (child )
317
+ child -> parent_bit_packed = node -> parent_bit_packed ;
318
+ parent_bit = (struct allowedips_node * * )(node -> parent_bit_packed & ~3UL );
319
+ * parent_bit = child ;
320
+ parent = (void * )parent_bit -
321
+ offsetof(struct allowedips_node , bit [node -> parent_bit_packed & 1 ]);
322
+ free_parent = !rcu_access_pointer (node -> bit [0 ]) &&
323
+ !rcu_access_pointer (node -> bit [1 ]) &&
324
+ (node -> parent_bit_packed & 3 ) <= 1 &&
325
+ !rcu_access_pointer (parent -> peer );
326
+ if (free_parent )
327
+ child = rcu_dereference_protected (
328
+ parent -> bit [!(node -> parent_bit_packed & 1 )],
329
+ lockdep_is_held (lock ));
330
+ call_rcu (& node -> rcu , node_free_rcu );
331
+ if (!free_parent )
332
+ continue ;
333
+ if (child )
334
+ child -> parent_bit_packed = parent -> parent_bit_packed ;
335
+ * (struct allowedips_node * * )(parent -> parent_bit_packed & ~3UL ) = child ;
336
+ call_rcu (& parent -> rcu , node_free_rcu );
337
+ }
341
338
}
342
339
343
340
int wg_allowedips_read_node (struct allowedips_node * node , u8 ip [16 ], u8 * cidr )
@@ -374,4 +371,16 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
374
371
return NULL ;
375
372
}
376
373
374
+ int __init wg_allowedips_slab_init (void )
375
+ {
376
+ node_cache = KMEM_CACHE (allowedips_node , 0 );
377
+ return node_cache ? 0 : - ENOMEM ;
378
+ }
379
+
380
+ void wg_allowedips_slab_uninit (void )
381
+ {
382
+ rcu_barrier ();
383
+ kmem_cache_destroy (node_cache );
384
+ }
385
+
377
386
#include "selftest/allowedips.c"
0 commit comments