Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.util.Collections;
import java.util.Map;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT;
import static org.elasticsearch.index.query.QueryBuilders.boolQuery;
import static org.elasticsearch.index.query.QueryBuilders.combinedFieldsQuery;
import static org.elasticsearch.index.query.QueryBuilders.constantScoreQuery;
Expand Down Expand Up @@ -1359,7 +1360,15 @@ public void testKnnQueryNotSupportedInPercolator() throws IOException {
""");
indicesAdmin().prepareCreate("index1").setMapping(mappings).get();
ensureGreen();
QueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder("my_vector", new float[] { 1, 1, 1, 1, 1 }, 10, 10, 10f, null, null);
QueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder(
"my_vector",
new float[] { 1, 1, 1, 1, 1 },
10,
10,
IVF_FORMAT.isEnabled() ? 10f : null,
null,
null
);

IndexRequestBuilder indexRequestBuilder = prepareIndex("index1").setId("knn_query1")
.setSource(jsonBuilder().startObject().field("my_query", knnVectorQueryBuilder).endObject());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.OptionalLong;
import java.util.stream.IntStream;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -124,7 +125,17 @@ public void testDirectIOUsed() {
indexVectors();

// do a search
var knn = List.of(new KnnSearchBuilder("fooVector", new VectorData(null, new byte[64]), 10, 20, 10f, null, null));
var knn = List.of(
new KnnSearchBuilder(
"fooVector",
new VectorData(null, new byte[64]),
10,
20,
IVF_FORMAT.isEnabled() ? 10f : null,
null,
null
)
);
assertHitCount(prepareSearch("foo-vectors").setKnnSearch(knn), 10);
mockLog.assertAllExpectationsMatched();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import java.util.List;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT;
import static org.hamcrest.Matchers.notNullValue;

@ESIntegTestCase.ClusterScope(minNumDataNodes = 2)
Expand Down Expand Up @@ -77,29 +78,45 @@ public void testKnnSearchWithScroll() throws Exception {
// test top level knn search
{
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, 10f, null, null)));
sourceBuilder.knnSearch(
List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, IVF_FORMAT.isEnabled() ? 10f : null, null, null))
);
executeScrollSearch(client, sourceBuilder, k);
}
// test top level knn search + another query
{
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, 10f, null, null)));
sourceBuilder.knnSearch(
List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, IVF_FORMAT.isEnabled() ? 10f : null, null, null))
);
sourceBuilder.query(QueryBuilders.existsQuery("category").boost(10));
executeScrollSearch(client, sourceBuilder, k + 10);
}

// test knn query
{
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.query(new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, 10f, null, null));
sourceBuilder.query(
new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, IVF_FORMAT.isEnabled() ? 10f : null, null, null)
);
executeScrollSearch(client, sourceBuilder, k * numShards);
}
// test knn query + another query
{
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.query(
QueryBuilders.boolQuery()
.should(new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, 10f, null, null))
.should(
new KnnVectorQueryBuilder(
VECTOR_FIELD,
new float[] { 0, 0 },
k,
100,
IVF_FORMAT.isEnabled() ? 10f : null,
null,
null
)
)
.should(QueryBuilders.existsQuery("category").boost(10))
);
executeScrollSearch(client, sourceBuilder, k * numShards + 10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.List;
import java.util.Map;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse;
Expand Down Expand Up @@ -73,7 +74,8 @@ public void testSimpleNested() throws Exception {
assertResponse(
prepareSearch("test").setKnnSearch(
List.of(
new KnnSearchBuilder("nested.vector", new float[] { 1, 1, 1 }, 1, 1, 10f, null, null).innerHit(new InnerHitBuilder())
new KnnSearchBuilder("nested.vector", new float[] { 1, 1, 1 }, 1, 1, IVF_FORMAT.isEnabled() ? 10f : null, null, null)
.innerHit(new InnerHitBuilder())
)
).setAllowPartialSearchResults(false),
response -> assertThat(response.getHits().getHits().length, greaterThan(0))
Expand Down Expand Up @@ -155,7 +157,15 @@ private void testNestedWithTwoSegments(boolean flush) {
waitForRelocation(ClusterHealthStatus.GREEN);
refresh();

var knn = new KnnSearchBuilder("nested.vector", new float[] { -0.5f, 90.0f, -10f, 14.8f, -156.0f }, 2, 3, 10f, null, null);
var knn = new KnnSearchBuilder(
"nested.vector",
new float[] { -0.5f, 90.0f, -10f, 14.8f, -156.0f },
2,
3,
IVF_FORMAT.isEnabled() ? 10f : null,
null,
null
);
var request = prepareSearch("test").addFetchField("name").setKnnSearch(List.of(knn));
assertNoFailuresAndResponse(request, response -> {
assertHitCount(response, 2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.List;
import java.util.Map;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT;
import static org.elasticsearch.search.profile.query.RandomQueryGenerator.randomQueryBuilder;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
Expand Down Expand Up @@ -72,7 +73,7 @@ public void testProfileDfs() throws Exception {
new float[] { randomFloat(), randomFloat(), randomFloat() },
randomIntBetween(5, 10),
50,
10f,
IVF_FORMAT.isEnabled() ? 10f : null,
randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)),
randomBoolean() ? null : randomFloat()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -129,7 +130,7 @@ public static TestParams generate() {
randomVector(numDims),
k,
(int) (k * randomFloatBetween(1.0f, 10.0f, true)),
randomBoolean() ? null : randomFloatBetween(0.0f, 100.0f, true),
IVF_FORMAT.isEnabled() == false ? null : randomBoolean() ? null : randomFloatBetween(0.0f, 100.0f, true),
new RescoreVectorBuilder(randomFloatBetween(1.0f, 100f, true))
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.io.IOException;
import java.util.List;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.equalTo;

Expand Down Expand Up @@ -84,7 +85,9 @@ public void testTelemetryForRetrievers() throws IOException {
// search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers`
{
performSearch(
new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, 10f, null, null))
new SearchSourceBuilder().retriever(
new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, IVF_FORMAT.isEnabled() ? 10f : null, null, null)
)
);
}

Expand All @@ -99,7 +102,9 @@ public void testTelemetryForRetrievers() throws IOException {
{
performSearch(
new SearchSourceBuilder().retriever(
new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, 10f, null, null))
new StandardRetrieverBuilder(
new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, IVF_FORMAT.isEnabled() ? 10f : null, null, null)
)
)
);
}
Expand All @@ -114,7 +119,9 @@ public void testTelemetryForRetrievers() throws IOException {
// his will record 1 entry for "knn" in `sections`
{
performSearch(
new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, 10f, null, null)))
new SearchSourceBuilder().knnSearch(
List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, IVF_FORMAT.isEnabled() ? 10f : null, null, null))
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
builder.field(K_FIELD.getPreferredName(), k);
builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands);

if (visitPercentage != null) {
if (IVF_FORMAT.isEnabled() && visitPercentage != null) {
builder.field(VISIT_PERCENTAGE_FIELD.getPreferredName(), visitPercentage);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(K_FIELD.getPreferredName(), k);
builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands);

if (visitPercentage != null) {
if (IVF_FORMAT.isEnabled() && visitPercentage != null) {
builder.field(VISIT_PERCENTAGE_FIELD.getPreferredName(), visitPercentage);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -353,8 +354,8 @@ public void testRewriteShardSearchRequestWithRank() {
SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25)
.knnSearch(
List.of(
new KnnSearchBuilder("vector", new float[] { 0.0f }, 10, 100, 10f, null, null),
new KnnSearchBuilder("vector2", new float[] { 0.0f }, 10, 100, 10f, null, null)
new KnnSearchBuilder("vector", new float[] { 0.0f }, 10, 100, IVF_FORMAT.isEnabled() ? 10f : null, null, null),
new KnnSearchBuilder("vector2", new float[] { 0.0f }, 10, 100, IVF_FORMAT.isEnabled() ? 10f : null, null, null)
)
)
.rankBuilder(new TestRankBuilder(100));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.io.IOException;
import java.util.List;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -63,7 +64,8 @@ public void testKnnSearchRemovedVector() throws IOException {
client().prepareUpdate("index", "0").setDoc("vector", (Object) null).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).get();

float[] queryVector = randomVector();
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50, 10f, null, null).boost(5.0f);
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50, IVF_FORMAT.isEnabled() ? 10f : null, null, null)
.boost(5.0f);
assertResponse(
client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch))
Expand Down Expand Up @@ -107,7 +109,9 @@ public void testKnnWithQuery() throws IOException {
indicesAdmin().prepareRefresh("index").get();

float[] queryVector = randomVector();
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, 10f, null, null).boost(5.0f).queryName("knn");
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, IVF_FORMAT.isEnabled() ? 10f : null, null, null)
.boost(5.0f)
.queryName("knn");
assertResponse(
client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch))
Expand Down Expand Up @@ -156,9 +160,8 @@ public void testKnnFilter() throws IOException {
indicesAdmin().prepareRefresh("index").get();

float[] queryVector = randomVector();
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, 10f, null, null).addFilterQuery(
QueryBuilders.termsQuery("field", "second")
);
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, IVF_FORMAT.isEnabled() ? 10f : null, null, null)
.addFilterQuery(QueryBuilders.termsQuery("field", "second"));
assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10), response -> {
assertHitCount(response, 5);
assertEquals(5, response.getHits().getHits().length);
Expand Down Expand Up @@ -199,9 +202,8 @@ public void testKnnFilterWithRewrite() throws IOException {
indicesAdmin().prepareRefresh("index").get();

float[] queryVector = randomVector();
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, 10f, null, null).addFilterQuery(
QueryBuilders.termsLookupQuery("field", new TermsLookup("index", "lookup-doc", "other-field"))
);
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, IVF_FORMAT.isEnabled() ? 10f : null, null, null)
.addFilterQuery(QueryBuilders.termsLookupQuery("field", new TermsLookup("index", "lookup-doc", "other-field")));
assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).setSize(10), response -> {
assertHitCount(response, 5);
assertEquals(5, response.getHits().getHits().length);
Expand Down Expand Up @@ -246,8 +248,10 @@ public void testMultiKnnClauses() throws IOException {
indicesAdmin().prepareRefresh("index").get();

float[] queryVector = randomVector(20f, 21f);
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, 10f, null, null).boost(5.0f);
KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, 10f, null, null).boost(10.0f);
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, IVF_FORMAT.isEnabled() ? 10f : null, null, null)
.boost(5.0f);
KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, IVF_FORMAT.isEnabled() ? 10f : null, null, null)
.boost(10.0f);
assertResponse(
client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch, knnSearch2))
Expand Down Expand Up @@ -308,8 +312,8 @@ public void testMultiKnnClausesSameDoc() throws IOException {

float[] queryVector = randomVector();
// Having the same query vector and same docs should mean our KNN scores are linearly combined if the same doc is matched
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, 10f, null, null);
KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, 10f, null, null);
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, IVF_FORMAT.isEnabled() ? 10f : null, null, null);
KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, IVF_FORMAT.isEnabled() ? 10f : null, null, null);
assertResponse(
client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch))
Expand Down Expand Up @@ -383,7 +387,7 @@ public void testKnnFilteredAlias() throws IOException {
indicesAdmin().prepareRefresh("index").get();

float[] queryVector = randomVector();
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50, 10f, null, null);
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50, IVF_FORMAT.isEnabled() ? 10f : null, null, null);
final int expectedHitCount = expectedHits;
assertResponse(client().prepareSearch("test-alias").setKnnSearch(List.of(knnSearch)).setSize(10), response -> {
assertHitCount(response, expectedHitCount);
Expand Down Expand Up @@ -420,7 +424,7 @@ public void testKnnSearchAction() throws IOException {
float[] queryVector = randomVector();
assertResponse(
client().prepareSearch("index1", "index2")
.setQuery(new KnnVectorQueryBuilder("vector", queryVector, 5, 5, 10f, null, null))
.setQuery(new KnnVectorQueryBuilder("vector", queryVector, 5, 5, IVF_FORMAT.isEnabled() ? 10f : null, null, null))
.setSize(2),
response -> {
// The total hits is num_cands * num_shards, since the query gathers num_cands hits from each shard
Expand Down Expand Up @@ -454,7 +458,8 @@ public void testKnnVectorsWith4096Dims() throws IOException {
indicesAdmin().prepareRefresh("index").get();

float[] queryVector = randomVector(4096);
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 50, 10f, null, null).boost(5.0f);
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 50, IVF_FORMAT.isEnabled() ? 10f : null, null, null)
.boost(5.0f);
assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10), response -> {
assertHitCount(response, 3);
assertEquals(3, response.getHits().getHits().length);
Expand Down
Loading