Skip to content

Commit 8232270

Browse files
committed
doing some performance tests
1 parent f1cda10 commit 8232270

File tree

7 files changed

+150
-42
lines changed

7 files changed

+150
-42
lines changed

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,18 @@ protected CompletableFuture<Node<NodeReference>> fetchNodeInternal(@Nonnull fina
7979
if (valueBytes == null) {
8080
throw new IllegalStateException("cannot fetch node");
8181
}
82-
return nodeFromRaw(primaryKey, keyBytes, valueBytes);
82+
return nodeFromRaw(layer, primaryKey, keyBytes, valueBytes);
8383
});
8484
}
8585

8686
@Nonnull
87-
private Node<NodeReference> nodeFromRaw(final @Nonnull Tuple primaryKey, @Nonnull final byte[] keyBytes,
88-
@Nonnull final byte[] valueBytes) {
87+
private Node<NodeReference> nodeFromRaw(final int layer, final @Nonnull Tuple primaryKey,
88+
@Nonnull final byte[] keyBytes, @Nonnull final byte[] valueBytes) {
8989
final Tuple nodeTuple = Tuple.fromBytes(valueBytes);
9090
final Node<NodeReference> node = nodeFromTuples(primaryKey, nodeTuple);
9191
final OnReadListener onReadListener = getOnReadListener();
92-
onReadListener.onNodeRead(node);
93-
onReadListener.onKeyValueRead(keyBytes, valueBytes);
92+
onReadListener.onNodeRead(layer, node);
93+
onReadListener.onKeyValueRead(layer, keyBytes, valueBytes);
9494
return node;
9595
}
9696

@@ -139,16 +139,17 @@ public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull f
139139
for (final NodeReference neighborReference : neighbors) {
140140
neighborItems.add(neighborReference.getPrimaryKey());
141141
}
142-
if (logger.isDebugEnabled()) {
143-
logger.debug("written neighbors of primaryKey={}, oldSize={}, newSize={}", node.getPrimaryKey(),
144-
node.getNeighbors().size(), neighborItems.size());
145-
}
146-
147142
nodeItems.add(Tuple.fromList(neighborItems));
143+
148144
final Tuple nodeTuple = Tuple.fromList(nodeItems);
149145

150146
transaction.set(key, nodeTuple.pack());
151147
getOnWriteListener().onNodeWritten(layer, node);
148+
149+
if (logger.isDebugEnabled()) {
150+
logger.debug("written neighbors of primaryKey={}, oldSize={}, newSize={}", node.getPrimaryKey(),
151+
node.getNeighbors().size(), neighborItems.size());
152+
}
152153
}
153154

154155
public Iterable<Node<NodeReference>> scanLayer(@Nonnull final ReadTransaction readTransaction, int layer,
@@ -166,7 +167,7 @@ public Iterable<Node<NodeReference>> scanLayer(@Nonnull final ReadTransaction re
166167
final byte[] key = keyValue.getKey();
167168
final byte[] value = keyValue.getValue();
168169
final Tuple primaryKey = getDataSubspace().unpack(key).getNestedTuple(1);
169-
return nodeFromRaw(primaryKey, key, value);
170+
return nodeFromRaw(layer, primaryKey, key, value);
170171
});
171172
}
172173
}

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
import org.slf4j.LoggerFactory;
4242

4343
import javax.annotation.Nonnull;
44+
import java.util.ArrayList;
4445
import java.util.Collection;
46+
import java.util.Collections;
4547
import java.util.Comparator;
4648
import java.util.List;
4749
import java.util.Map;
@@ -74,7 +76,7 @@ public class HNSW {
7476
public static final int DEFAULT_M = 16;
7577
public static final int DEFAULT_M_MAX = DEFAULT_M;
7678
public static final int DEFAULT_M_MAX_0 = 2 * DEFAULT_M;
77-
public static final int DEFAULT_EF_SEARCH = 64;
79+
public static final int DEFAULT_EF_SEARCH = 100;
7880
public static final int DEFAULT_EF_CONSTRUCTION = 200;
7981
public static final boolean DEFAULT_EXTEND_CANDIDATES = false;
8082
public static final boolean DEFAULT_KEEP_PRUNED_CONNECTIONS = false;
@@ -403,6 +405,7 @@ public OnReadListener getOnReadListener() {
403405
@SuppressWarnings("checkstyle:MethodName") // method name introduced by paper
404406
@Nonnull
405407
public CompletableFuture<? extends List<? extends NodeReferenceAndNode<? extends NodeReference>>> kNearestNeighborsSearch(@Nonnull final ReadTransaction readTransaction,
408+
final int k,
406409
final int efSearch,
407410
@Nonnull final Vector<Half> queryVector) {
408411
return StorageAdapter.fetchEntryNodeReference(readTransaction, getSubspace(), getOnReadListener())
@@ -450,7 +453,17 @@ public CompletableFuture<? extends List<? extends NodeReferenceAndNode<? extends
450453

451454
return searchLayer(storageAdapter, readTransaction,
452455
ImmutableList.of(nodeReference), 0, efSearch,
453-
Maps.newConcurrentMap(), queryVector);
456+
Maps.newConcurrentMap(), queryVector)
457+
.thenApply(searchResult -> {
458+
// reverse the original deque
459+
final int size = searchResult.size();
460+
final int start = Math.max(0, size - k);
461+
462+
final ArrayList<? extends NodeReferenceAndNode<?>> topKReversed =
463+
Lists.newArrayList(searchResult.subList(start, size));
464+
Collections.reverse(topKReversed);
465+
return topKReversed;
466+
});
454467
});
455468
}
456469

