diff --git a/src/main/knn/KnnGraphTester.java b/src/main/knn/KnnGraphTester.java index ebdc67f2..e0160f3a 100644 --- a/src/main/knn/KnnGraphTester.java +++ b/src/main/knn/KnnGraphTester.java @@ -18,12 +18,11 @@ package knn; import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.io.OutputStream; -import java.lang.management.ManagementFactory; -import java.lang.management.ThreadMXBean; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.IntBuffer; +import java.io.Serializable; import java.nio.channels.FileChannel; import java.nio.file.Files; import java.nio.file.Path; @@ -33,6 +32,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Deque; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Locale; @@ -83,6 +83,8 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.DoubleValuesSourceRescorer; +import org.apache.lucene.search.FullPrecisionFloatVectorSimilarityValuesSource; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; @@ -93,7 +95,6 @@ import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.CheckJoinIndex; import org.apache.lucene.store.Directory; @@ -180,6 +181,8 @@ enum IndexType { private IndexType indexType; // oversampling, e.g. the multiple * k to gather before checking recall private float overSample; + // rerank using full precision vectors + private boolean rerank; private KnnGraphTester() { // set defaults @@ -203,6 +206,7 @@ private KnnGraphTester() { queryStartIndex = 0; indexType = IndexType.HNSW; overSample = 1f; + rerank = false; } private static FileChannel getVectorFileChannel(Path path, int dim, VectorEncoding vectorEncoding, boolean noisy) throws IOException { @@ -284,6 +288,9 @@ private void run(String... args) throws Exception { throw new IllegalArgumentException("-overSample must be >= 1"); } break; + case "-rerank": + rerank = true; + break; case "-fanout": if (iarg == args.length - 1) { throw new IllegalArgumentException("-fanout requires a following number"); @@ -839,12 +846,12 @@ private void printHist(int[] hist, int max, int count, int nbuckets) { } @SuppressForbidden(reason = "Prints stuff") - private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Path outputPath, int[][] nn) + private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Path outputPath, ResultIds[][] nn) throws IOException { Result[] results = new Result[numQueryVectors]; - int[][] resultIds = new int[numQueryVectors][]; + ResultIds[][] resultIds = new ResultIds[numQueryVectors][]; long elapsedMS, totalCpuTimeMS, totalVisited = 0; - int topK = (overSample > 1) ? (int) (this.topK * overSample) : this.topK; + int annTopK = (overSample > 1) ? (int) (this.topK * overSample) : this.topK; int fanout = (overSample > 1) ? (int) (this.fanout * overSample) : this.fanout; ExecutorService executorService; if (numSearchThread > 0) { @@ -860,7 +867,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat if (targetReader instanceof VectorReaderByte b) { targetReaderByte = b; } - log("searching " + numQueryVectors + " query vectors; topK=" + topK + ", fanout=" + fanout + "\n"); + log("searching " + numQueryVectors + " query vectors; ann-topK=" + annTopK + ", fanout=" + fanout + "\n"); long startNS; try (MMapDirectory dir = new MMapDirectory(indexPath)) { dir.setPreload((x, ctx) -> x.endsWith(".vec") || x.endsWith(".veq")); @@ -874,10 +881,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat for (int i = 0; i < numQueryVectors; i++) { if (vectorEncoding.equals(VectorEncoding.BYTE)) { byte[] target = targetReaderByte.nextBytes(); - doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery); + doKnnByteVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery); } else { float[] target = targetReader.next(); - doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery, parentJoin); + doKnnVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery, parentJoin); } } targetReader.reset(); @@ -886,10 +893,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat for (int i = 0; i < numQueryVectors; i++) { if (vectorEncoding.equals(VectorEncoding.BYTE)) { byte[] target = targetReaderByte.nextBytes(); - results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery); + results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery); } else { float[] target = targetReader.next(); - results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery, parentJoin); + results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery, parentJoin); } } ThreadDetails endThreadDetails = new ThreadDetails(); @@ -930,18 +937,14 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat executorService.shutdown(); } } + // Do we need to write nn here again? We already wrote it in getExactNN() if (outputPath != null) { - ByteBuffer tmp = - ByteBuffer.allocate(resultIds[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); - try (OutputStream out = Files.newOutputStream(outputPath)) { - for (int i = 0; i < numQueryVectors; i++) { - tmp.asIntBuffer().put(nn[i]); - out.write(tmp.array()); - } - } + writeExactNN(nn, outputPath); } else { log("checking results\n"); - float recall = checkResults(resultIds, nn); + float recall = checkRecall(resultIds, nn); + double ndcg10 = calculateNDCG(nn, resultIds, 10); + double ndcgK = calculateNDCG(nn, resultIds, topK); totalVisited /= numQueryVectors; String quantizeDesc; if (quantize) { @@ -952,8 +955,11 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat double reindexSec = reindexTimeMsec / 1000.0; System.out.printf( Locale.ROOT, - "SUMMARY: %5.3f\t%5.3f\t%5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\t%5.3f\t%s\n", + "SUMMARY: %5.3f\t%5.3f\t%5.3f\t%s\t%5.3f\t%5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\t%5.3f\t%s\n", recall, + ndcg10, + ndcgK, + rerank, elapsedMS / (float) numQueryVectors, totalCpuTimeMS / (float) numQueryVectors, totalCpuTimeMS / (float) elapsedMS, @@ -999,7 +1005,7 @@ private static Result doKnnByteVectorQuery( return new Result(docs, profiledQuery.totalVectorCount(), 0); } - private static Result doKnnVectorQuery( + private Result doKnnVectorQuery( IndexSearcher searcher, String field, float[] vector, int k, int fanout, boolean prefilter, Query filter, boolean isParentJoinQuery) throws IOException { if (isParentJoinQuery) { @@ -1013,35 +1019,78 @@ private static Result doKnnVectorQuery( .add(filter, BooleanClause.Occur.FILTER) .build(); TopDocs docs = searcher.search(query, k); + if (rerank) { + FullPrecisionFloatVectorSimilarityValuesSource valuesSource = new FullPrecisionFloatVectorSimilarityValuesSource(vector, field); + DoubleValuesSourceRescorer rescorer = new DoubleValuesSourceRescorer(valuesSource) { + @Override + protected float combine(float firstPassScore, boolean valuePresent, double sourceValue) { + return valuePresent ? (float) sourceValue : firstPassScore; + } + }; + TopDocs rerankedDocs = rescorer.rescore(searcher, docs, topK); + return new Result(rerankedDocs, profiledQuery.totalVectorCount(), 0); + } return new Result(docs, profiledQuery.totalVectorCount(), 0); } record Result(TopDocs topDocs, long visitedCount, int reentryCount) { } - private float checkResults(int[][] results, int[][] nn) { + /** Holds ids and scores for corpus docs in search results */ + record ResultIds(int id, float score) implements Serializable {} + + private float checkRecall(ResultIds[][] results, ResultIds[][] expected) { int totalMatches = 0; - int totalResults = results.length * topK; - for (int i = 0; i < results.length; i++) { + int totalResults = expected.length * topK; + for (int i = 0; i < expected.length; i++) { // System.out.println("compare " + Arrays.toString(nn[i]) + " to "); // System.out.println(Arrays.toString(results[i])); - totalMatches += compareNN(nn[i], results[i]); + totalMatches += compareNN(expected[i], results[i]); } return totalMatches / (float) totalResults; } - private int compareNN(int[] expected, int[] results) { + /** + * Calculates Normalized Discounted Cumulative Gain (NDCG) at K. + * + *

We use full precision vector similarity scores for relevance. Since actual + * knn search result may hold quantized scores, we use scores for the corresponding + * document "id" from {@code ideal} search results. If a document is not present + * in ideal, it is considered irrelevant, and we assign it a score of 0f. + */ + private double calculateNDCG(ResultIds[][] ideal, ResultIds[][] actual, int k) { + double ndcg = 0; + for (int i = 0; i < ideal.length; i++) { + float[] exactResultsRelevance = new float[ideal[i].length]; + HashMap idToRelevance = new HashMap(ideal[i].length); + for (int rank = 0; rank < ideal[i].length; rank++) { + exactResultsRelevance[rank] = ideal[i][rank].score(); + idToRelevance.put(ideal[i][rank].id(), ideal[i][rank].score()); + } + float[] actualResultsRelevance = new float[actual[i].length]; + for (int rank = 0; rank < actual[i].length; rank++) { + actualResultsRelevance[rank] = idToRelevance.getOrDefault(actual[i][rank].id(), 0f); + } + double idealDCG = KnnTesterUtils.dcg(exactResultsRelevance, k); + double actualDCG = KnnTesterUtils.dcg(actualResultsRelevance, k); + ndcg += (actualDCG / idealDCG); + } + ndcg /= ideal.length; + return ndcg; + } + + private int compareNN(ResultIds[] expected, ResultIds[] results) { int matched = 0; Set expectedSet = new HashSet<>(); Set alreadySeen = new HashSet<>(); for (int i = 0; i < topK; i++) { - expectedSet.add(expected[i]); + expectedSet.add(expected[i].id); } - for (int docId : results) { - if (alreadySeen.add(docId) == false) { - throw new IllegalStateException("duplicate docId=" + docId); + for (ResultIds r : results) { + if (alreadySeen.add(r.id) == false) { + throw new IllegalStateException("duplicate docId=" + r.id); } - if (expectedSet.contains(docId)) { + if (expectedSet.contains(r.id)) { ++matched; } } @@ -1053,7 +1102,7 @@ private int compareNN(int[] expected, int[] results) { * The method runs "numQueryVectors" target queries and returns "topK" nearest neighbors * for each of them. Nearest Neighbors are computed using exact match. */ - private int[][] getExactNN(Path docPath, Path indexPath, Path queryPath, int queryStartIndex) throws IOException, InterruptedException { + private ResultIds[][] getExactNN(Path docPath, Path indexPath, Path queryPath, int queryStartIndex) throws IOException, InterruptedException { // look in working directory for cached nn file String hash = Integer.toString(Objects.hash(docPath, indexPath, queryPath, numDocs, numQueryVectors, topK, similarityFunction.ordinal(), parentJoin, queryStartIndex, prefilter ? selectivity : 1f, prefilter ? randomSeed : 0f), 36); String nnFileName = "nn-" + hash + ".bin"; @@ -1066,7 +1115,7 @@ private int[][] getExactNN(Path docPath, Path indexPath, Path queryPath, int que long startNS = System.nanoTime(); // TODO: enable computing NN from high precision vectors when // checking low-precision recall - int[][] nn; + ResultIds[][] nn; if (vectorEncoding.equals(VectorEncoding.BYTE)) { nn = computeExactNNByte(queryPath, queryStartIndex); } else { @@ -1089,35 +1138,32 @@ private boolean isNewer(Path path, Path... others) throws IOException { return true; } - private int[][] readExactNN(Path nnPath) throws IOException { - int[][] result = new int[numQueryVectors][]; - try (FileChannel in = FileChannel.open(nnPath)) { - IntBuffer intBuffer = - in.map(FileChannel.MapMode.READ_ONLY, 0, numQueryVectors * topK * Integer.BYTES) - .order(ByteOrder.LITTLE_ENDIAN) - .asIntBuffer(); + private ResultIds[][] readExactNN(Path nnPath) throws IOException { + log("reading true nearest neighbors from file \"" + nnPath + "\"\n"); + ResultIds[][] nn = new ResultIds[numQueryVectors][]; + try (InputStream in = Files.newInputStream(nnPath); + ObjectInputStream ois = new ObjectInputStream(in)) { for (int i = 0; i < numQueryVectors; i++) { - result[i] = new int[topK]; - intBuffer.get(result[i]); + nn[i] = (ResultIds[]) ois.readObject(); } + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); } - return result; + return nn; } - private void writeExactNN(int[][] nn, Path nnPath) throws IOException { - log("writing true nearest neighbors to cache file \"" + nnPath + "\"\n"); - ByteBuffer tmp = - ByteBuffer.allocate(nn[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); - try (OutputStream out = Files.newOutputStream(nnPath)) { + private void writeExactNN(ResultIds[][] nn, Path nnPath) throws IOException { + log("\nwriting true nearest neighbors to cache file \"" + nnPath + "\"\n"); + try (OutputStream fileOutputStream = Files.newOutputStream(nnPath); + ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream)) { for (int i = 0; i < numQueryVectors; i++) { - tmp.asIntBuffer().put(nn[i]); - out.write(tmp.array()); + objectOutputStream.writeObject(nn[i]); } } } - private int[][] computeExactNNByte(Path queryPath, int queryStartIndex) throws IOException, InterruptedException { - int[][] result = new int[numQueryVectors][]; + private ResultIds[][] computeExactNNByte(Path queryPath, int queryStartIndex) throws IOException, InterruptedException { + ResultIds[][] result = new ResultIds[numQueryVectors][]; log("computing true nearest neighbors of " + numQueryVectors + " target vectors\n"); List tasks = new ArrayList<>(); try (MMapDirectory dir = new MMapDirectory(indexPath)) { @@ -1143,10 +1189,10 @@ class ComputeNNByteTask implements Callable { private final int queryOrd; private final byte[] query; - private final int[][] result; + private final ResultIds[][] result; private final IndexReader reader; - ComputeNNByteTask(int queryOrd, byte[] query, int[][] result, IndexReader reader) { + ComputeNNByteTask(int queryOrd, byte[] query, ResultIds[][] result, IndexReader reader) { this.queryOrd = queryOrd; this.query = query; this.result = result; @@ -1176,9 +1222,9 @@ public Void call() { } /** Brute force computation of "true" nearest neighhbors. */ - private int[][] computeExactNN(Path queryPath, int queryStartIndex) + private ResultIds[][] computeExactNN(Path queryPath, int queryStartIndex) throws IOException, InterruptedException { - int[][] result = new int[numQueryVectors][]; + ResultIds[][] result = new ResultIds[numQueryVectors][]; log("computing true nearest neighbors of " + numQueryVectors + " target vectors\n"); log("parentJoin = %s\n", parentJoin); try (MMapDirectory dir = new MMapDirectory(indexPath)) { @@ -1216,10 +1262,10 @@ class ComputeNNFloatTask implements Callable { private final int queryOrd; private final float[] query; - private final int[][] result; + private final ResultIds[][] result; private final IndexReader reader; - ComputeNNFloatTask(int queryOrd, float[] query, int[][] result, IndexReader reader) { + ComputeNNFloatTask(int queryOrd, float[] query, ResultIds[][] result, IndexReader reader) { this.queryOrd = queryOrd; this.query = query; this.result = result; @@ -1255,10 +1301,10 @@ class ComputeExactSearchNNFloatTask implements Callable { private final int queryOrd; private final float[] query; - private final int[][] result; + private final ResultIds[][] result; private final IndexReader reader; - ComputeExactSearchNNFloatTask(int queryOrd, float[] query, int[][] result, IndexReader reader) { + ComputeExactSearchNNFloatTask(int queryOrd, float[] query, ResultIds[][] result, IndexReader reader) { this.queryOrd = queryOrd; this.query = query; this.result = result; diff --git a/src/main/knn/KnnTesterUtils.java b/src/main/knn/KnnTesterUtils.java index 15f6159b..c9093ba2 100644 --- a/src/main/knn/KnnTesterUtils.java +++ b/src/main/knn/KnnTesterUtils.java @@ -23,14 +23,16 @@ import java.io.IOException; +import static knn.KnnGraphTester.ID_FIELD; +import static knn.KnnGraphTester.ResultIds; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; public class KnnTesterUtils { /** Fetches values for the "id" field from search results */ - public static int[] getResultIds(TopDocs topDocs, StoredFields storedFields) throws IOException { - int[] resultIds = new int[topDocs.scoreDocs.length]; + public static ResultIds[] getResultIds(TopDocs topDocs, StoredFields storedFields) throws IOException { + ResultIds[] resultIds = new ResultIds[topDocs.scoreDocs.length]; int i = 0; // TODO: switch to doc values for this id field? more efficent than stored fields // TODO: or, at least load the stored documents in index (Lucene docid) order to @@ -39,8 +41,22 @@ public static int[] getResultIds(TopDocs topDocs, StoredFields storedFields) thr // queries have run) for (ScoreDoc doc : topDocs.scoreDocs) { assert doc.doc != NO_MORE_DOCS: "illegal docid " + doc.doc + " returned from KNN search?"; - resultIds[i++] = Integer.parseInt(storedFields.document(doc.doc).get(KnnGraphTester.ID_FIELD)); + resultIds[i++] = new ResultIds(Integer.parseInt(storedFields.document(doc.doc).get(ID_FIELD)), doc.score); } return resultIds; } + + /** + * Calculates Discounted Cumulative Gain @k + * @param relevance Relevance scores sorted by rank of search results. + * @param k DCG is calculated up to this rank + */ + public static double dcg(float[] relevance, int k) { + double dcg = 0; + k = Math.min(relevance.length, k); + for (int i = 0; i < k; i++) { + dcg += relevance[i] / (Math.log(2 + i) / Math.log(2)); // rank = (i+1) + } + return dcg; + } } diff --git a/src/python/knnPerfTest.py b/src/python/knnPerfTest.py index 374f5f5a..2bd4bcbe 100644 --- a/src/python/knnPerfTest.py +++ b/src/python/knnPerfTest.py @@ -89,11 +89,15 @@ "queryStartIndex": (0,), # seek to this start vector before searching, to sample different vectors # "forceMerge": (True, False), #'niter': (10,), + # "rerank": (False, True), } OUTPUT_HEADERS = [ "recall", + "ndcg@10", + "ndcg@K", + "rerank", "latency(ms)", "netCPU", "avgCpuCount", @@ -280,6 +284,8 @@ def run_knn_benchmark(checkout, values): if "-indexType" in this_cmd and "flat" in this_cmd: skip_headers.add("maxConn") skip_headers.add("beamWidth") + if "-rerank" not in this_cmd: + skip_headers.add("rerank") print_fixed_width(all_results, skip_headers) print_chart(all_results)