|
9 | 9 |
|
10 | 10 | package org.elasticsearch.search.vectors; |
11 | 11 |
|
| 12 | +import org.apache.lucene.index.IndexReader; |
12 | 13 | import org.apache.lucene.index.LeafReaderContext; |
13 | 14 | import org.apache.lucene.search.DocIdSetIterator; |
14 | 15 | import org.apache.lucene.search.Explanation; |
|
37 | 38 | public class KnnScoreDocQuery extends Query { |
38 | 39 | private final int[] docs; |
39 | 40 | private final float[] scores; |
| 41 | + |
| 42 | + // the indexes in docs and scores corresponding to the first matching document in each segment. |
| 43 | + // If a segment has no matching documents, it should be assigned the index of the next segment that does. |
| 44 | + // There should be a final entry that is always docs.length-1. |
40 | 45 | private final int[] segmentStarts; |
| 46 | + // an object identifying the reader context that was used to build this query |
| 47 | + |
41 | 48 | private final Object contextIdentity; |
42 | 49 |
|
43 | 50 | /** |
44 | 51 | * Creates a query. |
45 | 52 | * |
46 | 53 | * @param docs the global doc IDs of documents that match, in ascending order |
47 | 54 | * @param scores the scores of the matching documents |
48 | | - * @param segmentStarts the indexes in docs and scores corresponding to the first matching |
49 | | - * document in each segment. If a segment has no matching documents, it should be assigned |
50 | | - * the index of the next segment that does. There should be a final entry that is always |
51 | | - * docs.length-1. |
52 | | - * @param contextIdentity an object identifying the reader context that was used to build this |
53 | | - * query |
| 55 | + * @param reader IndexReader |
54 | 56 | */ |
55 | | - KnnScoreDocQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { |
| 57 | + KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) { |
56 | 58 | this.docs = docs; |
57 | 59 | this.scores = scores; |
58 | | - this.segmentStarts = segmentStarts; |
59 | | - this.contextIdentity = contextIdentity; |
| 60 | + this.segmentStarts = findSegmentStarts(reader, docs); |
| 61 | + this.contextIdentity = reader.getContext().id(); |
| 62 | + } |
| 63 | + |
| 64 | + private static int[] findSegmentStarts(IndexReader reader, int[] docs) { |
| 65 | + int[] starts = new int[reader.leaves().size() + 1]; |
| 66 | + starts[starts.length - 1] = docs.length; |
| 67 | + if (starts.length == 2) { |
| 68 | + return starts; |
| 69 | + } |
| 70 | + int resultIndex = 0; |
| 71 | + for (int i = 1; i < starts.length - 1; i++) { |
| 72 | + int upper = reader.leaves().get(i).docBase; |
| 73 | + resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); |
| 74 | + if (resultIndex < 0) { |
| 75 | + resultIndex = -1 - resultIndex; |
| 76 | + } |
| 77 | + starts[i] = resultIndex; |
| 78 | + } |
| 79 | + return starts; |
60 | 80 | } |
61 | 81 |
|
62 | 82 | @Override |
|
0 commit comments