Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.query;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorIndexType;
import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.script.MockScriptPlugin;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse;
import static org.hamcrest.Matchers.equalTo;

public class RescoreKnnVectorQueryIT extends ESIntegTestCase {

public static final String INDEX_NAME = "test";
public static final String VECTOR_FIELD = "vector";
public static final String VECTOR_SCORE_SCRIPT = "vector_scoring";
public static final String QUERY_VECTOR_PARAM = "query_vector";

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singleton(CustomScriptPlugin.class);
}

public static class CustomScriptPlugin extends MockScriptPlugin {
private static final VectorSimilarityFunction SIMILARITY_FUNCTION = DenseVectorFieldMapper.VectorSimilarity.L2_NORM
.vectorSimilarityFunction(IndexVersion.current(), DenseVectorFieldMapper.ElementType.FLOAT);

@Override
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
return Map.of(VECTOR_SCORE_SCRIPT, vars -> {
Map<?, ?> doc = (Map<?, ?>) vars.get("doc");
return SIMILARITY_FUNCTION.compare(
((DenseVectorScriptDocValues) doc.get(VECTOR_FIELD)).getVectorValue(),
(float[]) vars.get(QUERY_VECTOR_PARAM)
);
});
}
}

@Before
public void setup() throws IOException {
String type = randomFrom(
Arrays.stream(VectorIndexType.values())
.filter(VectorIndexType::isQuantized)
.map(t -> t.name().toLowerCase(Locale.ROOT))
.collect(Collectors.toCollection(ArrayList::new))
);
XContentBuilder mapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(VECTOR_FIELD)
.field("type", "dense_vector")
.field("similarity", "l2_norm")
.startObject("index_options")
.field("type", type)
.endObject()
.endObject()
.endObject()
.endObject();

Settings settings = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5))
.build();
prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get();
ensureGreen(INDEX_NAME);
}

private record TestParams(
int numDocs,
int numDims,
float[] queryVector,
int k,
int numCands,
RescoreVectorBuilder rescoreVectorBuilder
) {
public static TestParams generate() {
int numDims = randomIntBetween(32, 512) * 2; // Ensure even dimensions
int numDocs = randomIntBetween(10, 100);
int k = randomIntBetween(1, numDocs - 5);
return new TestParams(
numDocs,
numDims,
randomVector(numDims),
k,
(int) (k * randomFloatBetween(1.0f, 10.0f, true)),
new RescoreVectorBuilder(randomFloatBetween(1.0f, 100f, true))
);
}
}

public void testKnnSearchRescore() {
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnSearchGenerator = (testParams, requestBuilder) -> {
KnnSearchBuilder knnSearch = new KnnSearchBuilder(
VECTOR_FIELD,
testParams.queryVector,
testParams.k,
testParams.numCands,
testParams.rescoreVectorBuilder,
null
);
return requestBuilder.setKnnSearch(List.of(knnSearch));
};
testKnnRescore(knnSearchGenerator);
}

public void testKnnQueryRescore() {
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnQueryGenerator = (testParams, requestBuilder) -> {
KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(
VECTOR_FIELD,
testParams.queryVector,
testParams.k,
testParams.numCands,
testParams.rescoreVectorBuilder,
null
);
return requestBuilder.setQuery(knnQuery);
};
testKnnRescore(knnQueryGenerator);
}

public void testKnnRetriever() {
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnQueryGenerator = (testParams, requestBuilder) -> {
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
VECTOR_FIELD,
testParams.queryVector,
null,
testParams.k,
testParams.numCands,
testParams.rescoreVectorBuilder,
null
);
return requestBuilder.setSource(new SearchSourceBuilder().retriever(knnRetriever));
};
testKnnRescore(knnQueryGenerator);
}

private void testKnnRescore(BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> searchRequestGenerator) {
TestParams testParams = TestParams.generate();

int numDocs = testParams.numDocs;
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];

for (int i = 0; i < numDocs; i++) {
docs[i] = prepareIndex(INDEX_NAME).setId("" + i).setSource(VECTOR_FIELD, randomVector(testParams.numDims));
}
indexRandom(true, docs);

float[] queryVector = testParams.queryVector;
float oversample = randomFloatBetween(1.0f, 100f, true);
RescoreVectorBuilder rescoreVectorBuilder = new RescoreVectorBuilder(oversample);

SearchRequestBuilder requestBuilder = searchRequestGenerator.apply(
testParams,
prepareSearch(INDEX_NAME).setSize(numDocs).setTrackTotalHits(randomBoolean())
);

assertNoFailuresAndResponse(requestBuilder, knnResponse -> { compareWithExactSearch(knnResponse, queryVector, numDocs); });
}

