22
22
23
23
import com .apple .foundationdb .Database ;
24
24
import com .apple .foundationdb .Transaction ;
25
+ import com .apple .foundationdb .async .AsyncUtil ;
26
+ import com .apple .foundationdb .async .hnsw .Vector .HalfVector ;
25
27
import com .apple .foundationdb .async .rtree .RTree ;
26
28
import com .apple .foundationdb .test .TestDatabaseExtension ;
27
29
import com .apple .foundationdb .test .TestExecutors ;
36
38
import org .junit .jupiter .api .BeforeEach ;
37
39
import org .junit .jupiter .api .Tag ;
38
40
import org .junit .jupiter .api .Test ;
41
+ import org .junit .jupiter .api .Timeout ;
39
42
import org .junit .jupiter .api .extension .RegisterExtension ;
40
43
import org .junit .jupiter .api .parallel .Execution ;
41
44
import org .junit .jupiter .api .parallel .ExecutionMode ;
45
+ import org .junit .jupiter .params .ParameterizedTest ;
46
+ import org .junit .jupiter .params .provider .ValueSource ;
42
47
import org .slf4j .Logger ;
43
48
import org .slf4j .LoggerFactory ;
44
49
45
50
import javax .annotation .Nonnull ;
51
+ import java .io .BufferedReader ;
46
52
import java .io .BufferedWriter ;
53
+ import java .io .FileReader ;
47
54
import java .io .FileWriter ;
48
55
import java .io .IOException ;
49
56
import java .util .ArrayList ;
50
57
import java .util .Comparator ;
51
58
import java .util .List ;
52
59
import java .util .Map ;
60
+ import java .util .NavigableSet ;
61
+ import java .util .Objects ;
53
62
import java .util .Random ;
63
+ import java .util .concurrent .CompletableFuture ;
64
+ import java .util .concurrent .ConcurrentSkipListSet ;
54
65
import java .util .concurrent .TimeUnit ;
55
66
import java .util .concurrent .atomic .AtomicLong ;
67
+ import java .util .concurrent .atomic .AtomicReference ;
68
+ import java .util .function .Function ;
56
69
57
70
/**
58
71
* Tests testing insert/update/deletes of data into/in/from {@link RTree}s.
@@ -159,18 +172,20 @@ public void testBasicInsert() {
159
172
160
173
final TestOnReadListener onReadListener = new TestOnReadListener ();
161
174
175
+ final int dimensions = 128 ;
162
176
final HNSW hnsw = new HNSW (rtSubspace .getSubspace (), TestExecutors .defaultThreadPool (),
163
- HNSW .DEFAULT_CONFIG .toBuilder ().setMetric (Metric .COSINE_METRIC ). setEfConstruction ( 34 ). setM (16 ).setMMax (16 ).setMMax0 (32 ).build (),
177
+ HNSW .DEFAULT_CONFIG .toBuilder ().setMetric (Metric .EUCLIDEAN_METRIC ). setM (32 ).setMMax (32 ).setMMax0 (64 ).build (),
164
178
OnWriteListener .NOOP , onReadListener );
165
179
166
- for (int i = 0 ; i < 10000 ;) {
167
- i += basicInsertBatch (hnsw , random , 100 , nextNodeIdAtomic , onReadListener );
180
+ for (int i = 0 ; i < 1000 ;) {
181
+ i += basicInsertBatch (100 , nextNodeIdAtomic , onReadListener ,
182
+ tr -> hnsw .insert (tr , createNextPrimaryKey (nextNodeIdAtomic ), createRandomVector (random , dimensions )));
168
183
}
169
184
170
185
onReadListener .reset ();
171
186
final long beginTs = System .nanoTime ();
172
187
final List <? extends NodeReferenceAndNode <?>> result =
173
- db .run (tr -> hnsw .kNearestNeighborsSearch (tr , 10 , 20 , createRandomVector (random , 768 )).join ());
188
+ db .run (tr -> hnsw .kNearestNeighborsSearch (tr , 10 , 100 , createRandomVector (random , dimensions )).join ());
174
189
final long endTs = System .nanoTime ();
175
190
176
191
for (NodeReferenceAndNode <?> nodeReferenceAndNode : result ) {
@@ -184,14 +199,15 @@ public void testBasicInsert() {
184
199
logger .info ("search transaction took elapsedTime={}ms" , TimeUnit .NANOSECONDS .toMillis (endTs - beginTs ));
185
200
}
186
201
187
- private int basicInsertBatch (@ Nonnull final HNSW hnsw , @ Nonnull final Random random , final int batchSize ,
188
- @ Nonnull final AtomicLong nextNodeIdAtomic , @ Nonnull final TestOnReadListener onReadListener ) {
202
+ private int basicInsertBatch (final int batchSize ,
203
+ @ Nonnull final AtomicLong nextNodeIdAtomic , @ Nonnull final TestOnReadListener onReadListener ,
204
+ @ Nonnull final Function <Transaction , CompletableFuture <Void >> insertFunction ) {
189
205
return db .run (tr -> {
190
206
onReadListener .reset ();
191
207
final long nextNodeId = nextNodeIdAtomic .get ();
192
208
final long beginTs = System .nanoTime ();
193
209
for (int i = 0 ; i < batchSize ; i ++) {
194
- hnsw . insert (tr , createNextPrimaryKey ( nextNodeIdAtomic ), createRandomVector ( random , 768 ) ).join ();
210
+ insertFunction . apply (tr ).join ();
195
211
}
196
212
final long endTs = System .nanoTime ();
197
213
logger .info ("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}" , batchSize , nextNodeId ,
@@ -200,6 +216,91 @@ private int basicInsertBatch(@Nonnull final HNSW hnsw, @Nonnull final Random ran
200
216
});
201
217
}
202
218
219
+ @ Test
220
+ @ Timeout (value = 150 , unit = TimeUnit .MINUTES )
221
+ public void testSIFTInsert10k () throws Exception {
222
+ final Metric metric = Metric .EUCLIDEAN_METRIC ;
223
+ final int k = 10 ;
224
+ final AtomicLong nextNodeIdAtomic = new AtomicLong (0L );
225
+
226
+ final TestOnReadListener onReadListener = new TestOnReadListener ();
227
+
228
+ final HNSW hnsw = new HNSW (rtSubspace .getSubspace (), TestExecutors .defaultThreadPool (),
229
+ HNSW .DEFAULT_CONFIG .toBuilder ().setMetric (metric ).setM (32 ).setMMax (32 ).setMMax0 (64 ).build (),
230
+ OnWriteListener .NOOP , onReadListener );
231
+
232
+ final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv" ;
233
+ final int dimensions = 128 ;
234
+
235
+ final AtomicReference <HalfVector > queryVectorAtomic = new AtomicReference <>();
236
+ final NavigableSet <NodeReferenceWithDistance > trueResults = new ConcurrentSkipListSet <>(
237
+ Comparator .comparing (NodeReferenceWithDistance ::getDistance ));
238
+
239
+ try (BufferedReader br = new BufferedReader (new FileReader (tsvFile ))) {
240
+ for (int i = 0 ; i < 10000 ;) {
241
+ i += basicInsertBatch (100 , nextNodeIdAtomic , onReadListener ,
242
+ tr -> {
243
+ final String line ;
244
+ try {
245
+ line = br .readLine ();
246
+ } catch (IOException e ) {
247
+ throw new RuntimeException (e );
248
+ }
249
+
250
+ final String [] values = Objects .requireNonNull (line ).split ("\t " );
251
+ Assertions .assertEquals (dimensions , values .length );
252
+ final Half [] halfs = new Half [dimensions ];
253
+
254
+ for (int c = 0 ; c < values .length ; c ++) {
255
+ final String value = values [c ];
256
+ halfs [c ] = HNSWHelpers .halfValueOf (Double .parseDouble (value ));
257
+ }
258
+ final Tuple currentPrimaryKey = createNextPrimaryKey (nextNodeIdAtomic );
259
+ final HalfVector currentVector = new HalfVector (halfs );
260
+ final HalfVector queryVector = queryVectorAtomic .get ();
261
+ if (queryVector == null ) {
262
+ queryVectorAtomic .set (currentVector );
263
+ return AsyncUtil .DONE ;
264
+ } else {
265
+ final double currentDistance =
266
+ Vector .comparativeDistance (metric , currentVector , queryVector );
267
+ if (trueResults .size () < k || trueResults .last ().getDistance () > currentDistance ) {
268
+ trueResults .add (
269
+ new NodeReferenceWithDistance (currentPrimaryKey , currentVector ,
270
+ Vector .comparativeDistance (metric , currentVector , queryVector )));
271
+ }
272
+ if (trueResults .size () > k ) {
273
+ trueResults .remove (trueResults .last ());
274
+ }
275
+ return hnsw .insert (tr , currentPrimaryKey , currentVector );
276
+ }
277
+ });
278
+ }
279
+ }
280
+
281
+ onReadListener .reset ();
282
+ final long beginTs = System .nanoTime ();
283
+ final List <? extends NodeReferenceAndNode <?>> results =
284
+ db .run (tr -> hnsw .kNearestNeighborsSearch (tr , k , 100 , queryVectorAtomic .get ()).join ());
285
+ final long endTs = System .nanoTime ();
286
+
287
+ for (NodeReferenceAndNode <?> nodeReferenceAndNode : results ) {
288
+ final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode .getNodeReferenceWithDistance ();
289
+ logger .info ("retrieved result nodeId = {} at distance= {}" , nodeReferenceWithDistance .getPrimaryKey ().getLong (0 ),
290
+ nodeReferenceWithDistance .getDistance ());
291
+ }
292
+
293
+ for (final NodeReferenceWithDistance nodeReferenceWithDistance : trueResults ) {
294
+ logger .info ("true result nodeId ={} at distance={}" , nodeReferenceWithDistance .getPrimaryKey ().getLong (0 ),
295
+ nodeReferenceWithDistance .getDistance ());
296
+ }
297
+
298
+ System .out .println (onReadListener .getNodeCountByLayer ());
299
+ System .out .println (onReadListener .getBytesReadByLayer ());
300
+
301
+ logger .info ("search transaction took elapsedTime={}ms" , TimeUnit .NANOSECONDS .toMillis (endTs - beginTs ));
302
+ }
303
+
203
304
@ Test
204
305
public void testBasicInsertAndScanLayer () throws Exception {
205
306
final Random random = new Random (0 );
@@ -224,17 +325,92 @@ public void testBasicInsertAndScanLayer() throws Exception {
224
325
}
225
326
226
327
@ Test
227
- public void testManyVectors () {
328
+ public void testManyRandomVectors () {
228
329
final Random random = new Random ();
229
330
for (long l = 0L ; l < 3000000 ; l ++) {
230
- final Vector . HalfVector randomVector = createRandomVector (random , 768 );
331
+ final HalfVector randomVector = createRandomVector (random , 768 );
231
332
final Tuple vectorTuple = StorageAdapter .tupleFromVector (randomVector );
232
333
final Vector <Half > roundTripVector = StorageAdapter .vectorFromTuple (vectorTuple );
233
334
Vector .comparativeDistance (Metric .EuclideanMetric .EUCLIDEAN_METRIC , randomVector , roundTripVector );
234
335
Assertions .assertEquals (randomVector , roundTripVector );
235
336
}
236
337
}
237
338
339
+ @ Test
340
+ @ Timeout (value = 150 , unit = TimeUnit .MINUTES )
341
+ public void testSIFTVectors () throws Exception {
342
+ final AtomicLong nextNodeIdAtomic = new AtomicLong (0L );
343
+
344
+ final TestOnReadListener onReadListener = new TestOnReadListener ();
345
+
346
+ final HNSW hnsw = new HNSW (rtSubspace .getSubspace (), TestExecutors .defaultThreadPool (),
347
+ HNSW .DEFAULT_CONFIG .toBuilder ().setMetric (Metric .EUCLIDEAN_METRIC ).setM (32 ).setMMax (32 ).setMMax0 (64 ).build (),
348
+ OnWriteListener .NOOP , onReadListener );
349
+
350
+
351
+ final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv" ;
352
+ final int dimensions = 128 ;
353
+ final var referenceVector = createRandomVector (new Random (0 ), dimensions );
354
+ long count = 0L ;
355
+ double mean = 0.0d ;
356
+ double mean2 = 0.0d ;
357
+
358
+ try (BufferedReader br = new BufferedReader (new FileReader (tsvFile ))) {
359
+ for (int i = 0 ; i < 100_000 ; i ++) {
360
+ final String line ;
361
+ try {
362
+ line = br .readLine ();
363
+ } catch (IOException e ) {
364
+ throw new RuntimeException (e );
365
+ }
366
+
367
+ final String [] values = Objects .requireNonNull (line ).split ("\t " );
368
+ Assertions .assertEquals (dimensions , values .length );
369
+ final Half [] halfs = new Half [dimensions ];
370
+ for (int c = 0 ; c < values .length ; c ++) {
371
+ final String value = values [c ];
372
+ halfs [c ] = HNSWHelpers .halfValueOf (Double .parseDouble (value ));
373
+ }
374
+ final HalfVector newVector = new HalfVector (halfs );
375
+ final double distance = Vector .comparativeDistance (Metric .EUCLIDEAN_METRIC , referenceVector , newVector );
376
+ count ++;
377
+ final double delta = distance - mean ;
378
+ mean += delta / count ;
379
+ final double delta2 = distance - mean ;
380
+ mean2 += delta * delta2 ;
381
+ }
382
+ }
383
+ final double sampleVariance = mean2 / (count - 1 );
384
+ final double standardDeviation = Math .sqrt (sampleVariance );
385
+ logger .info ("mean={}, sample_variance={}, stddeviation={}, cv={}" , mean , sampleVariance , standardDeviation ,
386
+ standardDeviation / mean );
387
+ }
388
+
389
+
390
+ @ ParameterizedTest
391
+ @ ValueSource (ints = {2 , 3 , 10 , 100 , 768 })
392
+ public void testManyVectorsStandardDeviation (final int dimensionality ) {
393
+ final Random random = new Random ();
394
+ final Metric metric = Metric .EuclideanMetric .EUCLIDEAN_METRIC ;
395
+ long count = 0L ;
396
+ double mean = 0.0d ;
397
+ double mean2 = 0.0d ;
398
+ for (long i = 0L ; i < 100000 ; i ++) {
399
+ final HalfVector vector1 = createRandomVector (random , dimensionality );
400
+ final HalfVector vector2 = createRandomVector (random , dimensionality );
401
+ final double distance = Vector .comparativeDistance (metric , vector1 , vector2 );
402
+ count = i + 1 ;
403
+ final double delta = distance - mean ;
404
+ mean += delta / count ;
405
+ final double delta2 = distance - mean ;
406
+ mean2 += delta * delta2 ;
407
+ }
408
+ final double sampleVariance = mean2 / (count - 1 );
409
+ final double standardDeviation = Math .sqrt (sampleVariance );
410
+ logger .info ("mean={}, sample_variance={}, stddeviation={}, cv={}" , mean , sampleVariance , standardDeviation ,
411
+ standardDeviation / mean );
412
+ }
413
+
238
414
private boolean dumpLayer (final HNSW hnsw , final int layer ) throws IOException {
239
415
final String verticesFileName = "/Users/nseemann/Downloads/vertices-" + layer + ".csv" ;
240
416
final String edgesFileName = "/Users/nseemann/Downloads/edges-" + layer + ".csv" ;
@@ -324,13 +500,13 @@ private static Tuple createNextPrimaryKey(@Nonnull final AtomicLong nextIdAtomic
324
500
}
325
501
326
502
@ Nonnull
327
- private Vector . HalfVector createRandomVector (@ Nonnull final Random random , final int dimensionality ) {
503
+ private HalfVector createRandomVector (@ Nonnull final Random random , final int dimensionality ) {
328
504
final Half [] components = new Half [dimensionality ];
329
505
for (int d = 0 ; d < dimensionality ; d ++) {
330
506
// don't ask
331
507
components [d ] = HNSWHelpers .halfValueOf (random .nextDouble ());
332
508
}
333
- return new Vector . HalfVector (components );
509
+ return new HalfVector (components );
334
510
}
335
511
336
512
private static class TestOnReadListener implements OnReadListener {
0 commit comments