|
| 1 | +/* |
| 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | + * or more contributor license agreements. Licensed under the "Elastic License |
| 4 | + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side |
| 5 | + * Public License v 1"; you may not use this file except in compliance with, at |
| 6 | + * your election, the "Elastic License 2.0", the "GNU Affero General Public |
| 7 | + * License v3.0 only", or the "Server Side Public License, v 1". |
| 8 | + */ |
| 9 | + |
| 10 | +package org.elasticsearch.search.vectors; |
| 11 | + |
| 12 | + |
| 13 | +import org.apache.lucene.document.Document; |
| 14 | +import org.apache.lucene.document.KnnFloatVectorField; |
| 15 | +import org.apache.lucene.index.DirectoryReader; |
| 16 | +import org.apache.lucene.index.FloatVectorValues; |
| 17 | +import org.apache.lucene.index.IndexReader; |
| 18 | +import org.apache.lucene.index.IndexWriter; |
| 19 | +import org.apache.lucene.index.KnnVectorValues; |
| 20 | +import org.apache.lucene.index.LeafReaderContext; |
| 21 | +import org.apache.lucene.index.VectorSimilarityFunction; |
| 22 | +import org.apache.lucene.search.IndexSearcher; |
| 23 | +import org.apache.lucene.search.MatchAllDocsQuery; |
| 24 | +import org.apache.lucene.search.TopDocs; |
| 25 | +import org.apache.lucene.store.Directory; |
| 26 | +import org.elasticsearch.test.ESTestCase; |
| 27 | + |
| 28 | +import java.util.ArrayList; |
| 29 | +import java.util.Arrays; |
| 30 | +import java.util.Collection; |
| 31 | +import java.util.Map; |
| 32 | +import java.util.PriorityQueue; |
| 33 | +import java.util.stream.Collectors; |
| 34 | + |
| 35 | +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; |
| 36 | +import static org.hamcrest.Matchers.equalTo; |
| 37 | + |
| 38 | +public class RescoreKnnVectorQueryTests extends ESTestCase { |
| 39 | + |
| 40 | + public static final String FIELD_NAME = "float_vector"; |
| 41 | + |
| 42 | + public void testRescoresTopK() throws Exception { |
| 43 | + int numDocs = randomIntBetween(10, 100); |
| 44 | + testRescoreDocs(numDocs, randomIntBetween(5, numDocs - 1)); |
| 45 | + } |
| 46 | + |
| 47 | + public void testRescoresNoKParameter() throws Exception { |
| 48 | + testRescoreDocs(randomIntBetween(10, 100), null); |
| 49 | + } |
| 50 | + |
| 51 | + private void testRescoreDocs(int numDocs, Integer k) throws Exception { |
| 52 | + int numDims = randomIntBetween(5, 100); |
| 53 | + |
| 54 | + if (k == null) { |
| 55 | + k = numDocs; |
| 56 | + } |
| 57 | + |
| 58 | + try (Directory d = newDirectory()) { |
| 59 | + try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) { |
| 60 | + for (int i = 0; i < numDocs; i++) { |
| 61 | + Document document = new Document(); |
| 62 | + float[] vector = randomVector(numDims); |
| 63 | + KnnFloatVectorField vectorField = new KnnFloatVectorField( |
| 64 | + FIELD_NAME, vector); |
| 65 | + document.add(vectorField); |
| 66 | + w.addDocument(document); |
| 67 | + } |
| 68 | + w.commit(); |
| 69 | + w.forceMerge(1); |
| 70 | + } |
| 71 | + |
| 72 | + try (IndexReader reader = DirectoryReader.open(d)) { |
| 73 | + float[] queryVector = randomVector(numDims); |
| 74 | + |
| 75 | + RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( |
| 76 | + FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, k, new MatchAllDocsQuery()); |
| 77 | + |
| 78 | + IndexSearcher searcher = newSearcher(reader, true, false); |
| 79 | + TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs); |
| 80 | + Map<Integer, Float> rescoredDocs = Arrays.stream(docs.scoreDocs).collect(Collectors.toMap( |
| 81 | + scoreDoc -> scoreDoc.doc, |
| 82 | + scoreDoc -> scoreDoc.score) |
| 83 | + ); |
| 84 | + |
| 85 | + assertThat(rescoredDocs.size(), equalTo(k)); |
| 86 | + |
| 87 | + Collection<Float> rescoredScores = new ArrayList<>(rescoredDocs.values()); |
| 88 | + PriorityQueue<Float> topK = new PriorityQueue<>((o1, o2) -> Float.compare(o2, o1)); |
| 89 | + |
| 90 | + for (LeafReaderContext leafReaderContext : reader.leaves()) { |
| 91 | + FloatVectorValues floatVectorValues = leafReaderContext.reader().getFloatVectorValues(FIELD_NAME); |
| 92 | + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); |
| 93 | + while (iterator.nextDoc() != NO_MORE_DOCS) { |
| 94 | + float[] vector = floatVectorValues.vectorValue(iterator.index()); |
| 95 | + float score = VectorSimilarityFunction.COSINE.compare(queryVector, vector); |
| 96 | + topK.add(score); |
| 97 | + int docId = iterator.docID(); |
| 98 | + if (rescoredDocs.containsKey(docId)) { |
| 99 | + assertThat(rescoredDocs.get(docId), equalTo(score)); |
| 100 | + rescoredDocs.remove(docId); |
| 101 | + } |
| 102 | + } |
| 103 | + } |
| 104 | + |
| 105 | + assertThat(rescoredDocs.size(), equalTo(0)); |
| 106 | + |
| 107 | + // Check top scoring docs are contained in rescored docs |
| 108 | + for (int i = 0; i < k; i++) { |
| 109 | + Float topScore = topK.poll(); |
| 110 | + if (rescoredScores.contains(topScore) == false ) { |
| 111 | + fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores); |
| 112 | + } |
| 113 | + } |
| 114 | + } |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + private static float[] randomVector(int numDims) { |
| 119 | + float[] vector = new float[numDims]; |
| 120 | + for (int j = 0; j < numDims; j++) { |
| 121 | + vector[j] = randomFloatBetween(0, 1, true); |
| 122 | + } |
| 123 | + return vector; |
| 124 | + } |
| 125 | + |
| 126 | +} |
0 commit comments