|
| 1 | +--- |
| 2 | +layout: post |
| 3 | +title: "More collision-resistant hash tables" |
| 4 | +date: 2024-10-14 13:20:44 +0000 |
| 5 | +categories: jekyll update |
| 6 | +katex: true |
| 7 | +--- |
| 8 | + |
| 9 | +A sequel to the [previous post](https://somehybrid.github.io/jekyll/update/2024/10/14/hash-collisions.html). If you haven't read it, read it now. |
| 10 | + |
| 11 | +The previous post covered integer collisions in CPython's hash tables - notably that the hash algorithm for integers in CPython was done by taking the integer modulo a constant $P$, which was defined as the Mersenne prime $2^{61} - 1$. The main issue regarding this is DoS attacks on hash tables - namely that if an attacker that could add elements to a hash table is able to inexpensively manufacture hash collisions, they could fill it with items with the same hash and slow it down a lot. In the previous post, this could be done just by taking numbers congruent modulo $P$. How would you be able to combat such attacks? |
| 12 | + |
| 13 | +## A first attempt at a better hash table |
| 14 | +One way (and the most intuitive to me) to prevent hash collisions is instead of taking integers modulo $P$, you could bitwise XOR the ints with a secret that such an attacker wouldn't know. |
| 15 | + |
| 16 | +An implementation is given below: |
| 17 | +{% highlight python %} |
| 18 | +import secrets |
| 19 | + |
| 20 | +class XorDict(dict): |
| 21 | + def __init__(self, *args, **kwargs): |
| 22 | + self.secret = secrets.randbits(64) |
| 23 | + super().__init__(*args, **kwargs) |
| 24 | + |
| 25 | + def __getitem__(self, key: Any): |
| 26 | + if isinstance(key, int): |
| 27 | + key ^= self.secret |
| 28 | + |
| 29 | + try: |
| 30 | + return super().__getitem__(key) |
| 31 | + except KeyError: |
| 32 | + raise KeyError(key) |
| 33 | + |
| 34 | + def __getitem__(self, key: Any, value: Any): |
| 35 | + if isinstance(key, int): |
| 36 | + key ^= self.secret |
| 37 | + |
| 38 | + return super().__setitem__(key) |
| 39 | +{% endhighlight %} |
| 40 | + |
| 41 | +### Constructing collisions |
| 42 | +We can see this uses CPython's integer hash again, meaning you could arrange this as $(x \oplus s) \mod {P}$, with $x$ and $s$ as the input and secret respectively, and $P$ being the modulus defined above. From there, we can reuse the collision with congruence, by doing something like: |
| 43 | +{% highlight python %} |
| 44 | +>>> def badhash(x: int): |
| 45 | +... return (x ^ 12) % 15 # selecting values of s and p as 12 and 15 as a simplified example |
| 46 | +... |
| 47 | +>>> badhash(0b10) |
| 48 | +14 |
| 49 | +>>> 15 << 4 + 2 |
| 50 | +242 |
| 51 | +>>> badhash(242) |
| 52 | +14 |
| 53 | +{% endhighlight %} |
| 54 | + |
| 55 | +This works in essentially the same way as the other collision with more steps. You can just take the modulus multiplied by the smallest power of 2 larger than the modulus and the secret, which boils down to $[qP + ((x \mod P) \oplus s)] \equiv [(x \mod P) \oplus s] \pmod P$. The attacker just needs to know the number of bits in $P$ and $s$, which would only take a maximum of $\lceil \log_2(max(P, s)) \rceil$ hits against the hash table if not already known. |
| 56 | + |
| 57 | +## A second attempt |
| 58 | +As mentioned in the previous post, a good hashing algorithm for hash tables in CPython already exists - [SipHash](https://en.wikipedia.org/wiki/SipHash). One way you could construct a hash table is by converting the integers into bytes objects and then hashing them. |
| 59 | + |
| 60 | +In code: |
| 61 | +{% highlight python %} |
| 62 | +from collections.abc import Iterable, Sequence |
| 63 | +from typing import override, overload |
| 64 | + |
| 65 | +class SipDict[K, V](dict[K | int, V]): # pyright: ignore[reportMissingTypeArgument] |
| 66 | + def _calculate_index(self, key: int): |
| 67 | + return hash(key.to_bytes((key.bit_length() + 7) // 8)) |
| 68 | + |
| 69 | + @overload |
| 70 | + def __setitem__(self, key: int, value: V) -> None: ... |
| 71 | + |
| 72 | + @overload |
| 73 | + def __setitem__(self, key: K, value: V) -> None: ... |
| 74 | + |
| 75 | + @override |
| 76 | + def __setitem__(self, key: K | int, value: V) -> None: |
| 77 | + index = key |
| 78 | + if isinstance(key, int): |
| 79 | + index = self._calculate_index(key) |
| 80 | + |
| 81 | + return super().__setitem__(index, value) |
| 82 | + |
| 83 | + @overload |
| 84 | + def __getitem__(self, key: int) -> V: ... |
| 85 | + |
| 86 | + @overload |
| 87 | + def __getitem__(self, key: K) -> V: ... |
| 88 | + |
| 89 | + @override |
| 90 | + def __getitem__(self, key: K | int) -> V: |
| 91 | + index = key |
| 92 | + if isinstance(key, int): |
| 93 | + index = self._calculate_index(key) |
| 94 | + |
| 95 | + try: |
| 96 | + return super().__getitem__(index) |
| 97 | + except KeyError: |
| 98 | + raise KeyError(key) |
| 99 | + |
| 100 | + @classmethod |
| 101 | + def from_pairs(cls, data: Iterable[tuple[K, V]]) -> SipDict[K, V]: |
| 102 | + out: SipDict[K, V] = SipDict() |
| 103 | + for item in data: |
| 104 | + out[item[0]] = item[1] |
| 105 | + return out |
| 106 | +{% endhighlight %} |
| 107 | + |
| 108 | +The attacker could trivially find 2 keys producing the same hash by noting that the algorithm finds the hash of the byte representation of an integer. If the dictionary wasn't constrained specifically to integers, the attacker could get the byte representation and create 2 collisions. However, I do not give 2 ~~shits~~ hash collisions, as it wouldn't impact performance much. |
| 109 | + |
| 110 | +## Benchmarks |
| 111 | + |
| 112 | +Functions used to benchmark: |
| 113 | +{% highlight python %} |
| 114 | +def compare_initialization(data: list[tuple[int, int]]): |
| 115 | + start = time.perf_counter() |
| 116 | + _ = SipDict.from_pairs(data) |
| 117 | + end = time.perf_counter() |
| 118 | + print("improved hash table:", end - start) |
| 119 | + |
| 120 | + start = time.perf_counter() |
| 121 | + _ = dict(data) |
| 122 | + end = time.perf_counter() |
| 123 | + print("normal hash table:", end - start) |
| 124 | + |
| 125 | +def compare_retrieval(d1: SipDict[int, int], d2: dict, item: int): |
| 126 | + start = time.perf_counter() |
| 127 | + _ = d1[item] |
| 128 | + end = time.perf_counter() |
| 129 | + print("improved hash table:", end - start) |
| 130 | + |
| 131 | + start = time.perf_counter() |
| 132 | + _ = d2[item] |
| 133 | + end = time.perf_counter() |
| 134 | + print("normal hash table:", end - start) |
| 135 | +{% endhighlight %} |
| 136 | + |
| 137 | +### Constructing hash tables |
| 138 | +Constructing hash table with collisions: |
| 139 | + |
| 140 | +Benchmark code: |
| 141 | +{% highlight python %} |
| 142 | +p = 2**61 - 1 |
| 143 | +data = [(i * p + 1, i) for i in range(1000)] |
| 144 | + |
| 145 | +compare_initialization(data) |
| 146 | +{% endhighlight %} |
| 147 | + |
| 148 | +Results: |
| 149 | +{% highlight none %} |
| 150 | +improved hash table: 0.0006913679972058162 |
| 151 | +normal hash table: 0.015330620997701772 |
| 152 | +{% endhighlight %} |
| 153 | + |
| 154 | +Constructing hash table without collisions: |
| 155 | + |
| 156 | +{% highlight python %} |
| 157 | +p = 2**61 - 1 |
| 158 | +data = [(random.randint(1, 2 ** 64), i) for i in range(1000)] |
| 159 | + |
| 160 | +compare_initialization(data) |
| 161 | +{% endhighlight %} |
| 162 | + |
| 163 | +Results: |
| 164 | +{% highlight none %} |
| 165 | +improved hash table: 0.0008386349945794791 |
| 166 | +normal hash table: 0.00017141499847639352 |
| 167 | +{% endhighlight %} |
| 168 | + |
| 169 | +Note that the new hash table is many times faster than the normal one when initializing with lots of collisions, however is slower than the normal one without collisions. One thing to note is that both hashes are applied, CPython's modular reduction and SipHash, making it slow down more. Another reason is that the interpreter has to repeatedly add more memory, whilst CPython creates a presized dict using `dict_new_presized`. |
| 170 | + |
| 171 | +### Retrieving from hash tables |
| 172 | +Retrieval with collisions: |
| 173 | + |
| 174 | +Benchmark code: |
| 175 | +{% highlight python %} |
| 176 | +data = [(p * i + 1, i) for i in range(1000)] |
| 177 | +d1 = SipDict.from_pairs(data) |
| 178 | +d2 = dict(data) |
| 179 | + |
| 180 | +compare_retrieval(d1, d2, 874 * p + 1) |
| 181 | +{% endhighlight %} |
| 182 | + |
| 183 | +Results: |
| 184 | +{% highlight none %} |
| 185 | +improved hash table: 7.757989806123078e-06 |
| 186 | +normal hash table: 1.6758000128902495e-05 |
| 187 | +{% endhighlight %} |
| 188 | + |
| 189 | +Retrieval without collisions: |
| 190 | +{% highlight python %} |
| 191 | +data = [(i * (2**53), i) for i in range(1000)] |
| 192 | +d1 = SipDict.from_pairs(data) |
| 193 | +d2 = dict(data) |
| 194 | + |
| 195 | +compare_retrieval(d1, d2, 874 * (2**53)) |
| 196 | +{% endhighlight %} |
| 197 | + |
| 198 | +Results: |
| 199 | +{% highlight python %} |
| 200 | +improved hash table: 4.636996891349554e-06 |
| 201 | +normal hash table: 7.43995769880712e-07 |
| 202 | +{% endhighlight %} |
| 203 | + |
| 204 | +## Conclusion |
| 205 | +Is the use-case contrived? Possibly. May it actually happen? Also possibly. The above implementation of a slightly (?) (citation needed) better hash table is much slower for initialization (although its probably because of the memory allocation costs) and retrieval (this could probably be fixed if I edited the CPython source code, but honestly, i don't care enough). Like before, if this for some reason is a concern at all, go ask a professional. |
| 206 | + |
| 207 | +## Disclaimer |
| 208 | +The info in this blog post has tried to be accurate, yet there may be some issues and mistakes. If you find any, mention me on Mastodon at somehybrid@hachyderm.io and I might have enough motivation to actually fix any of it. |
0 commit comments