Skip to content

Commit 1b5a85d

Browse files
authored
Faster bucket search in ByteBufferHashTable (#18952)
Adds hash code comparison for large enough keys to ByteBufferHashTable#findBucket(). Also, changes key comparison to use long/int/byte instead of byte-only comparison (thus, the comparison is now closer to HashTableUtils#memoryEquals() used in MemoryOpenHashTable). These changes are aimed to speed-up bucket search in ByteBufferHashTable, especially in high-collision cases.
1 parent eed9d8b commit 1b5a85d

File tree

4 files changed

+324
-25
lines changed

4 files changed

+324
-25
lines changed
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.druid.benchmark;
21+
22+
import org.apache.druid.java.util.common.ByteBufferUtils;
23+
import org.apache.druid.query.groupby.epinephelinae.ByteBufferHashTable;
24+
import org.apache.druid.query.groupby.epinephelinae.Groupers;
25+
import org.openjdk.jmh.annotations.Benchmark;
26+
import org.openjdk.jmh.annotations.BenchmarkMode;
27+
import org.openjdk.jmh.annotations.Fork;
28+
import org.openjdk.jmh.annotations.Level;
29+
import org.openjdk.jmh.annotations.Measurement;
30+
import org.openjdk.jmh.annotations.Mode;
31+
import org.openjdk.jmh.annotations.OperationsPerInvocation;
32+
import org.openjdk.jmh.annotations.OutputTimeUnit;
33+
import org.openjdk.jmh.annotations.Param;
34+
import org.openjdk.jmh.annotations.Scope;
35+
import org.openjdk.jmh.annotations.Setup;
36+
import org.openjdk.jmh.annotations.State;
37+
import org.openjdk.jmh.annotations.TearDown;
38+
import org.openjdk.jmh.annotations.Warmup;
39+
40+
import java.nio.ByteBuffer;
41+
import java.util.Random;
42+
import java.util.concurrent.ThreadLocalRandom;
43+
import java.util.concurrent.TimeUnit;
44+
45+
/**
46+
* Benchmark for ByteBufferHashTable.findBucket() method.
47+
* Measures lookup latency with various key sizes, simulating GROUP BY workloads
48+
* with a mix of existing key lookups and new key insertions.
49+
*/
50+
@BenchmarkMode(Mode.AverageTime)
51+
@OutputTimeUnit(TimeUnit.NANOSECONDS)
52+
@OperationsPerInvocation(ByteBufferHashTableBenchmark.ITERATIONS)
53+
@Warmup(iterations = 3)
54+
@Measurement(iterations = 5)
55+
@Fork(1)
56+
@State(Scope.Benchmark)
57+
public class ByteBufferHashTableBenchmark
58+
{
59+
public static final int ITERATIONS = 10000;
60+
61+
private static final float MAX_LOAD_FACTOR = 0.7f;
62+
private static final int NUM_BUCKETS = 16384;
63+
// Percentage of lookups that will be for non-existent keys (simulates new group creation)
64+
private static final double NEW_KEY_RATIO = 0.1;
65+
// Size of aggregation values per bucket (e.g., 8 bytes for a long sum, 16 for sum+count, etc.)
66+
private static final int VALUE_SIZE = 16;
67+
68+
@Param({"8", "16", "32", "64", "128"})
69+
public int keySize;
70+
71+
private BenchmarkHashTable hashTable;
72+
private ByteBuffer[] lookupKeys;
73+
private int[] lookupKeyHashes;
74+
75+
// Direct buffers to be freed in tearDown
76+
private ByteBuffer tableBuffer;
77+
private ByteBuffer insertedKeysBuffer;
78+
private ByteBuffer lookupKeysBuffer;
79+
80+
@Setup(Level.Trial)
81+
public void setup()
82+
{
83+
int bucketSize = Integer.BYTES + keySize + VALUE_SIZE; // hash + key + aggregation values
84+
tableBuffer = ByteBuffer.allocateDirect(NUM_BUCKETS * bucketSize);
85+
86+
hashTable = new BenchmarkHashTable(
87+
MAX_LOAD_FACTOR,
88+
NUM_BUCKETS,
89+
bucketSize,
90+
tableBuffer,
91+
keySize
92+
);
93+
hashTable.reset();
94+
95+
Random random = new Random(42);
96+
int numEntries = (int) (NUM_BUCKETS * MAX_LOAD_FACTOR);
97+
98+
// Allocate direct buffer for inserted keys
99+
insertedKeysBuffer = ByteBuffer.allocateDirect(numEntries * keySize);
100+
int[] insertedKeyHashes = new int[numEntries];
101+
102+
// Insert entries into the hash table
103+
for (int i = 0; i < numEntries; i++) {
104+
byte[] keyBytes = new byte[keySize];
105+
random.nextBytes(keyBytes);
106+
107+
// Store key in direct buffer
108+
int keyOffset = i * keySize;
109+
insertedKeysBuffer.position(keyOffset);
110+
insertedKeysBuffer.put(keyBytes);
111+
112+
// Create a slice for this key
113+
insertedKeysBuffer.position(keyOffset);
114+
insertedKeysBuffer.limit(keyOffset + keySize);
115+
ByteBuffer keyBuffer = insertedKeysBuffer.slice();
116+
insertedKeysBuffer.clear(); // Reset limit for next iteration
117+
118+
int keyHash = Groupers.smear(keyBuffer.getInt(0)) & Groupers.USED_FLAG_MASK;
119+
insertedKeyHashes[i] = keyHash;
120+
121+
int bucket = hashTable.findBucket0(true, keyBuffer, keyHash);
122+
if (bucket >= 0) {
123+
// Reset position before initBucket since initializeNewBucketKey uses relative put
124+
keyBuffer.position(0);
125+
hashTable.initBucket(bucket, keyBuffer, keyHash);
126+
}
127+
}
128+
129+
// Prepare lookup keys - mix of existing keys and new keys (not in table)
130+
// Allocate a single direct buffer for all lookup keys
131+
lookupKeysBuffer = ByteBuffer.allocateDirect(ITERATIONS * keySize);
132+
lookupKeys = new ByteBuffer[ITERATIONS];
133+
lookupKeyHashes = new int[ITERATIONS];
134+
135+
int newKeyCount = (int) (ITERATIONS * NEW_KEY_RATIO);
136+
137+
// Generate new keys that don't exist in the table
138+
Random newKeyRandom = new Random(12345);
139+
for (int i = 0; i < newKeyCount; i++) {
140+
byte[] keyBytes = new byte[keySize];
141+
newKeyRandom.nextBytes(keyBytes);
142+
143+
int keyOffset = i * keySize;
144+
lookupKeysBuffer.position(keyOffset);
145+
lookupKeysBuffer.put(keyBytes);
146+
147+
// Create a slice for this key
148+
lookupKeysBuffer.position(keyOffset);
149+
lookupKeysBuffer.limit(keyOffset + keySize);
150+
lookupKeys[i] = lookupKeysBuffer.slice();
151+
lookupKeysBuffer.clear();
152+
153+
lookupKeyHashes[i] = Groupers.smear(lookupKeys[i].getInt(0)) & Groupers.USED_FLAG_MASK;
154+
}
155+
156+
// Fill the rest with existing keys (copy from insertedKeysBuffer)
157+
for (int i = newKeyCount; i < ITERATIONS; i++) {
158+
int idx = ThreadLocalRandom.current().nextInt(numEntries);
159+
160+
// Copy key data from inserted keys buffer
161+
int srcOffset = idx * keySize;
162+
int dstOffset = i * keySize;
163+
164+
for (int j = 0; j < keySize; j++) {
165+
lookupKeysBuffer.put(dstOffset + j, insertedKeysBuffer.get(srcOffset + j));
166+
}
167+
168+
// Create a slice for this key
169+
lookupKeysBuffer.position(dstOffset);
170+
lookupKeysBuffer.limit(dstOffset + keySize);
171+
lookupKeys[i] = lookupKeysBuffer.slice();
172+
lookupKeysBuffer.clear();
173+
174+
lookupKeyHashes[i] = insertedKeyHashes[idx];
175+
}
176+
177+
// Shuffle to mix new and existing keys
178+
for (int i = ITERATIONS - 1; i > 0; i--) {
179+
int j = ThreadLocalRandom.current().nextInt(i + 1);
180+
ByteBuffer tempKey = lookupKeys[i];
181+
lookupKeys[i] = lookupKeys[j];
182+
lookupKeys[j] = tempKey;
183+
int tempHash = lookupKeyHashes[i];
184+
lookupKeyHashes[i] = lookupKeyHashes[j];
185+
lookupKeyHashes[j] = tempHash;
186+
}
187+
}
188+
189+
@TearDown(Level.Trial)
190+
public void tearDown()
191+
{
192+
if (tableBuffer != null) {
193+
ByteBufferUtils.free(tableBuffer);
194+
tableBuffer = null;
195+
}
196+
if (insertedKeysBuffer != null) {
197+
ByteBufferUtils.free(insertedKeysBuffer);
198+
insertedKeysBuffer = null;
199+
}
200+
if (lookupKeysBuffer != null) {
201+
ByteBufferUtils.free(lookupKeysBuffer);
202+
lookupKeysBuffer = null;
203+
}
204+
}
205+
206+
@Benchmark
207+
public int findBucket()
208+
{
209+
int result = 0;
210+
for (int i = 0; i < ITERATIONS; i++) {
211+
// allowNewBucket=true simulates GROUP BY where new groups can be created
212+
result ^= hashTable.findBucket0(true, lookupKeys[i], lookupKeyHashes[i]);
213+
}
214+
return result;
215+
}
216+
217+
/**
218+
* Test harness that exposes protected findBucket() method for benchmarking.
219+
*/
220+
private static class BenchmarkHashTable extends ByteBufferHashTable
221+
{
222+
BenchmarkHashTable(
223+
float maxLoadFactor,
224+
int initialBuckets,
225+
int bucketSizeWithHash,
226+
ByteBuffer buffer,
227+
int keySize
228+
)
229+
{
230+
super(maxLoadFactor, initialBuckets, bucketSizeWithHash, buffer, keySize, Integer.MAX_VALUE, null);
231+
}
232+
233+
int findBucket0(boolean allowNewBucket, ByteBuffer keyBuffer, int keyHash)
234+
{
235+
return findBucket(allowNewBucket, maxBuckets, tableBuffer, keyBuffer, keyHash);
236+
}
237+
238+
void initBucket(int bucket, ByteBuffer keyBuffer, int keyHash)
239+
{
240+
initializeNewBucketKey(bucket, keyBuffer, keyHash);
241+
}
242+
}
243+
}

processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/ByteBufferHashTable.java

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ public void reset()
142142

143143
// Clear used bits of new table
144144
for (int i = 0; i < maxBuckets; i++) {
145-
tableBuffer.put(i * bucketSizeWithHash, (byte) 0);
145+
tableBuffer.putInt(i * bucketSizeWithHash, 0);
146146
}
147147
}
148148

