Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -2008,6 +2008,7 @@ public Query createKnnQuery(
VectorData queryVector,
Integer k,
int numCands,
int requestSize,
Float numCandsFactor,
Query filter,
Float similarityThreshold,
Expand All @@ -2024,6 +2025,7 @@ public Query createKnnQuery(
queryVector.asFloatVector(),
k,
numCands,
requestSize,
numCandsFactor,
filter,
similarityThreshold,
Expand Down Expand Up @@ -2090,6 +2092,7 @@ private Query createKnnFloatQuery(
float[] queryVector,
Integer k,
int numCands,
int requestSize,
Float numCandsFactor,
Query filter,
Float similarityThreshold,
Expand Down Expand Up @@ -2127,7 +2130,7 @@ && isNotUnitVector(squaredMagnitude)) {
name(),
queryVector,
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT),
k,
k == null ? requestSize : k,
knnQuery
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import org.apache.lucene.search.DoubleValues;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.elasticsearch.search.profile.query.QueryProfiler;
import org.elasticsearch.search.vectors.QueryProfilerProvider;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -29,12 +27,11 @@
* DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the
* original vector values stored in the index
*/
public class VectorSimilarityFloatValueSource extends DoubleValuesSource implements QueryProfilerProvider {
public class VectorSimilarityFloatValueSource extends DoubleValuesSource {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Profiling can be simplified, as we always know the number of results to return


private final String field;
private final float[] target;
private final VectorSimilarityFunction vectorSimilarityFunction;
private long vectorOpsCount;

public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) {
this.field = field;
Expand All @@ -52,7 +49,6 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws
return new DoubleValues() {
@Override
public double doubleValue() throws IOException {
vectorOpsCount++;
return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(iterator.index()));
}

Expand All @@ -73,11 +69,6 @@ public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
return this;
}

@Override
public void profile(QueryProfiler queryProfiler) {
queryProfiler.addVectorOpsCount(vectorOpsCount);
}

@Override
public int hashCode() {
return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
String parentPath = context.nestedLookup().getNestedParent(fieldName);
Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor();

BitSetProducer parentBitSet = null;
if (parentPath != null) {
final BitSetProducer parentBitSet;
final Query parentFilter;
NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper();
if (originalObjectMapper != null) {
Expand Down Expand Up @@ -558,17 +558,18 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
// Now join the filterQuery & parentFilter to provide the matching blocks of children
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
}
return vectorFieldType.createKnnQuery(
queryVector,
k,
adjustedNumCands,
numCandidatesFactor,
filterQuery,
vectorSimilarity,
parentBitSet
);
}
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, numCandidatesFactor, filterQuery, vectorSimilarity, null);

return vectorFieldType.createKnnQuery(
queryVector,
k,
adjustedNumCands,
requestSize,
numCandidatesFactor,
filterQuery,
vectorSimilarity,
parentBitSet
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,15 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide
private final String fieldName;
private final float[] floatTarget;
private final VectorSimilarityFunction vectorSimilarityFunction;
private final Integer k;
private final int k;
private final Query innerQuery;

private QueryProfilerProvider vectorProfiling;
private long vectorOperations = 0;

public RescoreKnnVectorQuery(
String fieldName,
float[] floatTarget,
VectorSimilarityFunction vectorSimilarityFunction,
Integer k,
int k,
Query innerQuery
) {
this.fieldName = fieldName;
Expand All @@ -54,19 +53,12 @@ public RescoreKnnVectorQuery(
@Override
public Query rewrite(IndexSearcher searcher) throws IOException {
DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
// Vector similarity VectorSimilarityFloatValueSource keep track of the compared vectors - we need that in case we don't need
// to calculate top k and return directly the query to understand how many comparisons were done
vectorProfiling = (QueryProfilerProvider) valueSource;
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
Query query = searcher.rewrite(functionScoreQuery);

if (k == null) {
// No need to calculate top k - let the request size limit the results.
return query;
}

// Retrieve top k documents from the rescored query
TopDocs topDocs = searcher.search(query, k);
vectorOperations = topDocs.totalHits.value();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We know in advance the number of comparisons done

ScoreDoc[] scoreDocs = topDocs.scoreDocs;
int[] docIds = new int[scoreDocs.length];
float[] scores = new float[scoreDocs.length];
Expand All @@ -82,7 +74,7 @@ public Query innerQuery() {
return innerQuery;
}

public Integer k() {
public int k() {
return k;
}

Expand All @@ -92,10 +84,7 @@ public void profile(QueryProfiler queryProfiler) {
queryProfilerProvider.profile(queryProfiler);
}

if (vectorProfiling == null) {
throw new IllegalStateException("Query should have been rewritten");
}
vectorProfiling.profile(queryProfiler);
queryProfiler.addVectorOpsCount(vectorOperations);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,7 @@ public void testByteVectorQueryBoundaries() throws IOException {

Exception e = expectThrows(
IllegalArgumentException.class,
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null, null)
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, 3, null, null, null, null)
);
assertThat(
e.getMessage(),
Expand All @@ -1687,6 +1687,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }),
3,
3,
3,
null,
null,
null,
Expand All @@ -1700,7 +1701,16 @@ public void testByteVectorQueryBoundaries() throws IOException {

e = expectThrows(
IllegalArgumentException.class,
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null, null)
() -> denseVectorFieldType.createKnnQuery(
VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }),
3,
3,
3,
null,
null,
null,
null
)
);
assertThat(
e.getMessage(),
Expand All @@ -1709,7 +1719,16 @@ public void testByteVectorQueryBoundaries() throws IOException {

e = expectThrows(
IllegalArgumentException.class,
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null, null)
() -> denseVectorFieldType.createKnnQuery(
VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }),
3,
3,
3,
null,
null,
null,
null
)
);
assertThat(
e.getMessage(),
Expand All @@ -1722,6 +1741,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }),
3,
3,
3,
null,
null,
null,
Expand All @@ -1736,6 +1756,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }),
3,
3,
3,
null,
null,
null,
Expand All @@ -1753,6 +1774,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }),
3,
3,
3,
null,
null,
null,
Expand Down Expand Up @@ -1787,6 +1809,7 @@ public void testFloatVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }),
3,
3,
3,
null,
null,
null,
Expand All @@ -1801,6 +1824,7 @@ public void testFloatVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }),
3,
3,
3,
null,
null,
null,
Expand All @@ -1818,6 +1842,7 @@ public void testFloatVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }),
3,
3,
3,
null,
null,
null,
Expand Down
Loading