Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
*/
public class IVFVectorsFormat extends KnnVectorsFormat {

static final FeatureFlag IVF_FORMAT_FEATURE_FLAG = new FeatureFlag("ivf_format");
public static final FeatureFlag IVF_FORMAT_FEATURE_FLAG = new FeatureFlag("ivf_format");
public static final String IVF_VECTOR_COMPONENT = "IVF";
public static final String NAME = "IVFVectorsFormat";
// centroid ordinals -> centroid values, offsets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;

import java.io.IOException;
import java.util.function.IntPredicate;
Expand Down Expand Up @@ -243,8 +244,11 @@ public final void search(String field, float[] target, KnnCollector knnCollector
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
return;
}
// TODO add new ivf search strategy
int nProbe = 10;
if (fieldInfo.getVectorDimension() != target.length) {
throw new IllegalArgumentException(
"vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension()
);
}
float percentFiltered = 1f;
if (acceptDocs instanceof BitSet bitSet) {
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
Expand All @@ -257,6 +261,13 @@ public final void search(String field, float[] target, KnnCollector knnCollector
}
return visitedDocs.getAndSet(docId) == false;
};
final int nProbe;
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
nProbe = ivfSearchStrategy.getNProbe();
} else {
// TODO calculate nProbe given the number of centroids vs. number of vectors for given `k`
nProbe = 10;
}

FieldEntry entry = fields.get(fieldInfo.number);
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.vectors;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.elasticsearch.search.profile.query.QueryProfiler;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;

abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerProvider {

static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;

protected final String field;
protected final int nProbe;
protected final int k;
protected final Query filter;
protected final KnnSearchStrategy searchStrategy;
protected int vectorOpsCount;

protected AbstractIVFKnnVectorQuery(String field, int nProbe, int k, Query filter) {
this.field = field;
this.nProbe = nProbe;
this.k = k;
this.filter = filter;
this.searchStrategy = new IVFKnnSearchStrategy(nProbe);
}

@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AbstractIVFKnnVectorQuery that = (AbstractIVFKnnVectorQuery) o;
return k == that.k
&& Objects.equals(field, that.field)
&& Objects.equals(filter, that.filter)
&& Objects.equals(nProbe, that.nProbe);
}

@Override
public int hashCode() {
return Objects.hash(field, k, filter, nProbe);
}

@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
vectorOpsCount = 0;
IndexReader reader = indexSearcher.getIndexReader();

final Weight filterWeight;
if (filter != null) {
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(filter, BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
if (rewritten.getClass() == MatchNoDocsQuery.class) {
return rewritten;
}
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
} else {
filterWeight = null;
}
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext context : leafReaderContexts) {
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
}
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);

// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
vectorOpsCount = (int) topK.totalHits.value();
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery();
}
return new KnnScoreDocQuery(topK.scoreDocs, reader);
}

private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
return results;
}

TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
final LeafReader reader = ctx.reader();
final Bits liveDocs = reader.getLiveDocs();

if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
}

Scorer scorer = filterWeight.scorer(ctx);
if (scorer == null) {
return TopDocsCollector.EMPTY_TOPDOCS;
}

BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
final int cost = acceptDocs.cardinality();
return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager);
}

abstract TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager
) throws IOException;

protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new IVFCollectorManager(k, nProbe);
}

@Override
public final void profile(QueryProfiler queryProfiler) {
queryProfiler.addVectorOpsCount(vectorOpsCount);
}

BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException {
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return bitSetIterator.getBitSet();
} else {
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(iterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
return BitSet.of(filterIterator, maxDoc);
}
}

static class IVFCollectorManager implements KnnCollectorManager {
private final int k;
private final int nprobe;

IVFCollectorManager(int k, int nprobe) {
this.k = k;
this.nprobe = nprobe;
}

@Override
public KnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context) throws IOException {
return new TopKnnCollector(k, visitedLimit, new IVFKnnSearchStrategy(nprobe));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.search.vectors;

import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.Bits;

import java.io.IOException;
import java.util.Arrays;

/** A {@link IVFKnnFloatVectorQuery} that uses the IVF search strategy. */
public class IVFKnnFloatVectorQuery extends AbstractIVFKnnVectorQuery {

private final float[] query;

/**
* Creates a new {@link IVFKnnFloatVectorQuery} with the given parameters.
* @param field the field to search
* @param query the query vector
* @param k the number of nearest neighbors to return
* @param filter the filter to apply to the results
* @param nProbe the number of probes to use for the IVF search strategy
*/
public IVFKnnFloatVectorQuery(String field, float[] query, int k, Query filter, int nProbe) {
super(field, nProbe, k, filter);
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
}
if (nProbe < 1) {
throw new IllegalArgumentException("nProbe must be at least 1, got: " + nProbe);
}
this.query = query;
}

@Override
public String toString(String field) {
StringBuilder buffer = new StringBuilder();
buffer.append(getClass().getSimpleName())
.append(":")
.append(this.field)
.append("[")
.append(query[0])
.append(",...]")
.append("[")
.append(k)
.append("]");
if (this.filter != null) {
buffer.append("[").append(this.filter).append("]");
}
return buffer.toString();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (super.equals(o) == false) return false;
IVFKnnFloatVectorQuery that = (IVFKnnFloatVectorQuery) o;
return Arrays.equals(query, that.query);
}

@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + Arrays.hashCode(query);
return result;
}

@Override
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager
) throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, searchStrategy, context);
LeafReader reader = context.reader();
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
if (floatVectorValues == null) {
FloatVectorValues.checkField(reader, field);
return NO_RESULTS;
}
if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) {
return NO_RESULTS;
}
reader.searchNearestVectors(field, query, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
return results != null ? results : NO_RESULTS;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.search.vectors;

import org.apache.lucene.search.knn.KnnSearchStrategy;

import java.util.Objects;

public class IVFKnnSearchStrategy extends KnnSearchStrategy {
private final int nProbe;

IVFKnnSearchStrategy(int nProbe) {
this.nProbe = nProbe;
}

public int getNProbe() {
return nProbe;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IVFKnnSearchStrategy that = (IVFKnnSearchStrategy) o;
return nProbe == that.nProbe;
}

@Override
public int hashCode() {
return Objects.hashCode(nProbe);
}

@Override
public void nextVectorsBlock() {
// do nothing
}
}
Loading