Skip to content

Commit 7d9f278

Browse files
committed
fixing after merge
1 parent 53f2f01 commit 7d9f278

File tree

1 file changed

+220
-1
lines changed

1 file changed

+220
-1
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java

Lines changed: 220 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,36 @@
1111
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
1212
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
1313
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
14+
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
1415
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader;
1516
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
17+
import org.apache.lucene.index.ByteVectorValues;
18+
import org.apache.lucene.index.FieldInfo;
19+
import org.apache.lucene.index.FloatVectorValues;
20+
import org.apache.lucene.index.KnnVectorValues;
1621
import org.apache.lucene.index.SegmentReadState;
1722
import org.apache.lucene.index.SegmentWriteState;
23+
import org.apache.lucene.index.VectorSimilarityFunction;
24+
import org.apache.lucene.search.ConjunctionUtils;
25+
import org.apache.lucene.search.DocAndFloatFeatureBuffer;
26+
import org.apache.lucene.search.DocIdSetIterator;
27+
import org.apache.lucene.search.VectorScorer;
1828
import org.apache.lucene.store.FlushInfo;
1929
import org.apache.lucene.store.IOContext;
30+
import org.apache.lucene.store.IndexInput;
2031
import org.apache.lucene.store.MergeInfo;
32+
import org.apache.lucene.util.Bits;
33+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
2134
import org.elasticsearch.common.util.set.Sets;
35+
import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues;
36+
import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues;
2237
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
2338
import org.elasticsearch.index.codec.vectors.MergeReaderWrapper;
2439
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
2540
import org.elasticsearch.index.store.FsDirectoryFactory;
2641

2742
import java.io.IOException;
43+
import java.util.List;
2844
import java.util.Set;
2945