private static void compareWithExactSearch(SearchResponse knnResponse, float[] queryVector, int docCount) {
// Do an exact query and compare
Script script = new Script(
ScriptType.INLINE,
CustomScriptPlugin.NAME,
VECTOR_SCORE_SCRIPT,
Map.of(QUERY_VECTOR_PARAM, queryVector)
);
ScriptScoreQueryBuilder scriptScoreQueryBuilder = QueryBuilders.scriptScoreQuery(new MatchAllQueryBuilder(), script);
assertNoFailuresAndResponse(prepareSearch(INDEX_NAME).setQuery(scriptScoreQueryBuilder).setSize(docCount), exactResponse -> {
assertHitCount(exactResponse, docCount);

int i = 0;
SearchHit[] exactHits = exactResponse.getHits().getHits();
for (SearchHit knnHit : knnResponse.getHits().getHits()) {
while (i < exactHits.length && exactHits[i].getId().equals(knnHit.getId()) == false) {
i++;
}
if (i >= exactHits.length) {
fail("Knn doc not found in exact search");
}
assertThat("Real score is not the same as rescored score", knnHit.getScore(), equalTo(exactHits[i].getScore()));
}
});
}

private static float[] randomVector(int numDimensions) {
float[] vector = new float[numDimensions];
for (int j = 0; j < numDimensions; j++) {
vector[j] = randomFloatBetween(0, 1, true);
}
return vector;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,7 @@ public final int hashCode() {
}
}

private enum VectorIndexType {
public enum VectorIndexType {
HNSW("hnsw", false) {
@Override
public IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.elasticsearch.core.Assertions;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -55,6 +56,13 @@ public class KnnScoreDocQuery extends Query {
* @param reader IndexReader
*/
KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) {
if (Assertions.ENABLED) {
assert docs.length == scores.length;
for (int i = 1; i < docs.length; i++) {
assert docs[i - 1] < docs[i] : "doc ids are not in order: " + Arrays.toString(docs);
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it sounds like this is redundant if we have appropriate test coverage? I was also wondering if it may be worth changing the two first arguments into a ScoreDoc[] given that's how stuff comes in, and perhaps unifying the sorting here. I realize though that this is a copy of a Lucene class and the change I am suggesting will make it diverge from its original source.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it sounds like this is redundant if we have appropriate test coverage?

My thinking was to provide a way to understand a test failure in an easier way in case someone provided a non-sorted array, instead of going through all the investigations that you had to do 😓

I'm happy with removing the assertion in case you think it's unnecessary, but I think it helps to understand what the preconditions for this constructor are.

I was also wondering if it may be worth changing the two first arguments into a ScoreDoc[] given that's how stuff comes in, and perhaps unifying the sorting here. I realize though that this is a copy of a Lucene class and the change I am suggesting will make it diverge from its original source.

I think that's a good idea. I will give it a try.

I realize though that this is a copy of a Lucene class and the change I am suggesting will make it diverge from its original source.

It already diverges a bit in terms of making it easier to create - as long as it's on the constructor stuff I think we should be good for doing the change.

I'll give it a go and come back for feedback.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a test for RescoreKnnVectorQuery that indexes a bunch of random vectors, searches with a random vector and asserts the rewrite is a KnnScoreDocQuery with the appropriately ordered values.

It seems we are almost there in RescoreKnnVectorQueryTests, but maybe add some assertions there. Maybe via package private methods?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it may be worth changing the two first arguments into a ScoreDoc[] given that's how stuff comes in, and perhaps unifying the sorting here

@javanna I gave it a try in a073f43 - I like it more, it simplifies how clients create this query plus we enforce the invariant in the constructor itself 💯

We should have a test for RescoreKnnVectorQuery that indexes a bunch of random vectors, searches with a random vector and asserts the rewrite is a KnnScoreDocQuery with the appropriately ordered values.

@benwtrent I think the change in a073f43 makes it unnecessary. We're already checking via random insertions in the test. Do you think we need to add something else to make sure this doesn't bite us again?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doing the sort in the ctor is fine and as long as we have tests that fill fail if somebody removes that sort, I am happy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RescoreKnnVectorQueryIT add those tests. I checked by removing the sort that Luca added back in #122653 that this was caught by the newly added tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it. It also allows to share some code between the two consumers. Perhaps make it clear in the javadocs that this is no longer a straight copy of its lucene sibling. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps make it clear in the javadocs that this is no longer a straight copy of its lucene sibling.

👍 I've clarified that in ee464fe


this.docs = docs;
this.scores = scores;
this.segmentStarts = findSegmentStarts(reader, docs);
Expand Down
Loading