diff --git a/src/main/knn/KnnGraphTester.java b/src/main/knn/KnnGraphTester.java index ebdc67f2..e0160f3a 100644 --- a/src/main/knn/KnnGraphTester.java +++ b/src/main/knn/KnnGraphTester.java @@ -18,12 +18,11 @@ package knn; import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.io.OutputStream; -import java.lang.management.ManagementFactory; -import java.lang.management.ThreadMXBean; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.IntBuffer; +import java.io.Serializable; import java.nio.channels.FileChannel; import java.nio.file.Files; import java.nio.file.Path; @@ -33,6 +32,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Deque; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Locale; @@ -83,6 +83,8 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.DoubleValuesSourceRescorer; +import org.apache.lucene.search.FullPrecisionFloatVectorSimilarityValuesSource; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; @@ -93,7 +95,6 @@ import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.CheckJoinIndex; import org.apache.lucene.store.Directory; @@ -180,6 +181,8 @@ enum IndexType { private IndexType indexType; // oversampling, e.g. the multiple * k to gather before checking recall private float overSample; + // rerank using full precision vectors + private boolean rerank; private KnnGraphTester() { // set defaults @@ -203,6 +206,7 @@ private KnnGraphTester() { queryStartIndex = 0; indexType = IndexType.HNSW; overSample = 1f; + rerank = false; } private static FileChannel getVectorFileChannel(Path path, int dim, VectorEncoding vectorEncoding, boolean noisy) throws IOException { @@ -284,6 +288,9 @@ private void run(String... args) throws Exception { throw new IllegalArgumentException("-overSample must be >= 1"); } break; + case "-rerank": + rerank = true; + break; case "-fanout": if (iarg == args.length - 1) { throw new IllegalArgumentException("-fanout requires a following number"); @@ -839,12 +846,12 @@ private void printHist(int[] hist, int max, int count, int nbuckets) { } @SuppressForbidden(reason = "Prints stuff") - private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Path outputPath, int[][] nn) + private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Path outputPath, ResultIds[][] nn) throws IOException { Result[] results = new Result[numQueryVectors]; - int[][] resultIds = new int[numQueryVectors][]; + ResultIds[][] resultIds = new ResultIds[numQueryVectors][]; long elapsedMS, totalCpuTimeMS, totalVisited = 0; - int topK = (overSample > 1) ? (int) (this.topK * overSample) : this.topK; + int annTopK = (overSample > 1) ? (int) (this.topK * overSample) : this.topK; int fanout = (overSample > 1) ? (int) (this.fanout * overSample) : this.fanout; ExecutorService executorService; if (numSearchThread > 0) { @@ -860,7 +867,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat if (targetReader instanceof VectorReaderByte b) { targetReaderByte = b; } - log("searching " + numQueryVectors + " query vectors; topK=" + topK + ", fanout=" + fanout + "\n"); + log("searching " + numQueryVectors + " query vectors; ann-topK=" + annTopK + ", fanout=" + fanout + "\n"); long startNS; try (MMapDirectory dir = new MMapDirectory(indexPath)) { dir.setPreload((x, ctx) -> x.endsWith(".vec") || x.endsWith(".veq")); @@ -874,10 +881,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat for (int i = 0; i < numQueryVectors; i++) { if (vectorEncoding.equals(VectorEncoding.BYTE)) { byte[] target = targetReaderByte.nextBytes(); - doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery); + doKnnByteVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery); } else { float[] target = targetReader.next(); - doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery, parentJoin); + doKnnVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery, parentJoin); } } targetReader.reset(); @@ -886,10 +893,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat for (int i = 0; i < numQueryVectors; i++) { if (vectorEncoding.equals(VectorEncoding.BYTE)) { byte[] target = targetReaderByte.nextBytes(); - results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery); + results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery); } else { float[] target = targetReader.next(); - results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery, parentJoin); + results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery, parentJoin); } } ThreadDetails endThreadDetails = new ThreadDetails(); @@ -930,18 +937,14 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat executorService.shutdown(); } } + // Do we need to write nn here again? We already wrote it in getExactNN() if (outputPath != null) { - ByteBuffer tmp = - ByteBuffer.allocate(resultIds[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); - try (OutputStream out = Files.newOutputStream(outputPath)) { - for (int i = 0; i < numQueryVectors; i++) { - tmp.asIntBuffer().put(nn[i]); - out.write(tmp.array()); - } - } + writeExactNN(nn, outputPath); } else { log("checking results\n"); - float recall = checkResults(resultIds, nn); + float recall = checkRecall(resultIds, nn); + double ndcg10 = calculateNDCG(nn, resultIds, 10); + double ndcgK = calculateNDCG(nn, resultIds, topK); totalVisited /= numQueryVectors; String quantizeDesc; if (quantize) { @@ -952,8 +955,11 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat double reindexSec = reindexTimeMsec / 1000.0; System.out.printf( Locale.ROOT, - "SUMMARY: %5.3f\t%5.3f\t%5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\t%5.3f\t%s\n", + "SUMMARY: %5.3f\t%5.3f\t%5.3f\t%s\t%5.3f\t%5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\t%5.3f\t%s\n", recall, + ndcg10, + ndcgK, + rerank, elapsedMS / (float) numQueryVectors, totalCpuTimeMS / (float) numQueryVectors, totalCpuTimeMS / (float) elapsedMS, @@ -999,7 +1005,7 @@ private static Result doKnnByteVectorQuery( return new Result(docs, profiledQuery.totalVectorCount(), 0); } - private static Result doKnnVectorQuery( + private Result doKnnVectorQuery( IndexSearcher searcher, String field, float[] vector, int k, int fanout, boolean prefilter, Query filter, boolean isParentJoinQuery) throws IOException { if (isParentJoinQuery) { @@ -1013,35 +1019,78 @@ private static Result doKnnVectorQuery( .add(filter, BooleanClause.Occur.FILTER) .build(); TopDocs docs = searcher.search(query, k); + if (rerank) { + FullPrecisionFloatVectorSimilarityValuesSource valuesSource = new FullPrecisionFloatVectorSimilarityValuesSource(vector, field); + DoubleValuesSourceRescorer rescorer = new DoubleValuesSourceRescorer(valuesSource) { + @Override + protected float combine(float firstPassScore, boolean valuePresent, double sourceValue) { + return valuePresent ? (float) sourceValue : firstPassScore; + } + }; + TopDocs rerankedDocs = rescorer.rescore(searcher, docs, topK); + return new Result(rerankedDocs, profiledQuery.totalVectorCount(), 0); + } return new Result(docs, profiledQuery.totalVectorCount(), 0); } record Result(TopDocs topDocs, long visitedCount, int reentryCount) { } - private float checkResults(int[][] results, int[][] nn) { + /** Holds ids and scores for corpus docs in search results */ + record ResultIds(int id, float score) implements Serializable {} + + private float checkRecall(ResultIds[][] results, ResultIds[][] expected) { int totalMatches = 0; - int totalResults = results.length * topK; - for (int i = 0; i < results.length; i++) { + int totalResults = expected.length * topK; + for (int i = 0; i < expected.length; i++) { // System.out.println("compare " + Arrays.toString(nn[i]) + " to "); // System.out.println(Arrays.toString(results[i])); - totalMatches += compareNN(nn[i], results[i]); + totalMatches += compareNN(expected[i], results[i]); } return totalMatches / (float) totalResults; } - private int compareNN(int[] expected, int[] results) { + /** + * Calculates Normalized Discounted Cumulative Gain (NDCG) at K. + * + *
We use full precision vector similarity scores for relevance. Since actual
+ * knn search result may hold quantized scores, we use scores for the corresponding
+ * document "id" from {@code ideal} search results. If a document is not present
+ * in ideal, it is considered irrelevant, and we assign it a score of 0f.
+ */
+ private double calculateNDCG(ResultIds[][] ideal, ResultIds[][] actual, int k) {
+ double ndcg = 0;
+ for (int i = 0; i < ideal.length; i++) {
+ float[] exactResultsRelevance = new float[ideal[i].length];
+ HashMap