Skip to content

Commit 90ce59e

Browse files
Small edits for KnnGraphTester (#575)
1. Correct the remaining size for input files larger than Integer.MAX_VALUE, as currently with every iteration we try to map the next blockSize of bytes even if less < blockSize bytes are left in the file. 2. Correct java.lang.ClassCastException when retrieving KnnGraphValues for stats printing. 3. Add an option for euclidean metric
1 parent f8c7619 commit 90ce59e

File tree

1 file changed

+26
-19
lines changed

1 file changed

+26
-19
lines changed

lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@
3636
import java.util.Locale;
3737
import java.util.Set;
3838
import org.apache.lucene.codecs.KnnVectorsFormat;
39+
import org.apache.lucene.codecs.KnnVectorsReader;
3940
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
4041
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
4142
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsReader;
43+
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
4244
import org.apache.lucene.document.Document;
4345
import org.apache.lucene.document.FieldType;
4446
import org.apache.lucene.document.KnnVectorField;
@@ -74,8 +76,6 @@ public class KnnGraphTester {
7476

7577
private static final String KNN_FIELD = "knn";
7678
private static final String ID_FIELD = "id";
77-
private static final VectorSimilarityFunction SIMILARITY_FUNCTION =
78-
VectorSimilarityFunction.DOT_PRODUCT;
7979

8080
private int numDocs;
8181
private int dim;
@@ -90,6 +90,7 @@ public class KnnGraphTester {
9090
private int reindexTimeMsec;
9191
private int beamWidth;
9292
private int maxConn;
93+
private VectorSimilarityFunction similarityFunction;
9394

9495
@SuppressForbidden(reason = "uses Random()")
9596
private KnnGraphTester() {
@@ -100,6 +101,7 @@ private KnnGraphTester() {
100101
topK = 100;
101102
warmCount = 1000;
102103
fanout = topK;
104+
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
103105
}
104106

105107
public static void main(String... args) throws Exception {
@@ -183,6 +185,14 @@ private void run(String... args) throws Exception {
183185
case "-docs":
184186
docVectorsPath = Paths.get(args[++iarg]);
185187
break;
188+
case "-metric":
189+
String metric = args[++iarg];
190+
if (metric.equals("euclidean")) {
191+
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
192+
} else if (metric.equals("angular") == false) {
193+
throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only");
194+
}
195+
break;
186196
case "-forceMerge":
187197
forceMerge = true;
188198
break;
@@ -237,12 +247,13 @@ private String formatIndexPath(Path docsPath) {
237247
private void printFanoutHist(Path indexPath) throws IOException {
238248
try (Directory dir = FSDirectory.open(indexPath);
239249
DirectoryReader reader = DirectoryReader.open(dir)) {
240-
// int[] globalHist = new int[reader.maxDoc()];
241250
for (LeafReaderContext context : reader.leaves()) {
242251
LeafReader leafReader = context.reader();
252+
KnnVectorsReader vectorsReader =
253+
((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) leafReader).getVectorReader())
254+
.getFieldReader(KNN_FIELD);
243255
KnnGraphValues knnValues =
244-
((Lucene90HnswVectorsReader) ((CodecReader) leafReader).getVectorReader())
245-
.getGraphValues(KNN_FIELD);
256+
((Lucene90HnswVectorsReader) vectorsReader).getGraphValues(KNN_FIELD);
246257
System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc());
247258
printGraphFanout(knnValues, leafReader.maxDoc());
248259
}
@@ -253,7 +264,7 @@ private void dumpGraph(Path docsPath) throws IOException {
253264
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
254265
RandomAccessVectorValues values = vectors.randomAccess();
255266
HnswGraphBuilder builder =
256-
new HnswGraphBuilder(vectors, SIMILARITY_FUNCTION, maxConn, beamWidth, 0);
267+
new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, 0);
257268
// start at node 1
258269
for (int i = 1; i < numDocs; i++) {
259270
builder.addGraphNode(values.vectorValue(i));
@@ -533,25 +544,21 @@ private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
533544
for (int i = 0; i < numIters; i++) {
534545
queries.get(query);
535546
long totalBytes = (long) numDocs * dim * Float.BYTES;
536-
int
537-
blockSize =
538-
(int)
539-
Math.min(
540-
totalBytes,
541-
(Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES)),
542-
offset = 0;
547+
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
548+
int offset = 0;
543549
int j = 0;
544550
// System.out.println("totalBytes=" + totalBytes);
545551
while (j < numDocs) {
552+
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
546553
FloatBuffer vectors =
547554
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
548555
.order(ByteOrder.LITTLE_ENDIAN)
549556
.asFloatBuffer();
550557
offset += blockSize;
551-
NeighborQueue queue = new NeighborQueue(topK, SIMILARITY_FUNCTION.reversed);
558+
NeighborQueue queue = new NeighborQueue(topK, similarityFunction.reversed);
552559
for (; j < numDocs && vectors.hasRemaining(); j++) {
553560
vectors.get(vector);
554-
float d = SIMILARITY_FUNCTION.compare(query, vector);
561+
float d = similarityFunction.compare(query, vector);
555562
queue.insertWithOverflow(j, d);
556563
}
557564
result[i] = new int[topK];
@@ -583,22 +590,22 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
583590
iwc.setRAMBufferSizeMB(1994d);
584591
// iwc.setMaxBufferedDocs(10000);
585592

586-
FieldType fieldType = KnnVectorField.createFieldType(dim, VectorSimilarityFunction.DOT_PRODUCT);
593+
FieldType fieldType = KnnVectorField.createFieldType(dim, similarityFunction);
587594
if (quiet == false) {
588595
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
589596
System.out.println("creating index in " + indexPath);
590597
}
591598
long start = System.nanoTime();
592599
long totalBytes = (long) numDocs * dim * Float.BYTES, offset = 0;
600+
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
601+
593602
try (FSDirectory dir = FSDirectory.open(indexPath);
594603
IndexWriter iw = new IndexWriter(dir, iwc)) {
595-
int blockSize =
596-
(int)
597-
Math.min(totalBytes, (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES));
598604
float[] vector = new float[dim];
599605
try (FileChannel in = FileChannel.open(docsPath)) {
600606
int i = 0;
601607
while (i < numDocs) {
608+
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
602609
FloatBuffer vectors =
603610
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
604611
.order(ByteOrder.LITTLE_ENDIAN)

0 commit comments

Comments
 (0)