|
11 | 11 | import org.apache.lucene.codecs.hnsw.FlatVectorsReader; |
12 | 12 | import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; |
13 | 13 | import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; |
| 14 | +import org.apache.lucene.codecs.lucene95.HasIndexSlice; |
14 | 15 | import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader; |
15 | 16 | import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter; |
| 17 | +import org.apache.lucene.index.ByteVectorValues; |
| 18 | +import org.apache.lucene.index.FieldInfo; |
| 19 | +import org.apache.lucene.index.FloatVectorValues; |
| 20 | +import org.apache.lucene.index.KnnVectorValues; |
16 | 21 | import org.apache.lucene.index.SegmentReadState; |
17 | 22 | import org.apache.lucene.index.SegmentWriteState; |
| 23 | +import org.apache.lucene.index.VectorSimilarityFunction; |
| 24 | +import org.apache.lucene.search.ConjunctionUtils; |
| 25 | +import org.apache.lucene.search.DocAndFloatFeatureBuffer; |
| 26 | +import org.apache.lucene.search.DocIdSetIterator; |
| 27 | +import org.apache.lucene.search.VectorScorer; |
18 | 28 | import org.apache.lucene.store.FlushInfo; |
19 | 29 | import org.apache.lucene.store.IOContext; |
| 30 | +import org.apache.lucene.store.IndexInput; |
20 | 31 | import org.apache.lucene.store.MergeInfo; |
| 32 | +import org.apache.lucene.util.Bits; |
| 33 | +import org.apache.lucene.util.hnsw.RandomVectorScorer; |
21 | 34 | import org.elasticsearch.common.util.set.Sets; |
| 35 | +import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues; |
| 36 | +import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues; |
22 | 37 | import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; |
23 | 38 | import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; |
24 | 39 | import org.elasticsearch.index.codec.vectors.es818.DirectIOHint; |
25 | 40 | import org.elasticsearch.index.store.FsDirectoryFactory; |
26 | 41 |
|
27 | 42 | import java.io.IOException; |
| 43 | +import java.util.List; |
28 | 44 | import java.util.Set; |
29 | 45 |
|
30 | 46 | public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat { |
@@ -71,7 +87,11 @@ public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectI |
71 | 87 | ); |
72 | 88 | // Use mmap for merges and direct I/O for searches. |
73 | 89 | return new MergeReaderWrapper( |
74 | | - new Lucene99FlatVectorsReader(directIOState, vectorsScorer), |
| 90 | + new Lucene99FlatBulkScoringVectorsReader( |
| 91 | + directIOState, |
| 92 | + new Lucene99FlatVectorsReader(directIOState, vectorsScorer), |
| 93 | + vectorsScorer |
| 94 | + ), |
75 | 95 | new Lucene99FlatVectorsReader(state, vectorsScorer) |
76 | 96 | ); |
77 | 97 | } else { |
@@ -113,4 +133,203 @@ public IOContext withHints(FileOpenHint... hints) { |
113 | 133 | return new DirectIOContext(Set.of(hints)); |
114 | 134 | } |
115 | 135 | } |
| 136 | + |
| 137 | + static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader { |
| 138 | + private final Lucene99FlatVectorsReader inner; |
| 139 | + private final SegmentReadState state; |
| 140 | + |
| 141 | + Lucene99FlatBulkScoringVectorsReader(SegmentReadState state, Lucene99FlatVectorsReader inner, FlatVectorsScorer scorer) { |
| 142 | + super(scorer); |
| 143 | + this.inner = inner; |
| 144 | + this.state = state; |
| 145 | + } |
| 146 | + |
| 147 | + @Override |
| 148 | + public void close() throws IOException { |
| 149 | + inner.close(); |
| 150 | + } |
| 151 | + |
| 152 | + @Override |
| 153 | + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { |
| 154 | + return inner.getRandomVectorScorer(field, target); |
| 155 | + } |
| 156 | + |
| 157 | + @Override |
| 158 | + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { |
| 159 | + return inner.getRandomVectorScorer(field, target); |
| 160 | + } |
| 161 | + |
| 162 | + @Override |
| 163 | + public void checkIntegrity() throws IOException { |
| 164 | + inner.checkIntegrity(); |
| 165 | + } |
| 166 | + |
| 167 | + @Override |
| 168 | + public FloatVectorValues getFloatVectorValues(String field) throws IOException { |
| 169 | + FloatVectorValues vectorValues = inner.getFloatVectorValues(field); |
| 170 | + if (vectorValues == null || vectorValues.size() == 0) { |
| 171 | + return null; |
| 172 | + } |
| 173 | + FieldInfo info = state.fieldInfos.fieldInfo(field); |
| 174 | + return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer); |
| 175 | + } |
| 176 | + |
| 177 | + @Override |
| 178 | + public ByteVectorValues getByteVectorValues(String field) throws IOException { |
| 179 | + return inner.getByteVectorValues(field); |
| 180 | + } |
| 181 | + |
| 182 | + @Override |
| 183 | + public long ramBytesUsed() { |
| 184 | + return inner.ramBytesUsed(); |
| 185 | + } |
| 186 | + } |
| 187 | + |
| 188 | + static class RescorerOffHeapVectorValues extends FloatVectorValues implements BulkScorableFloatVectorValues { |
| 189 | + VectorSimilarityFunction similarityFunction; |
| 190 | + FloatVectorValues inner; |
| 191 | + IndexInput inputSlice; |
| 192 | + FlatVectorsScorer scorer; |
| 193 | + |
| 194 | + RescorerOffHeapVectorValues(FloatVectorValues inner, VectorSimilarityFunction similarityFunction, FlatVectorsScorer scorer) { |
| 195 | + this.inner = inner; |
| 196 | + if (inner instanceof HasIndexSlice slice) { |
| 197 | + this.inputSlice = slice.getSlice(); |
| 198 | + } else { |
| 199 | + this.inputSlice = null; |
| 200 | + } |
| 201 | + this.similarityFunction = similarityFunction; |
| 202 | + this.scorer = scorer; |
| 203 | + } |
| 204 | + |
| 205 | + @Override |
| 206 | + public float[] vectorValue(int ord) throws IOException { |
| 207 | + return inner.vectorValue(ord); |
| 208 | + } |
| 209 | + |
| 210 | + @Override |
| 211 | + public int dimension() { |
| 212 | + return inner.dimension(); |
| 213 | + } |
| 214 | + |
| 215 | + @Override |
| 216 | + public int size() { |
| 217 | + return inner.size(); |
| 218 | + } |
| 219 | + |
| 220 | + @Override |
| 221 | + public RescorerOffHeapVectorValues copy() throws IOException { |
| 222 | + return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer); |
| 223 | + } |
| 224 | + |
| 225 | + @Override |
| 226 | + public BulkVectorScorer bulkRescorer(float[] target) throws IOException { |
| 227 | + return bulkScorer(target); |
| 228 | + } |
| 229 | + |
| 230 | + @Override |
| 231 | + public BulkVectorScorer bulkScorer(float[] target) throws IOException { |
| 232 | + DocIndexIterator indexIterator = inner.iterator(); |
| 233 | + RandomVectorScorer randomScorer = scorer.getRandomVectorScorer(similarityFunction, inner, target); |
| 234 | + return new PreFetchingFloatBulkScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES); |
| 235 | + } |
| 236 | + |
| 237 | + @Override |
| 238 | + public VectorScorer scorer(float[] target) throws IOException { |
| 239 | + return inner.scorer(target); |
| 240 | + } |
| 241 | + } |
| 242 | + |
| 243 | + private record PreFetchingFloatBulkScorer( |
| 244 | + RandomVectorScorer inner, |
| 245 | + KnnVectorValues.DocIndexIterator indexIterator, |
| 246 | + IndexInput inputSlice, |
| 247 | + int byteSize |
| 248 | + ) implements BulkScorableVectorValues.BulkVectorScorer { |
| 249 | + |
| 250 | + @Override |
| 251 | + public float score() throws IOException { |
| 252 | + return inner.score(indexIterator.index()); |
| 253 | + } |
| 254 | + |
| 255 | + @Override |
| 256 | + public DocIdSetIterator iterator() { |
| 257 | + return indexIterator; |
| 258 | + } |
| 259 | + |
| 260 | + @Override |
| 261 | + public BulkScorer bulkScore(DocIdSetIterator matchingDocs) throws IOException { |
| 262 | + DocIdSetIterator conjunctionScorer = matchingDocs == null |
| 263 | + ? indexIterator |
| 264 | + : ConjunctionUtils.intersectIterators(List.of(matchingDocs, indexIterator)); |
| 265 | + if (conjunctionScorer.docID() == -1) { |
| 266 | + conjunctionScorer.nextDoc(); |
| 267 | + } |
| 268 | + return new FloatBulkScorer(inner, inputSlice, byteSize, 32, indexIterator, conjunctionScorer); |
| 269 | + } |
| 270 | + } |
| 271 | + |
| 272 | + private static class FloatBulkScorer implements BulkScorableVectorValues.BulkVectorScorer.BulkScorer { |
| 273 | + private final KnnVectorValues.DocIndexIterator indexIterator; |
| 274 | + private final DocIdSetIterator matchingDocs; |
| 275 | + private final RandomVectorScorer inner; |
| 276 | + private final int bulkSize; |
| 277 | + private final IndexInput inputSlice; |
| 278 | + private final int byteSize; |
| 279 | + private final int[] docBuffer; |
| 280 | + private final float[] scoreBuffer; |
| 281 | + |
| 282 | + FloatBulkScorer( |
| 283 | + RandomVectorScorer fvv, |
| 284 | + IndexInput inputSlice, |
| 285 | + int byteSize, |
| 286 | + int bulkSize, |
| 287 | + KnnVectorValues.DocIndexIterator iterator, |
| 288 | + DocIdSetIterator matchingDocs |
| 289 | + ) { |
| 290 | + this.indexIterator = iterator; |
| 291 | + this.matchingDocs = matchingDocs; |
| 292 | + this.inner = fvv; |
| 293 | + this.bulkSize = bulkSize; |
| 294 | + this.inputSlice = inputSlice; |
| 295 | + this.docBuffer = new int[bulkSize]; |
| 296 | + this.scoreBuffer = new float[bulkSize]; |
| 297 | + this.byteSize = byteSize; |
| 298 | + } |
| 299 | + |
| 300 | + @Override |
| 301 | + public void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException { |
| 302 | + buffer.growNoCopy(nextCount); |
| 303 | + int size = 0; |
| 304 | + for (int doc = matchingDocs.docID(); doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount; doc = matchingDocs.nextDoc()) { |
| 305 | + if (liveDocs == null || liveDocs.get(doc)) { |
| 306 | + buffer.docs[size++] = indexIterator.index(); |
| 307 | + } |
| 308 | + } |
| 309 | + int loopBound = size - (size % bulkSize); |
| 310 | + int i = 0; |
| 311 | + for (; i < loopBound; i += bulkSize) { |
| 312 | + for (int j = 0; j < bulkSize; j++) { |
| 313 | + long ord = buffer.docs[i + j]; |
| 314 | + inputSlice.prefetch(ord * byteSize, byteSize); |
| 315 | + } |
| 316 | + System.arraycopy(buffer.docs, i, docBuffer, 0, bulkSize); |
| 317 | + inner.bulkScore(docBuffer, scoreBuffer, bulkSize); |
| 318 | + System.arraycopy(scoreBuffer, 0, buffer.features, i, bulkSize); |
| 319 | + } |
| 320 | + int countLeft = size - i; |
| 321 | + for (int j = i; j < size; j++) { |
| 322 | + long ord = buffer.docs[j]; |
| 323 | + inputSlice.prefetch(ord * byteSize, byteSize); |
| 324 | + } |
| 325 | + System.arraycopy(buffer.docs, i, docBuffer, 0, countLeft); |
| 326 | + inner.bulkScore(docBuffer, scoreBuffer, countLeft); |
| 327 | + System.arraycopy(scoreBuffer, 0, buffer.features, i, countLeft); |
| 328 | + buffer.size = size; |
| 329 | + // fix the docIds in buffer |
| 330 | + for (int j = 0; j < size; j++) { |
| 331 | + buffer.docs[j] = inner.ordToDoc(buffer.docs[j]); |
| 332 | + } |
| 333 | + } |
| 334 | + } |
116 | 335 | } |
0 commit comments