|
22 | 22 | import org.apache.lucene.codecs.hnsw.FlatVectorsReader; |
23 | 23 | import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; |
24 | 24 | import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; |
| 25 | +import org.apache.lucene.codecs.lucene95.HasIndexSlice; |
25 | 26 | import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader; |
26 | 27 | import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter; |
| 28 | +import org.apache.lucene.index.ByteVectorValues; |
| 29 | +import org.apache.lucene.index.FieldInfo; |
| 30 | +import org.apache.lucene.index.FloatVectorValues; |
| 31 | +import org.apache.lucene.index.KnnVectorValues; |
27 | 32 | import org.apache.lucene.index.SegmentReadState; |
28 | 33 | import org.apache.lucene.index.SegmentWriteState; |
| 34 | +import org.apache.lucene.index.VectorSimilarityFunction; |
| 35 | +import org.apache.lucene.search.ConjunctionUtils; |
| 36 | +import org.apache.lucene.search.DocAndFloatFeatureBuffer; |
| 37 | +import org.apache.lucene.search.DocIdSetIterator; |
| 38 | +import org.apache.lucene.search.VectorScorer; |
29 | 39 | import org.apache.lucene.store.FlushInfo; |
30 | 40 | import org.apache.lucene.store.IOContext; |
| 41 | +import org.apache.lucene.store.IndexInput; |
31 | 42 | import org.apache.lucene.store.MergeInfo; |
| 43 | +import org.apache.lucene.util.Bits; |
| 44 | +import org.apache.lucene.util.hnsw.RandomVectorScorer; |
32 | 45 | import org.elasticsearch.common.util.set.Sets; |
33 | 46 | import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat; |
| 47 | +import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues; |
| 48 | +import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues; |
34 | 49 | import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; |
35 | 50 | import org.elasticsearch.index.store.FsDirectoryFactory; |
36 | 51 |
|
37 | 52 | import java.io.IOException; |
| 53 | +import java.util.List; |
38 | 54 | import java.util.Set; |
39 | 55 |
|
40 | 56 | /** |
@@ -86,8 +102,13 @@ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException |
86 | 102 | ); |
87 | 103 | // Use mmap for merges and direct I/O for searches. |
88 | 104 | // TODO: Open the mmap file with sequential access instead of random (current behavior). |
| 105 | + // TODO: maybe we should force completely RANDOM access always for inner reader formats (outside of merges)? |
89 | 106 | return new MergeReaderWrapper( |
90 | | - new Lucene99FlatVectorsReader(directIOState, vectorsScorer), |
| 107 | + new Lucene99FlatBulkScoringVectorsReader( |
| 108 | + directIOState, |
| 109 | + new Lucene99FlatVectorsReader(directIOState, vectorsScorer), |
| 110 | + vectorsScorer |
| 111 | + ), |
91 | 112 | new Lucene99FlatVectorsReader(state, vectorsScorer) |
92 | 113 | ); |
93 | 114 | } else { |
@@ -129,4 +150,204 @@ public IOContext withHints(FileOpenHint... hints) { |
129 | 150 | return new DirectIOContext(Set.of(hints)); |
130 | 151 | } |
131 | 152 | } |
| 153 | + |
| 154 | + static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader { |
| 155 | + private final Lucene99FlatVectorsReader inner; |
| 156 | + private final SegmentReadState state; |
| 157 | + |
| 158 | + Lucene99FlatBulkScoringVectorsReader(SegmentReadState state, Lucene99FlatVectorsReader inner, FlatVectorsScorer scorer) { |
| 159 | + super(scorer); |
| 160 | + this.inner = inner; |
| 161 | + this.state = state; |
| 162 | + } |
| 163 | + |
| 164 | + @Override |
| 165 | + public void close() throws IOException { |
| 166 | + inner.close(); |
| 167 | + } |
| 168 | + |
| 169 | + @Override |
| 170 | + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { |
| 171 | + return inner.getRandomVectorScorer(field, target); |
| 172 | + } |
| 173 | + |
| 174 | + @Override |
| 175 | + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { |
| 176 | + return inner.getRandomVectorScorer(field, target); |
| 177 | + } |
| 178 | + |
| 179 | + @Override |
| 180 | + public void checkIntegrity() throws IOException { |
| 181 | + inner.checkIntegrity(); |
| 182 | + } |
| 183 | + |
| 184 | + @Override |
| 185 | + public FloatVectorValues getFloatVectorValues(String field) throws IOException { |
| 186 | + FloatVectorValues vectorValues = inner.getFloatVectorValues(field); |
| 187 | + if (vectorValues == null || vectorValues.size() == 0) { |
| 188 | + return null; |
| 189 | + } |
| 190 | + FieldInfo info = state.fieldInfos.fieldInfo(field); |
| 191 | + return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer); |
| 192 | + } |
| 193 | + |
| 194 | + @Override |
| 195 | + public ByteVectorValues getByteVectorValues(String field) throws IOException { |
| 196 | + return inner.getByteVectorValues(field); |
| 197 | + } |
| 198 | + |
| 199 | + @Override |
| 200 | + public long ramBytesUsed() { |
| 201 | + return inner.ramBytesUsed(); |
| 202 | + } |
| 203 | + |
| 204 | + static class RescorerOffHeapVectorValues extends FloatVectorValues implements BulkScorableFloatVectorValues { |
| 205 | + VectorSimilarityFunction similarityFunction; |
| 206 | + FloatVectorValues inner; |
| 207 | + IndexInput inputSlice; |
| 208 | + FlatVectorsScorer scorer; |
| 209 | + |
| 210 | + RescorerOffHeapVectorValues(FloatVectorValues inner, VectorSimilarityFunction similarityFunction, FlatVectorsScorer scorer) { |
| 211 | + this.inner = inner; |
| 212 | + if (inner instanceof HasIndexSlice slice) { |
| 213 | + this.inputSlice = slice.getSlice(); |
| 214 | + } else { |
| 215 | + this.inputSlice = null; |
| 216 | + } |
| 217 | + this.similarityFunction = similarityFunction; |
| 218 | + this.scorer = scorer; |
| 219 | + } |
| 220 | + |
| 221 | + @Override |
| 222 | + public float[] vectorValue(int ord) throws IOException { |
| 223 | + return inner.vectorValue(ord); |
| 224 | + } |
| 225 | + |
| 226 | + @Override |
| 227 | + public int dimension() { |
| 228 | + return inner.dimension(); |
| 229 | + } |
| 230 | + |
| 231 | + @Override |
| 232 | + public int size() { |
| 233 | + return inner.size(); |
| 234 | + } |
| 235 | + |
| 236 | + @Override |
| 237 | + public RescorerOffHeapVectorValues copy() throws IOException { |
| 238 | + return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer); |
| 239 | + } |
| 240 | + |
| 241 | + @Override |
| 242 | + public BulkVectorScorer bulkRescorer(float[] target) throws IOException { |
| 243 | + return bulkScorer(target); |
| 244 | + } |
| 245 | + |
| 246 | + @Override |
| 247 | + public BulkVectorScorer bulkScorer(float[] target) throws IOException { |
| 248 | + DocIndexIterator indexIterator = inner.iterator(); |
| 249 | + RandomVectorScorer randomScorer = scorer.getRandomVectorScorer(similarityFunction, inner, target); |
| 250 | + return new PreFetchingFloatBulkScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES); |
| 251 | + } |
| 252 | + |
| 253 | + @Override |
| 254 | + public VectorScorer scorer(float[] target) throws IOException { |
| 255 | + return inner.scorer(target); |
| 256 | + } |
| 257 | + } |
| 258 | + |
| 259 | + private record PreFetchingFloatBulkScorer( |
| 260 | + RandomVectorScorer inner, |
| 261 | + KnnVectorValues.DocIndexIterator indexIterator, |
| 262 | + IndexInput inputSlice, |
| 263 | + int byteSize |
| 264 | + ) implements BulkScorableVectorValues.BulkVectorScorer { |
| 265 | + |
| 266 | + @Override |
| 267 | + public float score() throws IOException { |
| 268 | + return inner.score(indexIterator.index()); |
| 269 | + } |
| 270 | + |
| 271 | + @Override |
| 272 | + public DocIdSetIterator iterator() { |
| 273 | + return indexIterator; |
| 274 | + } |
| 275 | + |
| 276 | + @Override |
| 277 | + public Bulk bulk(DocIdSetIterator matchingDocs) throws IOException { |
| 278 | + DocIdSetIterator conjunctionScorer = matchingDocs == null |
| 279 | + ? indexIterator |
| 280 | + : ConjunctionUtils.intersectIterators(List.of(matchingDocs, indexIterator)); |
| 281 | + if (conjunctionScorer.docID() == -1) { |
| 282 | + conjunctionScorer.nextDoc(); |
| 283 | + } |
| 284 | + return new FloatBulkScorer(inner, inputSlice, byteSize, 32, indexIterator, conjunctionScorer); |
| 285 | + } |
| 286 | + } |
| 287 | + |
| 288 | + private static class FloatBulkScorer implements BulkScorableVectorValues.BulkVectorScorer.Bulk { |
| 289 | + private final KnnVectorValues.DocIndexIterator indexIterator; |
| 290 | + private final DocIdSetIterator matchingDocs; |
| 291 | + private final RandomVectorScorer inner; |
| 292 | + private final int bulkSize; |
| 293 | + private final IndexInput inputSlice; |
| 294 | + private final int byteSize; |
| 295 | + private final int[] docBuffer; |
| 296 | + private final float[] scoreBuffer; |
| 297 | + |
| 298 | + FloatBulkScorer( |
| 299 | + RandomVectorScorer fvv, |
| 300 | + IndexInput inputSlice, |
| 301 | + int byteSize, |
| 302 | + int bulkSize, |
| 303 | + KnnVectorValues.DocIndexIterator iterator, |
| 304 | + DocIdSetIterator matchingDocs |
| 305 | + ) { |
| 306 | + this.indexIterator = iterator; |
| 307 | + this.matchingDocs = matchingDocs; |
| 308 | + this.inner = fvv; |
| 309 | + this.bulkSize = bulkSize; |
| 310 | + this.inputSlice = inputSlice; |
| 311 | + this.docBuffer = new int[bulkSize]; |
| 312 | + this.scoreBuffer = new float[bulkSize]; |
| 313 | + this.byteSize = byteSize; |
| 314 | + } |
| 315 | + |
| 316 | + @Override |
| 317 | + public void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException { |
| 318 | + buffer.growNoCopy(nextCount); |
| 319 | + int size = 0; |
| 320 | + for (int doc = matchingDocs.docID(); doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount; doc = matchingDocs |
| 321 | + .nextDoc()) { |
| 322 | + if (liveDocs == null || liveDocs.get(doc)) { |
| 323 | + buffer.docs[size++] = indexIterator.index(); |
| 324 | + } |
| 325 | + } |
| 326 | + int loopBound = size - (size % bulkSize); |
| 327 | + int i = 0; |
| 328 | + for (; i < loopBound; i += bulkSize) { |
| 329 | + for (int j = 0; j < bulkSize; j++) { |
| 330 | + long ord = buffer.docs[i + j]; |
| 331 | + inputSlice.prefetch(ord * byteSize, byteSize); |
| 332 | + } |
| 333 | + System.arraycopy(buffer.docs, i, docBuffer, 0, bulkSize); |
| 334 | + inner.bulkScore(docBuffer, scoreBuffer, bulkSize); |
| 335 | + System.arraycopy(scoreBuffer, 0, buffer.features, i, bulkSize); |
| 336 | + } |
| 337 | + int countLeft = size - i; |
| 338 | + for (int j = i; j < size; j++) { |
| 339 | + long ord = buffer.docs[j]; |
| 340 | + inputSlice.prefetch(ord * byteSize, byteSize); |
| 341 | + } |
| 342 | + System.arraycopy(buffer.docs, i, docBuffer, 0, countLeft); |
| 343 | + inner.bulkScore(docBuffer, scoreBuffer, countLeft); |
| 344 | + System.arraycopy(scoreBuffer, 0, buffer.features, i, countLeft); |
| 345 | + buffer.size = size; |
| 346 | + // fix the docIds in buffer |
| 347 | + for (int j = 0; j < size; j++) { |
| 348 | + buffer.docs[j] = inner.ordToDoc(buffer.docs[j]); |
| 349 | + } |
| 350 | + } |
| 351 | + } |
| 352 | + } |
132 | 353 | } |
0 commit comments