Skip to content
This repository was archived by the owner on Sep 27, 2019. It is now read-only.

Commit 5ba478e

Browse files
committed
Fix insert/probe for oa tables
1 parent a673a64 commit 5ba478e

File tree

2 files changed

+123
-21
lines changed

2 files changed

+123
-21
lines changed

src/include/codegen/util/oa_hash_table.h

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,26 @@ class OAHashTable {
106106
static void Destroy(OAHashTable &table);
107107

108108
/**
109-
* Insert a key-value pair into the hash-table.
109+
* Insert a key-value pair into the hash-table. Mostly used for testing.
110110
*
111-
* @tparam Key The datatype of the key
112-
* @tparam Value The datatype of the value
113111
* @param hash The hash value of the key
114112
* @param key The key to store in the table
115113
* @param value The value to store in the value
116114
*/
117115
template <typename Key, typename Value>
118-
void Insert(uint64_t hash, Key &key, Value &value);
116+
void Insert(uint64_t hash, const Key &key, const Value &value);
117+
118+
/**
119+
* Probe a key in the hash table. Doesn't deal with duplicate values. Mostly
120+
* for testing.
121+
*
122+
* @param hash
123+
* @param key
124+
* @param value
125+
* @return
126+
*/
127+
template <typename Key, typename Value>
128+
bool Probe(uint64_t hash, const Key &key, Value &value);
119129

120130
/**
121131
* Make room in the hash-table to store a new key-value pair. The provided
@@ -263,13 +273,13 @@ class OAHashTable {
263273
};
264274

265275
template <typename Key, typename Value>
266-
void OAHashTable::Insert(uint64_t hash, Key &key, Value &value) {
276+
void OAHashTable::Insert(uint64_t hash, const Key &key, const Value &value) {
267277
uint64_t bucket = hash & bucket_mask_;
268278

269279
uint64_t entry_int =
270280
reinterpret_cast<uint64_t>(buckets_) + bucket * entry_size_;
271281
while (true) {
272-
HashEntry *entry = reinterpret_cast<HashEntry *>(entry_int);
282+
auto *entry = reinterpret_cast<HashEntry *>(entry_int);
273283

274284
// If entry is free, dump key and value
275285
if (entry->IsFree()) {
@@ -292,10 +302,49 @@ void OAHashTable::Insert(uint64_t hash, Key &key, Value &value) {
292302
}
293303

294304
// Continue
295-
bucket = (bucket == num_buckets_) ? 0 : bucket + 1;
296-
entry_int = (bucket == num_buckets_) ? reinterpret_cast<uint64_t>(buckets_)
297-
: entry_int + entry_size_;
305+
bucket++;
306+
entry_int += entry_size_;
307+
308+
// Wrap
309+
if (bucket == num_buckets_) {
310+
bucket = 0;
311+
entry_int = reinterpret_cast<uint64_t>(buckets_);
312+
}
313+
}
314+
}
315+
316+
template <typename Key, typename Value>
317+
bool OAHashTable::Probe(uint64_t hash, const Key &key, Value &value) {
318+
uint64_t steps = 0;
319+
320+
uint64_t bucket = hash & bucket_mask_;
321+
322+
uint64_t entry_int =
323+
reinterpret_cast<uint64_t>(buckets_) + (bucket * entry_size_);
324+
325+
while (steps++ < num_entries_) {
326+
auto *entry = reinterpret_cast<HashEntry *>(entry_int);
327+
// check if entry is free
328+
if (!entry->IsFree() && entry->hash == hash) {
329+
// hashes match, check key
330+
auto *entry_key = reinterpret_cast<Key *>(entry->data);
331+
if (*entry_key == key) {
332+
value = *reinterpret_cast<Value *>(entry->data + sizeof(Key));
333+
return true;
334+
}
335+
}
336+
337+
// Continue
338+
bucket++;
339+
entry_int += entry_size_;
340+
341+
// Wrap
342+
if (bucket == num_buckets_) {
343+
bucket = 0;
344+
entry_int = reinterpret_cast<uint64_t>(buckets_);
345+
}
298346
}
347+
return false;
299348
}
300349

301350
} // namespace util

test/codegen/oa_hash_table_test.cpp

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,19 @@ TEST_F(OAHashTableTest, MicroBenchmark) {
136136
std::vector<Key> keys;
137137
Value v = {6, 5, 4, 3};
138138

139+
std::random_device r;
140+
std::default_random_engine e(r());
141+
std::uniform_int_distribution<uint32_t> gen;
142+
139143
// Create all keys
140144
uint32_t num_keys = 100000;
141145
for (uint32_t i = 0; i < num_keys; i++) {
142-
keys.emplace_back(1, static_cast<uint32_t>(rand()));
146+
keys.emplace_back(gen(e), gen(e));
143147
}
144148

145-
double avg_oaht = 0.0;
146-
double avg_map = 0.0;
147-
double avg_cuckoo = 0.0;
149+
double avg_oaht_insert = 0.0, avg_oaht_probe = 0.0;
150+
double avg_map_insert = 0.0, avg_map_probe = 0.0;
151+
double avg_cuckoo_insert = 0.0, avg_cuckoo_probe = 0.0;
148152

149153
// First, bench ours ...
150154
{
@@ -156,14 +160,29 @@ TEST_F(OAHashTableTest, MicroBenchmark) {
156160
Timer<std::ratio<1, 1000>> timer;
157161
timer.Start();
158162

159-
// Start
163+
// Start Insert
160164
for (uint32_t i = 0; i < num_keys; i++) {
161165
ht.Insert(Hash(keys[i]), keys[i], v);
162166
}
163-
// End
167+
// End Insert
164168

165169
timer.Stop();
166-
avg_oaht += timer.GetDuration();
170+
avg_oaht_insert += timer.GetDuration();
171+
172+
timer.Reset();
173+
timer.Start();
174+
175+
// Start Probe
176+
std::vector<Key> shuffled = keys;
177+
std::random_shuffle(shuffled.begin(), shuffled.end());
178+
for (uint32_t i = 0; i < num_keys; i++) {
179+
Value probe_val;
180+
EXPECT_TRUE(ht.Probe(Hash(shuffled[i]), shuffled[i], probe_val));
181+
}
182+
// End Probe
183+
184+
timer.Stop();
185+
avg_oaht_probe += timer.GetDuration();
167186
}
168187
}
169188

@@ -185,7 +204,21 @@ TEST_F(OAHashTableTest, MicroBenchmark) {
185204
}
186205

187206
timer.Stop();
188-
avg_map += timer.GetDuration();
207+
avg_map_insert += timer.GetDuration();
208+
209+
timer.Reset();
210+
timer.Start();
211+
212+
// Start Probe
213+
std::vector<Key> shuffled = keys;
214+
std::random_shuffle(shuffled.begin(), shuffled.end());
215+
for (uint32_t i = 0; i < num_keys; i++) {
216+
EXPECT_NE(ht.find(shuffled[i]), ht.end());
217+
}
218+
// End Probe
219+
220+
timer.Stop();
221+
avg_map_probe += timer.GetDuration();
189222
}
190223
}
191224

@@ -207,13 +240,33 @@ TEST_F(OAHashTableTest, MicroBenchmark) {
207240
}
208241

209242
timer.Stop();
210-
avg_cuckoo += timer.GetDuration();
243+
avg_cuckoo_insert += timer.GetDuration();
244+
245+
timer.Reset();
246+
timer.Start();
247+
248+
// Start Probe
249+
std::vector<Key> shuffled = keys;
250+
std::random_shuffle(shuffled.begin(), shuffled.end());
251+
for (uint32_t i = 0; i < num_keys; i++) {
252+
Value probe_val;
253+
EXPECT_TRUE(map.find(shuffled[i], probe_val));
254+
}
255+
// End Probe
256+
257+
timer.Stop();
258+
avg_cuckoo_probe += timer.GetDuration();
211259
}
212260
}
213261

214-
LOG_INFO("Avg OA_HT: %.2lf, Avg std::unordered_map: %.2lf, Avg cuckoo: %.2lf",
215-
avg_oaht / (double)num_runs, avg_map / (double)num_runs,
216-
avg_cuckoo / (double)num_runs);
262+
LOG_INFO("OA_HT insert: %.2lf, probe: %.2lf",
263+
avg_oaht_insert / (double)num_runs,
264+
avg_oaht_probe / (double)num_runs);
265+
LOG_INFO("std::unordered_map insert: %.2lf, probe: %.2lf",
266+
avg_map_insert / (double)num_runs, avg_map_probe / (double)num_runs);
267+
LOG_INFO("Cuckoo insert: %.2lf, probe: %.2lf",
268+
avg_cuckoo_insert / (double)num_runs,
269+
avg_cuckoo_probe / (double)num_runs);
217270
}
218271

219272
} // namespace test

0 commit comments

Comments
 (0)