@@ -579,14 +592,13 @@ private <N extends NodeReference> CompletableFuture<List<NodeReferenceAndNode<N>
579592
}).thenCompose(ignored ->
580593
fetchSomeNodesIfNotCached(storageAdapter, readTransaction, layer, nearestNeighbors, nodeCache))
581594
.thenApply(searchResult -> {
582-
debug(l -> {
583-
l.debug("searched layer={} for efSearch={} with result=={}", layer, efSearch,
584-
searchResult.stream()
585-
.map(nodeReferenceAndNode ->
586-
"(primaryKey=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getPrimaryKey() +
587-
",distance=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance() + ")")
588-
.collect(Collectors.joining(",")));
589-
});
595+
debug(l ->
596+
l.debug("searched layer={} for efSearch={} with result=={}", layer, efSearch,
597+
searchResult.stream()
598+
.map(nodeReferenceAndNode ->
599+
"(primaryKey=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getPrimaryKey() +
600+
",distance=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance() + ")")
601+
.collect(Collectors.joining(","))));
590602
return searchResult;
591603
});
592604
}
@@ -1093,6 +1105,12 @@ private int insertionLayer(@Nonnull final Random random) {
10931105
return (int) Math.floor(-Math.log(u) * lambda);
10941106
}
10951107

