Skip to content

Commit 1170856

Browse files
committed
better test helpers
1 parent f0d4f06 commit 1170856

File tree

2 files changed

+211
-28
lines changed

2 files changed

+211
-28
lines changed

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,15 @@
4141
import org.slf4j.LoggerFactory;
4242

4343
import javax.annotation.Nonnull;
44-
import java.util.ArrayList;
4544
import java.util.Collection;
46-
import java.util.Collections;
4745
import java.util.Comparator;
4846
import java.util.List;
4947
import java.util.Map;
5048
import java.util.Objects;
5149
import java.util.Queue;
5250
import java.util.Random;
5351
import java.util.Set;
52+
import java.util.TreeSet;
5453
import java.util.concurrent.CompletableFuture;
5554
import java.util.concurrent.Executor;
5655
import java.util.concurrent.PriorityBlockingQueue;
@@ -455,14 +454,23 @@ public CompletableFuture<? extends List<? extends NodeReferenceAndNode<? extends
455454
ImmutableList.of(nodeReference), 0, efSearch,
456455
Maps.newConcurrentMap(), queryVector)
457456
.thenApply(searchResult -> {
458-
// reverse the original deque
459-
final int size = searchResult.size();
460-
final int start = Math.max(0, size - k);
461-
462-
final ArrayList<? extends NodeReferenceAndNode<?>> topKReversed =
463-
Lists.newArrayList(searchResult.subList(start, size));
464-
Collections.reverse(topKReversed);
465-
return topKReversed;
457+
// reverse the original queue
458+
final TreeSet<NodeReferenceAndNode<? extends NodeReference>> sortedTopK =
459+
new TreeSet<>(
460+
Comparator.comparing(nodeReferenceAndNode ->
461+
nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance()));
462+
463+
for (final NodeReferenceAndNode<?> nodeReferenceAndNode : searchResult) {
464+
if (sortedTopK.size() < k || sortedTopK.last().getNodeReferenceWithDistance().getDistance() >
465+
nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance()) {
466+
sortedTopK.add(nodeReferenceAndNode);
467+
}
468+
469+
if (sortedTopK.size() > k) {
470+
sortedTopK.remove(sortedTopK.last());
471+
}
472+
}
473+
return ImmutableList.copyOf(sortedTopK);
466474
});
467475
});
468476
}
@@ -592,13 +600,12 @@ private <N extends NodeReference> CompletableFuture<List<NodeReferenceAndNode<N>
592600
}).thenCompose(ignored ->
593601
fetchSomeNodesIfNotCached(storageAdapter, readTransaction, layer, nearestNeighbors, nodeCache))
594602
.thenApply(searchResult -> {
595-
debug(l ->
596-
l.debug("searched layer={} for efSearch={} with result=={}", layer, efSearch,
597-
searchResult.stream()
598-
.map(nodeReferenceAndNode ->
599-
"(primaryKey=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getPrimaryKey() +
600-
",distance=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance() + ")")
601-
.collect(Collectors.joining(","))));
603+
debug(l -> l.debug("searched layer={} for efSearch={} with result=={}", layer, efSearch,
604+
searchResult.stream()
605+
.map(nodeReferenceAndNode ->
606+
"(primaryKey=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getPrimaryKey() +
607+
",distance=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance() + ")")
608+
.collect(Collectors.joining(","))));
602609
return searchResult;
603610
});
604611
}

fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java

Lines changed: 187 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
import com.apple.foundationdb.Database;
2424
import com.apple.foundationdb.Transaction;
25+
import com.apple.foundationdb.async.AsyncUtil;
26+
import com.apple.foundationdb.async.hnsw.Vector.HalfVector;
2527
import com.apple.foundationdb.async.rtree.RTree;
2628
import com.apple.foundationdb.test.TestDatabaseExtension;
2729
import com.apple.foundationdb.test.TestExecutors;
@@ -36,23 +38,34 @@
3638
import org.junit.jupiter.api.BeforeEach;
3739
import org.junit.jupiter.api.Tag;
3840
import org.junit.jupiter.api.Test;
41+
import org.junit.jupiter.api.Timeout;
3942
import org.junit.jupiter.api.extension.RegisterExtension;
4043
import org.junit.jupiter.api.parallel.Execution;
4144
import org.junit.jupiter.api.parallel.ExecutionMode;
45+
import org.junit.jupiter.params.ParameterizedTest;
46+
import org.junit.jupiter.params.provider.ValueSource;
4247
import org.slf4j.Logger;
4348
import org.slf4j.LoggerFactory;
4449

