1818#include "mono_time.h"
1919#include "network.h"
2020#include "ping.h"
21+ #include "shared_key_cache.h"
2122#include "state.h"
2223#include "util.h"
2324
4647// TODO(sudden6): find out why we need multiple callbacks and if we really need 32
4748#define DHT_FRIEND_MAX_LOCKS 32
4849
50+ /* Settings for the shared key cache */
51+ #define MAX_KEYS_PER_SLOT 4
52+ #define KEYS_TIMEOUT 600
53+
4954typedef struct DHT_Friend_Callback {
5055 dht_ip_cb * ip_callback ;
5156 void * data ;
@@ -107,8 +112,8 @@ struct DHT {
107112 uint32_t loaded_num_nodes ;
108113 unsigned int loaded_nodes_index ;
109114
110- Shared_Keys shared_keys_recv ;
111- Shared_Keys shared_keys_sent ;
115+ Shared_Key_Cache * shared_keys_recv ;
116+ Shared_Key_Cache * shared_keys_sent ;
112117
113118 struct Ping * ping ;
114119 Ping_Array * dht_ping_array ;
@@ -255,74 +260,22 @@ unsigned int bit_by_bit_cmp(const uint8_t *pk1, const uint8_t *pk2)
255260 return i * 8 + j ;
256261}
257262
258- /**
259- * Shared key generations are costly, it is therefore smart to store commonly used
260- * ones so that they can be re-used later without being computed again.
261- *
262- * If a shared key is already in shared_keys, copy it to shared_key.
263- * Otherwise generate it into shared_key and copy it to shared_keys
264- */
265- void get_shared_key (const Mono_Time * mono_time , Shared_Keys * shared_keys , uint8_t * shared_key ,
266- const uint8_t * secret_key , const uint8_t * public_key )
267- {
268- uint32_t num = -1 ;
269- uint32_t curr = 0 ;
270-
271- for (uint32_t i = 0 ; i < MAX_KEYS_PER_SLOT ; ++ i ) {
272- const int index = public_key [30 ] * MAX_KEYS_PER_SLOT + i ;
273- Shared_Key * const key = & shared_keys -> keys [index ];
274-
275- if (key -> stored ) {
276- if (pk_equal (public_key , key -> public_key )) {
277- memcpy (shared_key , key -> shared_key , CRYPTO_SHARED_KEY_SIZE );
278- ++ key -> times_requested ;
279- key -> time_last_requested = mono_time_get (mono_time );
280- return ;
281- }
282-
283- if (num != 0 ) {
284- if (mono_time_is_timeout (mono_time , key -> time_last_requested , KEYS_TIMEOUT )) {
285- num = 0 ;
286- curr = index ;
287- } else if (num > key -> times_requested ) {
288- num = key -> times_requested ;
289- curr = index ;
290- }
291- }
292- } else if (num != 0 ) {
293- num = 0 ;
294- curr = index ;
295- }
296- }
297-
298- encrypt_precompute (public_key , secret_key , shared_key );
299-
300- if (num != UINT32_MAX ) {
301- Shared_Key * const key = & shared_keys -> keys [curr ];
302- key -> stored = true;
303- key -> times_requested = 1 ;
304- memcpy (key -> public_key , public_key , CRYPTO_PUBLIC_KEY_SIZE );
305- memcpy (key -> shared_key , shared_key , CRYPTO_SHARED_KEY_SIZE );
306- key -> time_last_requested = mono_time_get (mono_time );
307- }
308- }
309-
310263/**
311264 * Copy shared_key to encrypt/decrypt DHT packet from public_key into shared_key
312265 * for packets that we receive.
313266 */
314- void dht_get_shared_key_recv (DHT * dht , uint8_t * shared_key , const uint8_t * public_key )
267+ const uint8_t * dht_get_shared_key_recv (DHT * dht , const uint8_t * public_key )
315268{
316- get_shared_key (dht -> mono_time , & dht -> shared_keys_recv , shared_key , dht -> self_secret_key , public_key );
269+ return shared_key_cache_lookup (dht -> shared_keys_recv , public_key );
317270}
318271
319272/**
320273 * Copy shared_key to encrypt/decrypt DHT packet from public_key into shared_key
321274 * for packets that we send.
322275 */
323- void dht_get_shared_key_sent (DHT * dht , uint8_t * shared_key , const uint8_t * public_key )
276+ const uint8_t * dht_get_shared_key_sent (DHT * dht , const uint8_t * public_key )
324277{
325- get_shared_key (dht -> mono_time , & dht -> shared_keys_sent , shared_key , dht -> self_secret_key , public_key );
278+ return shared_key_cache_lookup (dht -> shared_keys_sent , public_key );
326279}
327280
328281#define CRYPTO_SIZE (1 + CRYPTO_PUBLIC_KEY_SIZE * 2 + CRYPTO_NONCE_SIZE)
@@ -1063,8 +1016,7 @@ static bool send_announce_ping(DHT *dht, const uint8_t *public_key, const IP_Por
10631016 public_key , CRYPTO_PUBLIC_KEY_SIZE );
10641017 memcpy (plain + CRYPTO_PUBLIC_KEY_SIZE , & ping_id , sizeof (ping_id ));
10651018
1066- uint8_t shared_key [CRYPTO_SHARED_KEY_SIZE ];
1067- dht_get_shared_key_sent (dht , shared_key , public_key );
1019+ const uint8_t * shared_key = dht_get_shared_key_sent (dht , public_key );
10681020
10691021 uint8_t request [1 + CRYPTO_PUBLIC_KEY_SIZE + CRYPTO_NONCE_SIZE + sizeof (plain ) + CRYPTO_MAC_SIZE ];
10701022
@@ -1092,8 +1044,7 @@ static int handle_data_search_response(void *object, const IP_Port *source,
10921044
10931045 VLA (uint8_t , plain , plain_len );
10941046 const uint8_t * public_key = packet + 1 ;
1095- uint8_t shared_key [CRYPTO_SHARED_KEY_SIZE ];
1096- dht_get_shared_key_recv (dht , shared_key , public_key );
1047+ const uint8_t * shared_key = dht_get_shared_key_recv (dht , public_key );
10971048
10981049 if (decrypt_data_symmetric (shared_key ,
10991050 packet + 1 + CRYPTO_PUBLIC_KEY_SIZE ,
@@ -1522,15 +1473,12 @@ bool dht_getnodes(DHT *dht, const IP_Port *ip_port, const uint8_t *public_key, c
15221473 memcpy (plain , client_id , CRYPTO_PUBLIC_KEY_SIZE );
15231474 memcpy (plain + CRYPTO_PUBLIC_KEY_SIZE , & ping_id , sizeof (ping_id ));
15241475
1525- uint8_t shared_key [CRYPTO_SHARED_KEY_SIZE ];
1526- dht_get_shared_key_sent (dht , shared_key , public_key );
1476+ const uint8_t * shared_key = dht_get_shared_key_sent (dht , public_key );
15271477
15281478 const int len = dht_create_packet (dht -> rng ,
15291479 dht -> self_public_key , shared_key , NET_PACKET_GET_NODES ,
15301480 plain , sizeof (plain ), data , sizeof (data ));
15311481
1532- crypto_memzero (shared_key , sizeof (shared_key ));
1533-
15341482 if (len != sizeof (data )) {
15351483 LOGGER_ERROR (dht -> log , "getnodes packet encryption failed" );
15361484 return false;
@@ -1605,9 +1553,7 @@ static int handle_getnodes(void *object, const IP_Port *source, const uint8_t *p
16051553 }
16061554
16071555 uint8_t plain [CRYPTO_NODE_SIZE ];
1608- uint8_t shared_key [CRYPTO_SHARED_KEY_SIZE ];
1609-
1610- dht_get_shared_key_recv (dht , shared_key , packet + 1 );
1556+ const uint8_t * shared_key = dht_get_shared_key_recv (dht , packet + 1 );
16111557 const int len = decrypt_data_symmetric (
16121558 shared_key ,
16131559 packet + 1 + CRYPTO_PUBLIC_KEY_SIZE ,
@@ -1616,16 +1562,13 @@ static int handle_getnodes(void *object, const IP_Port *source, const uint8_t *p
16161562 plain );
16171563
16181564 if (len != CRYPTO_NODE_SIZE ) {
1619- crypto_memzero (shared_key , sizeof (shared_key ));
16201565 return 1 ;
16211566 }
16221567
16231568 sendnodes_ipv6 (dht , source , packet + 1 , plain , plain + CRYPTO_PUBLIC_KEY_SIZE , sizeof (uint64_t ), shared_key );
16241569
16251570 ping_add (dht -> ping , packet + 1 , source );
16261571
1627- crypto_memzero (shared_key , sizeof (shared_key ));
1628-
16291572 return 0 ;
16301573}
16311574
@@ -1670,17 +1613,14 @@ static bool handle_sendnodes_core(void *object, const IP_Port *source, const uin
16701613 }
16711614
16721615 VLA (uint8_t , plain , 1 + data_size + sizeof (uint64_t ));
1673- uint8_t shared_key [CRYPTO_SHARED_KEY_SIZE ];
1674- dht_get_shared_key_sent (dht , shared_key , packet + 1 );
1616+ const uint8_t * shared_key = dht_get_shared_key_sent (dht , packet + 1 );
16751617 const int len = decrypt_data_symmetric (
16761618 shared_key ,
16771619 packet + 1 + CRYPTO_PUBLIC_KEY_SIZE ,
16781620 packet + 1 + CRYPTO_PUBLIC_KEY_SIZE + CRYPTO_NONCE_SIZE ,
16791621 1 + data_size + sizeof (uint64_t ) + CRYPTO_MAC_SIZE ,
16801622 plain );
16811623
1682- crypto_memzero (shared_key , sizeof (shared_key ));
1683-
16841624 if ((unsigned int )len != SIZEOF_VLA (plain )) {
16851625 return false;
16861626 }
@@ -2796,6 +2736,15 @@ DHT *new_dht(const Logger *log, const Random *rng, const Network *ns, Mono_Time
27962736
27972737 crypto_new_keypair (rng , dht -> self_public_key , dht -> self_secret_key );
27982738
2739+ dht -> shared_keys_recv = shared_key_cache_new (mono_time , dht -> self_secret_key , KEYS_TIMEOUT , MAX_KEYS_PER_SLOT );
2740+ dht -> shared_keys_sent = shared_key_cache_new (mono_time , dht -> self_secret_key , KEYS_TIMEOUT , MAX_KEYS_PER_SLOT );
2741+
2742+ if (dht -> shared_keys_recv == nullptr || dht -> shared_keys_sent == nullptr ) {
2743+ kill_dht (dht );
2744+ return nullptr ;
2745+ }
2746+
2747+
27992748 dht -> dht_ping_array = ping_array_new (DHT_PING_ARRAY_SIZE , PING_TIMEOUT );
28002749
28012750 if (dht -> dht_ping_array == nullptr ) {
@@ -2858,12 +2807,12 @@ void kill_dht(DHT *dht)
28582807 networking_registerhandler (dht -> net , NET_PACKET_LAN_DISCOVERY , nullptr , nullptr );
28592808 cryptopacket_registerhandler (dht , CRYPTO_PACKET_NAT_PING , nullptr , nullptr );
28602809
2810+ shared_key_cache_free (dht -> shared_keys_recv );
2811+ shared_key_cache_free (dht -> shared_keys_sent );
28612812 ping_array_kill (dht -> dht_ping_array );
28622813 ping_kill (dht -> ping );
28632814 free (dht -> friends_list );
28642815 free (dht -> loaded_nodes_list );
2865- crypto_memzero (& dht -> shared_keys_recv , sizeof (dht -> shared_keys_recv ));
2866- crypto_memzero (& dht -> shared_keys_sent , sizeof (dht -> shared_keys_sent ));
28672816 crypto_memzero (dht -> self_secret_key , sizeof (dht -> self_secret_key ));
28682817 free (dht );
28692818}
0 commit comments