Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
Expand Down Expand Up @@ -128,13 +129,14 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
collectAllMatchingDocs(knnCollector, acceptDocs, reader.getRandomVectorScorer(field, target));
}

private void collectAllMatchingDocs(KnnCollector knnCollector, Bits acceptDocs, RandomVectorScorer scorer) throws IOException {
private void collectAllMatchingDocs(KnnCollector knnCollector, AcceptDocs acceptDocs, RandomVectorScorer scorer)
throws IOException {
OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits());
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
collector.collect(i, scorer.score(i));
Expand All @@ -145,7 +147,7 @@ private void collectAllMatchingDocs(KnnCollector knnCollector, Bits acceptDocs,
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
collectAllMatchingDocs(knnCollector, acceptDocs, reader.getRandomVectorScorer(field, target));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
Expand Down Expand Up @@ -136,13 +137,14 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
collectAllMatchingDocs(knnCollector, acceptDocs, reader.getRandomVectorScorer(field, target));
}

private void collectAllMatchingDocs(KnnCollector knnCollector, Bits acceptDocs, RandomVectorScorer scorer) throws IOException {
private void collectAllMatchingDocs(KnnCollector knnCollector, AcceptDocs acceptDocs, RandomVectorScorer scorer)
throws IOException {
OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits());
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
collector.collect(i, scorer.score(i));
Expand All @@ -153,7 +155,7 @@ private void collectAllMatchingDocs(KnnCollector knnCollector, Bits acceptDocs,
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
collectAllMatchingDocs(knnCollector, acceptDocs, reader.getRandomVectorScorer(field, target));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
Expand Down Expand Up @@ -212,7 +212,7 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti
}

@Override
public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
public final void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) {
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
Expand All @@ -223,11 +223,8 @@ public final void search(String field, float[] target, KnnCollector knnCollector
"vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension()
);
}
float percentFiltered = 1f;
if (acceptDocs instanceof BitSet bitSet) {
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
}
int numVectors = rawVectorsReader.getFloatVectorValues(field).size();
float percentFiltered = Math.max(0f, Math.min(1f, (float) acceptDocs.cost() / numVectors));
float visitRatio = DYNAMIC_VISIT_RATIO;
// Search strategy may be null if this is being called from checkIndex (e.g. from a test)
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
Expand Down Expand Up @@ -255,7 +252,8 @@ public final void search(String field, float[] target, KnnCollector knnCollector
target,
postListSlice
);
PostingVisitor scorer = getPostingVisitor(fieldInfo, postListSlice, target, acceptDocs);
Bits acceptDocsBits = acceptDocs.bits();
PostingVisitor scorer = getPostingVisitor(fieldInfo, postListSlice, target, acceptDocsBits);
long expectedDocs = 0;
long actualDocs = 0;
// initially we visit only the "centroids to search"
Expand All @@ -271,7 +269,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
expectedDocs += scorer.resetPostingsScorer(offsetAndLength.offset());
actualDocs += scorer.visit(knnCollector);
}
if (acceptDocs != null) {
if (acceptDocsBits != null) {
float unfilteredRatioVisited = (float) expectedDocs / numVectors;
int filteredVectors = (int) Math.ceil(numVectors * percentFiltered);
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
Expand All @@ -284,7 +282,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
}

@Override
public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
public final void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
final ByteVectorValues values = rawVectorsReader.getByteVectorValues(field);
for (int i = 0; i < values.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.DataAccessHint;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.LongValues;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.core.IOUtils;
Expand Down Expand Up @@ -302,11 +302,11 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
try (
IndexInput vectors = mergeState.segmentInfo.dir.openInput(
tempRawVectorsFileName,
IOContext.DEFAULT.withReadAdvice(ReadAdvice.SEQUENTIAL)
IOContext.DEFAULT.withHints(DataAccessHint.SEQUENTIAL)
);
IndexInput docs = docsFileName == null
? null
: mergeState.segmentInfo.dir.openInput(docsFileName, IOContext.DEFAULT.withReadAdvice(ReadAdvice.SEQUENTIAL))
: mergeState.segmentInfo.dir.openInput(docsFileName, IOContext.DEFAULT.withHints(DataAccessHint.SEQUENTIAL))
) {
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
Expand Down Expand Up @@ -226,17 +227,17 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
if (knnCollector.k() == 0) return;
final RandomVectorScorer scorer = getRandomVectorScorer(field, target);
if (scorer == null) return;
OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits());
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
collector.collect(i, scorer.score(i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
Expand Down Expand Up @@ -240,17 +241,17 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
if (knnCollector.k() == 0) return;
final RandomVectorScorer scorer = getRandomVectorScorer(field, target);
if (scorer == null) return;
OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits());
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
collector.collect(i, scorer.score(i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.elasticsearch.core.IOUtils;

Expand Down Expand Up @@ -60,12 +60,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
mainReader.search(field, target, knnCollector, acceptDocs);
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
mainReader.search(field, target, knnCollector, acceptDocs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ByteBuffersDirectory;
Expand Down Expand Up @@ -447,12 +448,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void searchNearestVectors(String field, float[] target, KnnCollector collector, Bits acceptDocs) throws IOException {
public void searchNearestVectors(String field, float[] target, KnnCollector collector, AcceptDocs acceptDocs) throws IOException {
getDelegate().searchNearestVectors(field, target, collector, acceptDocs);
}

@Override
public void searchNearestVectors(String field, byte[] target, KnnCollector collector, Bits acceptDocs) throws IOException {
public void searchNearestVectors(String field, byte[] target, KnnCollector collector, AcceptDocs acceptDocs) throws IOException {
getDelegate().searchNearestVectors(field, target, collector, acceptDocs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.memory.MemoryIndex;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Bits;
Expand Down Expand Up @@ -210,7 +211,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
}

@Override
public void searchNearestVectors(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {
public void searchNearestVectors(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {
throw new UnsupportedOperationException();
}

Expand Down Expand Up @@ -255,7 +256,7 @@ public ByteVectorValues getByteVectorValues(String field) {
}

@Override
public void searchNearestVectors(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {
public void searchNearestVectors(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.suggest.document.CompletionTerms;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.automaton.CompiledAutomaton;
import org.elasticsearch.common.lucene.index.SequentialStoredFieldsLeafReader;
Expand Down Expand Up @@ -141,7 +141,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void searchNearestVectors(String field, byte[] target, KnnCollector collector, Bits acceptDocs) throws IOException {
public void searchNearestVectors(String field, byte[] target, KnnCollector collector, AcceptDocs acceptDocs) throws IOException {
if (queryCancellation.isEnabled() == false) {
in.searchNearestVectors(field, target, collector, acceptDocs);
return;
Expand All @@ -159,7 +159,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
}

@Override
public void searchNearestVectors(String field, float[] target, KnnCollector collector, Bits acceptDocs) throws IOException {
public void searchNearestVectors(String field, float[] target, KnnCollector collector, AcceptDocs acceptDocs) throws IOException {
if (queryCancellation.isEnabled() == false) {
in.searchNearestVectors(field, target, collector, acceptDocs);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import org.apache.lucene.index.TermVectors;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.suggest.document.CompletionTerms;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.automaton.CompiledAutomaton;
import org.elasticsearch.common.lucene.index.SequentialStoredFieldsLeafReader;
Expand Down Expand Up @@ -221,15 +221,15 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void searchNearestVectors(String field, byte[] target, KnnCollector collector, Bits acceptDocs) throws IOException {
public void searchNearestVectors(String field, byte[] target, KnnCollector collector, AcceptDocs acceptDocs) throws IOException {
super.searchNearestVectors(field, target, collector, acceptDocs);
if (collector.visitedCount() > 0) {
notifier.onKnnVectorsUsed(field);
}
}

@Override
public void searchNearestVectors(String field, float[] target, KnnCollector collector, Bits acceptDocs) throws IOException {
public void searchNearestVectors(String field, float[] target, KnnCollector collector, AcceptDocs acceptDocs) throws IOException {
super.searchNearestVectors(field, target, collector, acceptDocs);
if (collector.visitedCount() > 0) {
notifier.onKnnVectorsUsed(field);
Expand Down
Loading