1108+
private void info(@Nonnull final Consumer<Logger> loggerConsumer) {
1109+
if (logger.isInfoEnabled()) {
1110+
loggerConsumer.accept(logger);
1111+
}
1112+
}
1113+
10961114
private void debug(@Nonnull final Consumer<Logger> loggerConsumer) {
10971115
if (logger.isDebugEnabled()) {
10981116
loggerConsumer.accept(logger);

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,29 +70,29 @@ protected CompletableFuture<Node<NodeReferenceWithVector>> fetchNodeInternal(@No
7070

7171
return AsyncUtil.collect(readTransaction.getRange(Range.startsWith(rangeKey),
7272
ReadTransaction.ROW_LIMIT_UNLIMITED, false, StreamingMode.WANT_ALL), readTransaction.getExecutor())
73-
.thenApply(keyValues -> nodeFromRaw(primaryKey, keyValues));
73+
.thenApply(keyValues -> nodeFromRaw(layer, primaryKey, keyValues));
7474
}
7575

7676
@Nonnull
77-
private Node<NodeReferenceWithVector> nodeFromRaw(final @Nonnull Tuple primaryKey, final List<KeyValue> keyValues) {
77+
private Node<NodeReferenceWithVector> nodeFromRaw(final int layer, final @Nonnull Tuple primaryKey, final List<KeyValue> keyValues) {
7878
final OnReadListener onReadListener = getOnReadListener();
7979

8080
final ImmutableList.Builder<NodeReferenceWithVector> nodeReferencesWithVectorBuilder = ImmutableList.builder();
8181
for (final KeyValue keyValue : keyValues) {
82-
nodeReferencesWithVectorBuilder.add(neighborFromRaw(keyValue.getKey(), keyValue.getValue()));
82+
nodeReferencesWithVectorBuilder.add(neighborFromRaw(layer, keyValue.getKey(), keyValue.getValue()));
8383
}
8484

8585
final Node<NodeReferenceWithVector> node =
8686
getNodeFactory().create(primaryKey, null, nodeReferencesWithVectorBuilder.build());
87-
onReadListener.onNodeRead(node);
87+
onReadListener.onNodeRead(layer, node);
8888
return node;
8989
}
9090

9191
@Nonnull
92-
private NodeReferenceWithVector neighborFromRaw(final @Nonnull byte[] key, final byte[] value) {
92+
private NodeReferenceWithVector neighborFromRaw(final int layer, final @Nonnull byte[] key, final byte[] value) {
9393
final OnReadListener onReadListener = getOnReadListener();
9494

95-
onReadListener.onKeyValueRead(key, value);
95+
onReadListener.onKeyValueRead(layer, key, value);
9696
final Tuple neighborKeyTuple = getDataSubspace().unpack(key);
9797
final Tuple neighborValueTuple = Tuple.fromBytes(value);
9898

@@ -153,7 +153,7 @@ public Iterable<Node<NodeReferenceWithVector>> scanLayer(@Nonnull final ReadTran
153153
ImmutableList.Builder<NodeReferenceWithVector> neighborsBuilder = ImmutableList.builder();
154154
for (final KeyValue item: itemsIterable) {
155155
final NodeReferenceWithVector neighbor =
156-
neighborFromRaw(item.getKey(), item.getValue());
156+
neighborFromRaw(layer, item.getKey(), item.getValue());
157157
final Tuple primaryKeyFromNodeReference = neighbor.getPrimaryKey();
158158
if (nodePrimaryKey == null) {
159159
nodePrimaryKey = primaryKeyFromNodeReference;

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,16 @@ public interface OnReadListener {
3030
OnReadListener NOOP = new OnReadListener() {
3131
};
3232

33-
default void onSlotIndexEntryRead(@Nonnull final byte[] key) {
34-
// nothing
35-
}
36-
3733
default <N extends NodeReference> CompletableFuture<Node<N>> onAsyncRead(@Nonnull CompletableFuture<Node<N>> future) {
3834
return future;
3935
}
4036

41-
default void onNodeRead(@Nonnull Node<? extends NodeReference> node) {
37+
default void onNodeRead(int layer, @Nonnull Node<? extends NodeReference> node) {
4238
// nothing
4339
}
4440

45-
default void onKeyValueRead(@Nonnull byte[] key,
41+
default void onKeyValueRead(int layer,
42+
@Nonnull byte[] key,
4643
@Nonnull byte[] value) {
4744
// nothing
4845
}

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ static CompletableFuture<EntryNodeReference> fetchEntryNodeReference(@Nonnull fi
107107
if (valueBytes == null) {
108108
return null; // not a single node in the index
109109
}
110-
onReadListener.onKeyValueRead(key, valueBytes);
110+
onReadListener.onKeyValueRead(-1, key, valueBytes);
111111

112112
final Tuple entryTuple = Tuple.fromBytes(valueBytes);
113113
final int lMax = (int)entryTuple.getLong(0);

fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import com.apple.test.Tags;
3131
import com.christianheina.langx.half4j.Half;
3232
import com.google.common.collect.ImmutableList;
33+
import com.google.common.collect.Maps;
3334
import org.assertj.core.util.Lists;
3435
import org.junit.jupiter.api.Assertions;
3536
import org.junit.jupiter.api.BeforeEach;
@@ -47,7 +48,10 @@
4748
import java.io.IOException;
4849
import java.util.ArrayList;
4950
import java.util.Comparator;
51+
import java.util.List;
52+
import java.util.Map;
5053
import java.util.Random;
54+
import java.util.concurrent.TimeUnit;
5155
import java.util.concurrent.atomic.AtomicLong;
5256

5357
/**
@@ -151,14 +155,48 @@ public void testInliningSerialization() {
151155
@Test
152156
public void testBasicInsert() {
153157
final Random random = new Random(0);
154-
final AtomicLong nextNodeId = new AtomicLong(0L);
155-
final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool());
158+
final AtomicLong nextNodeIdAtomic = new AtomicLong(0L);
156159

157-
db.run(tr -> {
158-
for (int i = 0; i < 10; i ++) {
159-
hnsw.insert(tr, createNextPrimaryKey(nextNodeId), createRandomVector(random, 728)).join();
160+
final TestOnReadListener onReadListener = new TestOnReadListener();
161+
162+
final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(),
163+
HNSW.DEFAULT_CONFIG.toBuilder().setMetric(Metric.COSINE_METRIC).setEfConstruction(34).setM(16).setMMax(16).setMMax0(32).build(),
164+
OnWriteListener.NOOP, onReadListener);
165+
166+
for (int i = 0; i < 10000;) {
167+
i += basicInsertBatch(hnsw, random, 100, nextNodeIdAtomic, onReadListener);
168+
}
169+
170+
onReadListener.reset();
171+
final long beginTs = System.nanoTime();
172+
final List<? extends NodeReferenceAndNode<?>> result =
173+
db.run(tr -> hnsw.kNearestNeighborsSearch(tr, 10, 20, createRandomVector(random, 768)).join());
174+
final long endTs = System.nanoTime();
175+
176+
for (NodeReferenceAndNode<?> nodeReferenceAndNode : result) {
177+
final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance();
178+
logger.info("nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0),
179+
nodeReferenceWithDistance.getDistance());
180+
}
181+
System.out.println(onReadListener.getNodeCountByLayer());
182+
System.out.println(onReadListener.getBytesReadByLayer());
183+
184+
logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs));
185+
}
186+
187+
private int basicInsertBatch(@Nonnull final HNSW hnsw, @Nonnull final Random random, final int batchSize,
188+
@Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener) {
189+
return db.run(tr -> {
190+
onReadListener.reset();
191+
final long nextNodeId = nextNodeIdAtomic.get();
192+
final long beginTs = System.nanoTime();
193+
for (int i = 0; i < batchSize; i ++) {
194+
hnsw.insert(tr, createNextPrimaryKey(nextNodeIdAtomic), createRandomVector(random, 768)).join();
160195
}
161-
return null;
196+
final long endTs = System.nanoTime();
197+
logger.info("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", batchSize, nextNodeId,
198+
TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), onReadListener.getNodeCountByLayer(), onReadListener.getSumMByLayer());
199+
return batchSize;
162200
});
163201
}
164202

@@ -185,6 +223,18 @@ public void testBasicInsertAndScanLayer() throws Exception {
185223
}
186224
}
187225

226+
@Test
227+
public void testManyVectors() {
228+
final Random random = new Random();
229+
for (long l = 0L; l < 3000000; l ++) {
230+
final Vector.HalfVector randomVector = createRandomVector(random, 768);
231+
final Tuple vectorTuple = StorageAdapter.tupleFromVector(randomVector);
232+
final Vector<Half> roundTripVector = StorageAdapter.vectorFromTuple(vectorTuple);
233+
Vector.comparativeDistance(Metric.EuclideanMetric.EUCLIDEAN_METRIC, randomVector, roundTripVector);
234+
Assertions.assertEquals(randomVector, roundTripVector);
235+
}
236+
}
237+
188238
private boolean dumpLayer(final HNSW hnsw, final int layer) throws IOException {
189239
final String verticesFileName = "/Users/nseemann/Downloads/vertices-" + layer + ".csv";
190240
final String edgesFileName = "/Users/nseemann/Downloads/edges-" + layer + ".csv";
@@ -282,4 +332,46 @@ private Vector.HalfVector createRandomVector(@Nonnull final Random random, final
282332
}
283333
return new Vector.HalfVector(components);
284334
}
335+
336+
private static class TestOnReadListener implements OnReadListener {
337+
final Map<Integer, Long> nodeCountByLayer;
338+
final Map<Integer, Long> sumMByLayer;
339+
final Map<Integer, Long> bytesReadByLayer;
340+
341+
public TestOnReadListener() {
342+
this.nodeCountByLayer = Maps.newConcurrentMap();
343+
this.sumMByLayer = Maps.newConcurrentMap();
344+
this.bytesReadByLayer = Maps.newConcurrentMap();
345+
}
346+
347+
public Map<Integer, Long> getNodeCountByLayer() {
348+
return nodeCountByLayer;
349+
}
350+
351+
public Map<Integer, Long> getBytesReadByLayer() {
352+
return bytesReadByLayer;
353+
}
354+
355+
public Map<Integer, Long> getSumMByLayer() {
356+
return sumMByLayer;
357+
}
358+
359+
public void reset() {
360+
nodeCountByLayer.clear();
361+
bytesReadByLayer.clear();
362+
sumMByLayer.clear();
363+
}
364+
365+
@Override
366+
public void onNodeRead(final int layer, @Nonnull final Node<? extends NodeReference> node) {
367+
nodeCountByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + 1L);
368+
sumMByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + node.getNeighbors().size());
369+
}
370+
371+
@Override
372+
public void onKeyValueRead(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) {
373+
bytesReadByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) +
374+
key.length + value.length);
375+
}
376+
}
285377
}

gradle/scripts/log4j-test.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ appender.console.name = STDOUT
2626
appender.console.layout.type = PatternLayout
2727
appender.console.layout.pattern = %d [%level] %logger{1.} - %m %X%n%ex{full}
2828

29-
rootLogger.level = debug
29+
rootLogger.level = info
3030
rootLogger.appenderRefs = stdout
3131
rootLogger.appenderRef.stdout.ref = STDOUT
3232

0 commit comments

Comments
 (0)