3046
public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat {
@@ -71,7 +87,11 @@ public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectI
7187
);
7288
// Use mmap for merges and direct I/O for searches.
7389
return new MergeReaderWrapper(
74-
new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
90+
new Lucene99FlatBulkScoringVectorsReader(
91+
directIOState,
92+
new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
93+
vectorsScorer
94+
),
7595
new Lucene99FlatVectorsReader(state, vectorsScorer)
7696
);
7797
} else {
@@ -113,4 +133,203 @@ public IOContext withHints(FileOpenHint... hints) {
113133
return new DirectIOContext(Set.of(hints));
114134
}
115135
}
136+
137+
static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader {
138+
private final Lucene99FlatVectorsReader inner;
139+
private final SegmentReadState state;
140+
141+
Lucene99FlatBulkScoringVectorsReader(SegmentReadState state, Lucene99FlatVectorsReader inner, FlatVectorsScorer scorer) {
142+
super(scorer);
143+
this.inner = inner;
144+
this.state = state;
145+
}
146+
147+
@Override
148+
public void close() throws IOException {
149+
inner.close();
150+
}
151+
152+
@Override
153+
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
154+
return inner.getRandomVectorScorer(field, target);
155+
}
156+
157+
@Override
158+
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
159+
return inner.getRandomVectorScorer(field, target);
160+
}
161+
162+
@Override
163+
public void checkIntegrity() throws IOException {
164+
inner.checkIntegrity();
165+
}
166+
167+
@Override
168+
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
169+
FloatVectorValues vectorValues = inner.getFloatVectorValues(field);
170+
if (vectorValues == null || vectorValues.size() == 0) {
171+
return null;
172+
}
173+
FieldInfo info = state.fieldInfos.fieldInfo(field);
174+
return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer);
175+
}
176+
177+
@Override
178+
public ByteVectorValues getByteVectorValues(String field) throws IOException {
179+
return inner.getByteVectorValues(field);
180+
}
181+
182+
@Override
183+
public long ramBytesUsed() {
184+
return inner.ramBytesUsed();
185+
}
186+
}
187+
188+
static class RescorerOffHeapVectorValues extends FloatVectorValues implements BulkScorableFloatVectorValues {
189+
VectorSimilarityFunction similarityFunction;
190+
FloatVectorValues inner;
191+
IndexInput inputSlice;
192+
FlatVectorsScorer scorer;
193+
194+
RescorerOffHeapVectorValues(FloatVectorValues inner, VectorSimilarityFunction similarityFunction, FlatVectorsScorer scorer) {
195+
this.inner = inner;
196+
if (inner instanceof HasIndexSlice slice) {
197+
this.inputSlice = slice.getSlice();
198+
} else {
199+
this.inputSlice = null;
200+
}
201+
this.similarityFunction = similarityFunction;
202+
this.scorer = scorer;
203+
}
204+
205+
@Override
206+
public float[] vectorValue(int ord) throws IOException {
207+
return inner.vectorValue(ord);
208+
}
209+
210+
@Override
211+
public int dimension() {
212+
return inner.dimension();
213+
}
214+
215+
@Override
216+
public int size() {
217+
return inner.size();
218+
}
219+
220+
@Override
221+
public RescorerOffHeapVectorValues copy() throws IOException {
222+
return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer);
223+
}
224+
225+
@Override
226+
public BulkVectorScorer bulkRescorer(float[] target) throws IOException {
227+
return bulkScorer(target);
228+
}
229+
230+
@Override
231+
public BulkVectorScorer bulkScorer(float[] target) throws IOException {
232+
DocIndexIterator indexIterator = inner.iterator();
233+
RandomVectorScorer randomScorer = scorer.getRandomVectorScorer(similarityFunction, inner, target);
234+
return new PreFetchingFloatBulkScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES);
235+
}
236+
237+
@Override
238+
public VectorScorer scorer(float[] target) throws IOException {
239+
return inner.scorer(target);
240+
}
241+
}
242+
243+
private record PreFetchingFloatBulkScorer(
244+
RandomVectorScorer inner,
245+
KnnVectorValues.DocIndexIterator indexIterator,
246+
IndexInput inputSlice,
247+
int byteSize
248+
) implements BulkScorableVectorValues.BulkVectorScorer {
249+
250+
@Override
251+
public float score() throws IOException {
252+
return inner.score(indexIterator.index());
253+
}
254+
255+
@Override
256+
public DocIdSetIterator iterator() {
257+
return indexIterator;
258+
}
259+
260+
@Override
261+
public BulkScorer bulkScore(DocIdSetIterator matchingDocs) throws IOException {
262+
DocIdSetIterator conjunctionScorer = matchingDocs == null
263+
? indexIterator
264+
: ConjunctionUtils.intersectIterators(List.of(matchingDocs, indexIterator));
265+
if (conjunctionScorer.docID() == -1) {
266+
conjunctionScorer.nextDoc();
267+
}
268+
return new FloatBulkScorer(inner, inputSlice, byteSize, 32, indexIterator, conjunctionScorer);
269+
}
270+
}
271+
272+
private static class FloatBulkScorer implements BulkScorableVectorValues.BulkVectorScorer.BulkScorer {
273+
private final KnnVectorValues.DocIndexIterator indexIterator;
274+
private final DocIdSetIterator matchingDocs;
275+
private final RandomVectorScorer inner;
276+
private final int bulkSize;
277+
private final IndexInput inputSlice;
278+
private final int byteSize;
279+
private final int[] docBuffer;
280+
private final float[] scoreBuffer;
281+
282+
FloatBulkScorer(
283+
RandomVectorScorer fvv,
284+
IndexInput inputSlice,
285+
int byteSize,
286+
int bulkSize,
287+
KnnVectorValues.DocIndexIterator iterator,
288+
DocIdSetIterator matchingDocs
289+
) {
290+
this.indexIterator = iterator;
291+
this.matchingDocs = matchingDocs;
292+
this.inner = fvv;
293+
this.bulkSize = bulkSize;
294+
this.inputSlice = inputSlice;
295+
this.docBuffer = new int[bulkSize];
296+
this.scoreBuffer = new float[bulkSize];
297+
this.byteSize = byteSize;
298+
}
299+
300+
@Override
301+
public void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException {
302+
buffer.growNoCopy(nextCount);
303+
int size = 0;
304+
for (int doc = matchingDocs.docID(); doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount; doc = matchingDocs.nextDoc()) {
305+
if (liveDocs == null || liveDocs.get(doc)) {
306+
buffer.docs[size++] = indexIterator.index();
307+
}
308+
}
309+
int loopBound = size - (size % bulkSize);
310+
int i = 0;
311+
for (; i < loopBound; i += bulkSize) {
312+
for (int j = 0; j < bulkSize; j++) {
313+
long ord = buffer.docs[i + j];
314+
inputSlice.prefetch(ord * byteSize, byteSize);
315+
}
316+
System.arraycopy(buffer.docs, i, docBuffer, 0, bulkSize);
317+
inner.bulkScore(docBuffer, scoreBuffer, bulkSize);
318+
System.arraycopy(scoreBuffer, 0, buffer.features, i, bulkSize);
319+
}
320+
int countLeft = size - i;
321+
for (int j = i; j < size; j++) {
322+
long ord = buffer.docs[j];
323+
inputSlice.prefetch(ord * byteSize, byteSize);
324+
}
325+
System.arraycopy(buffer.docs, i, docBuffer, 0, countLeft);
326+
inner.bulkScore(docBuffer, scoreBuffer, countLeft);
327+
System.arraycopy(scoreBuffer, 0, buffer.features, i, countLeft);
328+
buffer.size = size;
329+
// fix the docIds in buffer
330+
for (int j = 0; j < size; j++) {
331+
buffer.docs[j] = inner.ordToDoc(buffer.docs[j]);
332+
}
333+
}
334+
}
116335
}

0 commit comments

Comments
 (0)