36
36
import java .util .Locale ;
37
37
import java .util .Set ;
38
38
import org .apache .lucene .codecs .KnnVectorsFormat ;
39
+ import org .apache .lucene .codecs .KnnVectorsReader ;
39
40
import org .apache .lucene .codecs .lucene90 .Lucene90Codec ;
40
41
import org .apache .lucene .codecs .lucene90 .Lucene90HnswVectorsFormat ;
41
42
import org .apache .lucene .codecs .lucene90 .Lucene90HnswVectorsReader ;
43
+ import org .apache .lucene .codecs .perfield .PerFieldKnnVectorsFormat ;
42
44
import org .apache .lucene .document .Document ;
43
45
import org .apache .lucene .document .FieldType ;
44
46
import org .apache .lucene .document .KnnVectorField ;
@@ -74,8 +76,6 @@ public class KnnGraphTester {
74
76
75
77
private static final String KNN_FIELD = "knn" ;
76
78
private static final String ID_FIELD = "id" ;
77
- private static final VectorSimilarityFunction SIMILARITY_FUNCTION =
78
- VectorSimilarityFunction .DOT_PRODUCT ;
79
79
80
80
private int numDocs ;
81
81
private int dim ;
@@ -90,6 +90,7 @@ public class KnnGraphTester {
90
90
private int reindexTimeMsec ;
91
91
private int beamWidth ;
92
92
private int maxConn ;
93
+ private VectorSimilarityFunction similarityFunction ;
93
94
94
95
@ SuppressForbidden (reason = "uses Random()" )
95
96
private KnnGraphTester () {
@@ -100,6 +101,7 @@ private KnnGraphTester() {
100
101
topK = 100 ;
101
102
warmCount = 1000 ;
102
103
fanout = topK ;
104
+ similarityFunction = VectorSimilarityFunction .DOT_PRODUCT ;
103
105
}
104
106
105
107
public static void main (String ... args ) throws Exception {
@@ -183,6 +185,14 @@ private void run(String... args) throws Exception {
183
185
case "-docs" :
184
186
docVectorsPath = Paths .get (args [++iarg ]);
185
187
break ;
188
+ case "-metric" :
189
+ String metric = args [++iarg ];
190
+ if (metric .equals ("euclidean" )) {
191
+ similarityFunction = VectorSimilarityFunction .EUCLIDEAN ;
192
+ } else if (metric .equals ("angular" ) == false ) {
193
+ throw new IllegalArgumentException ("-metric can be 'angular' or 'euclidean' only" );
194
+ }
195
+ break ;
186
196
case "-forceMerge" :
187
197
forceMerge = true ;
188
198
break ;
@@ -237,12 +247,13 @@ private String formatIndexPath(Path docsPath) {
237
247
private void printFanoutHist (Path indexPath ) throws IOException {
238
248
try (Directory dir = FSDirectory .open (indexPath );
239
249
DirectoryReader reader = DirectoryReader .open (dir )) {
240
- // int[] globalHist = new int[reader.maxDoc()];
241
250
for (LeafReaderContext context : reader .leaves ()) {
242
251
LeafReader leafReader = context .reader ();
252
+ KnnVectorsReader vectorsReader =
253
+ ((PerFieldKnnVectorsFormat .FieldsReader ) ((CodecReader ) leafReader ).getVectorReader ())
254
+ .getFieldReader (KNN_FIELD );
243
255
KnnGraphValues knnValues =
244
- ((Lucene90HnswVectorsReader ) ((CodecReader ) leafReader ).getVectorReader ())
245
- .getGraphValues (KNN_FIELD );
256
+ ((Lucene90HnswVectorsReader ) vectorsReader ).getGraphValues (KNN_FIELD );
246
257
System .out .printf ("Leaf %d has %d documents\n " , context .ord , leafReader .maxDoc ());
247
258
printGraphFanout (knnValues , leafReader .maxDoc ());
248
259
}
@@ -253,7 +264,7 @@ private void dumpGraph(Path docsPath) throws IOException {
253
264
try (BinaryFileVectors vectors = new BinaryFileVectors (docsPath )) {
254
265
RandomAccessVectorValues values = vectors .randomAccess ();
255
266
HnswGraphBuilder builder =
256
- new HnswGraphBuilder (vectors , SIMILARITY_FUNCTION , maxConn , beamWidth , 0 );
267
+ new HnswGraphBuilder (vectors , similarityFunction , maxConn , beamWidth , 0 );
257
268
// start at node 1
258
269
for (int i = 1 ; i < numDocs ; i ++) {
259
270
builder .addGraphNode (values .vectorValue (i ));
@@ -533,25 +544,21 @@ private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
533
544
for (int i = 0 ; i < numIters ; i ++) {
534
545
queries .get (query );
535
546
long totalBytes = (long ) numDocs * dim * Float .BYTES ;
536
- int
537
- blockSize =
538
- (int )
539
- Math .min (
540
- totalBytes ,
541
- (Integer .MAX_VALUE / (dim * Float .BYTES )) * (dim * Float .BYTES )),
542
- offset = 0 ;
547
+ final int maxBlockSize = (Integer .MAX_VALUE / (dim * Float .BYTES )) * (dim * Float .BYTES );
548
+ int offset = 0 ;
543
549
int j = 0 ;
544
550
// System.out.println("totalBytes=" + totalBytes);
545
551
while (j < numDocs ) {
552
+ int blockSize = (int ) Math .min (totalBytes - offset , maxBlockSize );
546
553
FloatBuffer vectors =
547
554
in .map (FileChannel .MapMode .READ_ONLY , offset , blockSize )
548
555
.order (ByteOrder .LITTLE_ENDIAN )
549
556
.asFloatBuffer ();
550
557
offset += blockSize ;
551
- NeighborQueue queue = new NeighborQueue (topK , SIMILARITY_FUNCTION .reversed );
558
+ NeighborQueue queue = new NeighborQueue (topK , similarityFunction .reversed );
552
559
for (; j < numDocs && vectors .hasRemaining (); j ++) {
553
560
vectors .get (vector );
554
- float d = SIMILARITY_FUNCTION .compare (query , vector );
561
+ float d = similarityFunction .compare (query , vector );
555
562
queue .insertWithOverflow (j , d );
556
563
}
557
564
result [i ] = new int [topK ];
@@ -583,22 +590,22 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
583
590
iwc .setRAMBufferSizeMB (1994d );
584
591
// iwc.setMaxBufferedDocs(10000);
585
592
586
- FieldType fieldType = KnnVectorField .createFieldType (dim , VectorSimilarityFunction . DOT_PRODUCT );
593
+ FieldType fieldType = KnnVectorField .createFieldType (dim , similarityFunction );
587
594
if (quiet == false ) {
588
595
iwc .setInfoStream (new PrintStreamInfoStream (System .out ));
589
596
System .out .println ("creating index in " + indexPath );
590
597
}
591
598
long start = System .nanoTime ();
592
599
long totalBytes = (long ) numDocs * dim * Float .BYTES , offset = 0 ;
600
+ final int maxBlockSize = (Integer .MAX_VALUE / (dim * Float .BYTES )) * (dim * Float .BYTES );
601
+
593
602
try (FSDirectory dir = FSDirectory .open (indexPath );
594
603
IndexWriter iw = new IndexWriter (dir , iwc )) {
595
- int blockSize =
596
- (int )
597
- Math .min (totalBytes , (Integer .MAX_VALUE / (dim * Float .BYTES )) * (dim * Float .BYTES ));
598
604
float [] vector = new float [dim ];
599
605
try (FileChannel in = FileChannel .open (docsPath )) {
600
606
int i = 0 ;
601
607
while (i < numDocs ) {
608
+ int blockSize = (int ) Math .min (totalBytes - offset , maxBlockSize );
602
609
FloatBuffer vectors =
603
610
in .map (FileChannel .MapMode .READ_ONLY , offset , blockSize )
604
611
.order (ByteOrder .LITTLE_ENDIAN )
0 commit comments