4550
import javax.annotation.Nonnull;
51+
import java.io.BufferedReader;
4652
import java.io.BufferedWriter;
53+
import java.io.FileReader;
4754
import java.io.FileWriter;
4855
import java.io.IOException;
4956
import java.util.ArrayList;
5057
import java.util.Comparator;
5158
import java.util.List;
5259
import java.util.Map;
60+
import java.util.NavigableSet;
61+
import java.util.Objects;
5362
import java.util.Random;
63+
import java.util.concurrent.CompletableFuture;
64+
import java.util.concurrent.ConcurrentSkipListSet;
5465
import java.util.concurrent.TimeUnit;
5566
import java.util.concurrent.atomic.AtomicLong;
67+
import java.util.concurrent.atomic.AtomicReference;
68+
import java.util.function.Function;
5669

5770
/**
5871
* Tests testing insert/update/deletes of data into/in/from {@link RTree}s.
@@ -159,18 +172,20 @@ public void testBasicInsert() {
159172

160173
final TestOnReadListener onReadListener = new TestOnReadListener();
161174

175+
final int dimensions = 128;
162176
final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(),
163-
HNSW.DEFAULT_CONFIG.toBuilder().setMetric(Metric.COSINE_METRIC).setEfConstruction(34).setM(16).setMMax(16).setMMax0(32).build(),
177+
HNSW.DEFAULT_CONFIG.toBuilder().setMetric(Metric.EUCLIDEAN_METRIC).setM(32).setMMax(32).setMMax0(64).build(),
164178
OnWriteListener.NOOP, onReadListener);
165179

166-
for (int i = 0; i < 10000;) {
167-
i += basicInsertBatch(hnsw, random, 100, nextNodeIdAtomic, onReadListener);
180+
for (int i = 0; i < 1000;) {
181+
i += basicInsertBatch(100, nextNodeIdAtomic, onReadListener,
182+
tr -> hnsw.insert(tr, createNextPrimaryKey(nextNodeIdAtomic), createRandomVector(random, dimensions)));
168183
}
169184

170185
onReadListener.reset();
171186
final long beginTs = System.nanoTime();
172187
final List<? extends NodeReferenceAndNode<?>> result =
173-
db.run(tr -> hnsw.kNearestNeighborsSearch(tr, 10, 20, createRandomVector(random, 768)).join());
188+
db.run(tr -> hnsw.kNearestNeighborsSearch(tr, 10, 100, createRandomVector(random, dimensions)).join());
174189
final long endTs = System.nanoTime();
175190

176191
for (NodeReferenceAndNode<?> nodeReferenceAndNode : result) {
@@ -184,14 +199,15 @@ public void testBasicInsert() {
184199
logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs));
185200
}
186201

187-
private int basicInsertBatch(@Nonnull final HNSW hnsw, @Nonnull final Random random, final int batchSize,
188-
@Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener) {
202+
private int basicInsertBatch(final int batchSize,
203+
@Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener,
204+
@Nonnull final Function<Transaction, CompletableFuture<Void>> insertFunction) {
189205
return db.run(tr -> {
190206
onReadListener.reset();
191207
final long nextNodeId = nextNodeIdAtomic.get();
192208
final long beginTs = System.nanoTime();
193209
for (int i = 0; i < batchSize; i ++) {
194-
hnsw.insert(tr, createNextPrimaryKey(nextNodeIdAtomic), createRandomVector(random, 768)).join();
210+
insertFunction.apply(tr).join();
195211
}
196212
final long endTs = System.nanoTime();
197213
logger.info("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", batchSize, nextNodeId,
@@ -200,6 +216,91 @@ private int basicInsertBatch(@Nonnull final HNSW hnsw, @Nonnull final Random ran
200216
});
201217
}
202218

219+
@Test
220+
@Timeout(value = 150, unit = TimeUnit.MINUTES)
221+
public void testSIFTInsert10k() throws Exception {
222+
final Metric metric = Metric.EUCLIDEAN_METRIC;
223+
final int k = 10;
224+
final AtomicLong nextNodeIdAtomic = new AtomicLong(0L);
225+
226+
final TestOnReadListener onReadListener = new TestOnReadListener();
227+
228+
final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(),
229+
HNSW.DEFAULT_CONFIG.toBuilder().setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(),
230+
OnWriteListener.NOOP, onReadListener);
231+
232+
final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv";
233+
final int dimensions = 128;
234+
235+
final AtomicReference<HalfVector> queryVectorAtomic = new AtomicReference<>();
236+
final NavigableSet<NodeReferenceWithDistance> trueResults = new ConcurrentSkipListSet<>(
237+
Comparator.comparing(NodeReferenceWithDistance::getDistance));
238+
239+
try (BufferedReader br = new BufferedReader(new FileReader(tsvFile))) {
240+
for (int i = 0; i < 10000;) {
241+
i += basicInsertBatch(100, nextNodeIdAtomic, onReadListener,
242+
tr -> {
243+
final String line;
244+
try {
245+
line = br.readLine();
246+
} catch (IOException e) {
247+
throw new RuntimeException(e);
248+
}
249+
250+
final String[] values = Objects.requireNonNull(line).split("\t");
251+
Assertions.assertEquals(dimensions, values.length);
252+
final Half[] halfs = new Half[dimensions];
253+
254+
for (int c = 0; c < values.length; c++) {
255+
final String value = values[c];
256+
halfs[c] = HNSWHelpers.halfValueOf(Double.parseDouble(value));
257+
}
258+
final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic);
259+
final HalfVector currentVector = new HalfVector(halfs);
260+
final HalfVector queryVector = queryVectorAtomic.get();
261+
if (queryVector == null) {
262+
queryVectorAtomic.set(currentVector);
263+
return AsyncUtil.DONE;
264+
} else {
265+
final double currentDistance =
266+
Vector.comparativeDistance(metric, currentVector, queryVector);
267+
if (trueResults.size() < k || trueResults.last().getDistance() > currentDistance) {
268+
trueResults.add(
269+
new NodeReferenceWithDistance(currentPrimaryKey, currentVector,
270+
Vector.comparativeDistance(metric, currentVector, queryVector)));
271+
}
272+
if (trueResults.size() > k) {
273+
trueResults.remove(trueResults.last());
274+
}
275+
return hnsw.insert(tr, currentPrimaryKey, currentVector);
276+
}
277+
});
278+
}
279+
}
280+
281+
onReadListener.reset();
282+
final long beginTs = System.nanoTime();
283+
final List<? extends NodeReferenceAndNode<?>> results =
284+
db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVectorAtomic.get()).join());
285+
final long endTs = System.nanoTime();
286+
287+
for (NodeReferenceAndNode<?> nodeReferenceAndNode : results) {
288+
final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance();
289+
logger.info("retrieved result nodeId = {} at distance= {}", nodeReferenceWithDistance.getPrimaryKey().getLong(0),
290+
nodeReferenceWithDistance.getDistance());
291+
}
292+
293+
for (final NodeReferenceWithDistance nodeReferenceWithDistance : trueResults) {
294+
logger.info("true result nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0),
295+
nodeReferenceWithDistance.getDistance());
296+
}
297+
298+
System.out.println(onReadListener.getNodeCountByLayer());
299+
System.out.println(onReadListener.getBytesReadByLayer());
300+
301+
logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs));
302+
}
303+
203304
@Test
204305
public void testBasicInsertAndScanLayer() throws Exception {
205306
final Random random = new Random(0);
@@ -224,17 +325,92 @@ public void testBasicInsertAndScanLayer() throws Exception {
224325
}
225326

226327
@Test
227-
public void testManyVectors() {
328+
public void testManyRandomVectors() {
228329
final Random random = new Random();
229330
for (long l = 0L; l < 3000000; l ++) {
230-
final Vector.HalfVector randomVector = createRandomVector(random, 768);
331+
final HalfVector randomVector = createRandomVector(random, 768);
231332
final Tuple vectorTuple = StorageAdapter.tupleFromVector(randomVector);
232333
final Vector<Half> roundTripVector = StorageAdapter.vectorFromTuple(vectorTuple);
233334
Vector.comparativeDistance(Metric.EuclideanMetric.EUCLIDEAN_METRIC, randomVector, roundTripVector);
234335
Assertions.assertEquals(randomVector, roundTripVector);
235336
}
236337
}
237338

339+
@Test
340+
@Timeout(value = 150, unit = TimeUnit.MINUTES)
341+
public void testSIFTVectors() throws Exception {
342+
final AtomicLong nextNodeIdAtomic = new AtomicLong(0L);
343+
344+
final TestOnReadListener onReadListener = new TestOnReadListener();
345+
346+
final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(),
347+
HNSW.DEFAULT_CONFIG.toBuilder().setMetric(Metric.EUCLIDEAN_METRIC).setM(32).setMMax(32).setMMax0(64).build(),
348+
OnWriteListener.NOOP, onReadListener);
349+
350+
351+
final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv";
352+
final int dimensions = 128;
353+
final var referenceVector = createRandomVector(new Random(0), dimensions);
354+
long count = 0L;
355+
double mean = 0.0d;
356+
double mean2 = 0.0d;
357+
358+
try (BufferedReader br = new BufferedReader(new FileReader(tsvFile))) {
359+
for (int i = 0; i < 100_000; i ++) {
360+
final String line;
361+
try {
362+
line = br.readLine();
363+
} catch (IOException e) {
364+
throw new RuntimeException(e);
365+
}
366+
367+
final String[] values = Objects.requireNonNull(line).split("\t");
368+
Assertions.assertEquals(dimensions, values.length);
369+
final Half[] halfs = new Half[dimensions];
370+
for (int c = 0; c < values.length; c++) {
371+
final String value = values[c];
372+
halfs[c] = HNSWHelpers.halfValueOf(Double.parseDouble(value));
373+
}
374+
final HalfVector newVector = new HalfVector(halfs);
375+
final double distance = Vector.comparativeDistance(Metric.EUCLIDEAN_METRIC, referenceVector, newVector);
376+
count++;
377+
final double delta = distance - mean;
378+
mean += delta / count;
379+
final double delta2 = distance - mean;
380+
mean2 += delta * delta2;
381+
}
382+
}
383+
final double sampleVariance = mean2 / (count - 1);
384+
final double standardDeviation = Math.sqrt(sampleVariance);
385+
logger.info("mean={}, sample_variance={}, stddeviation={}, cv={}", mean, sampleVariance, standardDeviation,
386+
standardDeviation / mean);
387+
}
388+
389+
390+
@ParameterizedTest
391+
@ValueSource(ints = {2, 3, 10, 100, 768})
392+
public void testManyVectorsStandardDeviation(final int dimensionality) {
393+
final Random random = new Random();
394+
final Metric metric = Metric.EuclideanMetric.EUCLIDEAN_METRIC;
395+
long count = 0L;
396+
double mean = 0.0d;
397+
double mean2 = 0.0d;
398+
for (long i = 0L; i < 100000; i ++) {
399+
final HalfVector vector1 = createRandomVector(random, dimensionality);
400+
final HalfVector vector2 = createRandomVector(random, dimensionality);
401+
final double distance = Vector.comparativeDistance(metric, vector1, vector2);
402+
count = i + 1;
403+
final double delta = distance - mean;
404+
mean += delta / count;
405+
final double delta2 = distance - mean;
406+
mean2 += delta * delta2;
407+
}
408+
final double sampleVariance = mean2 / (count - 1);
409+
final double standardDeviation = Math.sqrt(sampleVariance);
410+
logger.info("mean={}, sample_variance={}, stddeviation={}, cv={}", mean, sampleVariance, standardDeviation,
411+
standardDeviation / mean);
412+
}
413+
238414
private boolean dumpLayer(final HNSW hnsw, final int layer) throws IOException {
239415
final String verticesFileName = "/Users/nseemann/Downloads/vertices-" + layer + ".csv";
240416
final String edgesFileName = "/Users/nseemann/Downloads/edges-" + layer + ".csv";
@@ -324,13 +500,13 @@ private static Tuple createNextPrimaryKey(@Nonnull final AtomicLong nextIdAtomic
324500
}
325501

326502
@Nonnull
327-
private Vector.HalfVector createRandomVector(@Nonnull final Random random, final int dimensionality) {
503+
private HalfVector createRandomVector(@Nonnull final Random random, final int dimensionality) {
328504
final Half[] components = new Half[dimensionality];
329505
for (int d = 0; d < dimensionality; d ++) {
330506
// don't ask
331507
components[d] = HNSWHelpers.halfValueOf(random.nextDouble());
332508
}
333-
return new Vector.HalfVector(components);
509+
return new HalfVector(components);
334510
}
335511

336512
private static class TestOnReadListener implements OnReadListener {

0 commit comments

Comments
 (0)