@@ -181,7 +181,7 @@ public void adjustTableWhenFull()
181181

182182
// Clear used bits of new table
183183
for (int i = 0; i < newBuckets; i++) {
184-
newTableBuffer.put(i * bucketSizeWithHash, (byte) 0);
184+
newTableBuffer.putInt(i * bucketSizeWithHash, 0);
185185
}
186186

187187
// Loop over old buckets and copy to new table
@@ -202,7 +202,7 @@ public void adjustTableWhenFull()
202202
keyBuffer.limit(entryBuffer.position() + HASH_SIZE + keySize);
203203
keyBuffer.position(entryBuffer.position() + HASH_SIZE);
204204

205-
final int keyHash = entryBuffer.getInt(entryBuffer.position()) & 0x7fffffff;
205+
final int keyHash = entryBuffer.getInt(entryBuffer.position()) & Groupers.USED_FLAG_MASK;
206206
final int newBucket = findBucket(true, newBuckets, newTableBuffer, keyBuffer, keyHash);
207207

208208
if (newBucket < 0) {
@@ -300,35 +300,81 @@ protected int findBucket(
300300
final int startBucket = keyHash % buckets;
301301
int bucket = startBucket;
302302

303-
outer:
303+
// Pre-compute hash with used flag for comparison.
304+
final int keyHashWithUsedFlag = Groupers.getUsedFlag(keyHash);
305+
final int keyBufferPosition = keyBuffer.position();
306+
304307
while (true) {
305308
final int bucketOffset = bucket * bucketSizeWithHash;
309+
final int storedHashWithUsedFlag = targetTableBuffer.getInt(bucketOffset);
306310

307-
if ((targetTableBuffer.get(bucketOffset) & 0x80) == 0) {
311+
if ((storedHashWithUsedFlag & Groupers.USED_FLAG_BIT) == 0) {
308312
// Found unused bucket before finding our key
309313
return allowNewBucket ? bucket : -1;
310314
}
311315

312-
for (int i = bucketOffset + HASH_SIZE, j = keyBuffer.position(); j < keyBuffer.position() + keySize; i++, j++) {
313-
if (targetTableBuffer.get(i) != keyBuffer.get(j)) {
314-
bucket += 1;
315-
if (bucket == buckets) {
316-
bucket = 0;
317-
}
316+
if (storedHashWithUsedFlag == keyHashWithUsedFlag &&
317+
keysEqual(targetTableBuffer, bucketOffset + HASH_SIZE, keyBuffer, keyBufferPosition, keySize)) {
318+
// Found our key in a used bucket
319+
return bucket;
320+
}
318321

319-
if (bucket == startBucket) {
320-
// Came back around to the start without finding a free slot, that was a long trip!
321-
// Should never happen unless buckets == regrowthThreshold.
322-
return -1;
323-
}
322+
// Move to next bucket (linear probing)
323+
bucket += 1;
324+
if (bucket == buckets) {
325+
bucket = 0;
326+
}
324327

325-
continue outer;
326-
}
328+
if (bucket == startBucket) {
329+
// Came back around to the start without finding a free slot, that was a long trip!
330+
// Should never happen unless buckets == regrowthThreshold.
331+
return -1;
327332
}
333+
}
334+
}
328335

329-
// Found our key in a used bucket
330-
return bucket;
336+
/**
337+
* Compare keys using long/int comparisons for better performance than byte-by-byte.
338+
*/
339+
private static boolean keysEqual(
340+
final ByteBuffer tableBuffer,
341+
int tableOffset,
342+
final ByteBuffer keyBuffer,
343+
int keyOffset,
344+
int length
345+
)
346+
{
347+
// Compare 8 bytes at a time
348+
while (length >= Long.BYTES) {
349+
if (tableBuffer.getLong(tableOffset) != keyBuffer.getLong(keyOffset)) {
350+
return false;
351+
}
352+
tableOffset += Long.BYTES;
353+
keyOffset += Long.BYTES;
354+
length -= Long.BYTES;
355+
}
356+
357+
// Compare 4 bytes if remaining
358+
if (length >= Integer.BYTES) {
359+
if (tableBuffer.getInt(tableOffset) != keyBuffer.getInt(keyOffset)) {
360+
return false;
361+
}
362+
tableOffset += Integer.BYTES;
363+
keyOffset += Integer.BYTES;
364+
length -= Integer.BYTES;
331365
}
366+
367+
// Compare remaining 1-3 bytes
368+
while (length > 0) {
369+
if (tableBuffer.get(tableOffset) != keyBuffer.get(keyOffset)) {
370+
return false;
371+
}
372+
tableOffset++;
373+
keyOffset++;
374+
length--;
375+
}
376+
377+
return true;
332378
}
333379

334380
protected boolean canAllowNewBucket()

processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/Groupers.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,17 @@ private Groupers()
4848
+ "for details."
4949
);
5050

51-
private static final int USED_FLAG_MASK = 0x7fffffff;
51+
/**
52+
* Mask to clear the used flag bit from a hash value. Used when reading a hash from a bucket
53+
* to recover the original hash without the used flag.
54+
*/
55+
public static final int USED_FLAG_MASK = 0x7fffffff;
56+
57+
/**
58+
* Bit that indicates a bucket is used in the hash table. This is the sign bit (highest bit)
59+
* of the hash value stored in each bucket.
60+
*/
61+
public static final int USED_FLAG_BIT = 0x80000000;
5262

5363
private static final int C1 = 0xcc9e2d51;
5464
private static final int C2 = 0x1b873593;
@@ -93,7 +103,7 @@ public static int hashObject(final Object obj)
93103

94104
static int getUsedFlag(int keyHash)
95105
{
96-
return keyHash | 0x80000000;
106+
return keyHash | USED_FLAG_BIT;
97107
}
98108

99109
public static ByteBuffer getSlice(ByteBuffer buffer, int sliceSize, int i)

0 commit comments

Comments
 (0)