diff --git a/docs/changelog/129440.yaml b/docs/changelog/129440.yaml new file mode 100644 index 0000000000000..f4999f8c627d3 --- /dev/null +++ b/docs/changelog/129440.yaml @@ -0,0 +1,5 @@ +pr: 129440 +summary: Fix filtered knn vector search when query timeouts are enabled +area: Vector Search +type: bug +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java new file mode 100644 index 0000000000000..eda9beece4396 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.query; + +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; + +public class VectorIT extends ESIntegTestCase { + + private static final String INDEX_NAME = "test"; + private static final String VECTOR_FIELD = "vector"; + private static final String NUM_ID_FIELD = "num_id"; + + private static void randomVector(float[] vector) { + for (int i = 0; i < vector.length; i++) { + vector[i] = randomFloat(); + } + } + + @Before + public void setup() throws IOException { + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(VECTOR_FIELD) + .field("type", "dense_vector") + .startObject("index_options") + .field("type", "hnsw") + .endObject() + .endObject() + .startObject(NUM_ID_FIELD) + .field("type", "long") + .endObject() + .endObject() + .endObject(); + + Settings settings = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .build(); + prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get(); + ensureGreen(INDEX_NAME); + for (int i = 0; i < 150; i++) { + float[] vector = new float[8]; + randomVector(vector); + prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource(VECTOR_FIELD, vector, NUM_ID_FIELD, i).get(); + } + forceMerge(true); + refresh(INDEX_NAME); + } + + public void testFilteredQueryStrategy() { + float[] vector = new float[8]; + randomVector(vector); + var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, null, null).addFilterQuery( + QueryBuilders.rangeQuery(NUM_ID_FIELD).lte(30) + ); + assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), acornResponse -> { + assertNotEquals(0, acornResponse.getHits().getHits().length); + var profileResults = acornResponse.getProfileResults(); + long vectorOpsSum = profileResults.values() + .stream() + .mapToLong( + pr -> pr.getQueryPhase() + .getSearchProfileDfsPhaseResult() + .getQueryProfileShardResult() + .stream() + .mapToLong(qpr -> qpr.getVectorOperationsCount().longValue()) + .sum() + ) + .sum(); + client().admin() + .indices() + .prepareUpdateSettings(INDEX_NAME) + .setSettings( + Settings.builder() + .put( + DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC.getKey(), + DenseVectorFieldMapper.FilterHeuristic.FANOUT.toString() + ) + ) + .get(); + assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), fanoutResponse -> { + assertNotEquals(0, fanoutResponse.getHits().getHits().length); + var fanoutProfileResults = fanoutResponse.getProfileResults(); + long fanoutVectorOpsSum = fanoutProfileResults.values() + .stream() + .mapToLong( + pr -> pr.getQueryPhase() + .getSearchProfileDfsPhaseResult() + .getQueryProfileShardResult() + .stream() + .mapToLong(qpr -> qpr.getVectorOperationsCount().longValue()) + .sum() + ) + .sum(); + assertTrue( + "fanoutVectorOps [" + fanoutVectorOpsSum + "] is not gt acornVectorOps [" + vectorOpsSum + "]", + fanoutVectorOpsSum > vectorOpsSum + ); + }); + }); + } + +} diff --git a/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java b/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java index 64b54d3623f04..9c998eb920dc9 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java @@ -26,6 +26,7 @@ import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.search.suggest.document.CompletionTerms; +import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.automaton.CompiledAutomaton; @@ -145,7 +146,7 @@ public void searchNearestVectors(String field, byte[] target, KnnCollector colle in.searchNearestVectors(field, target, collector, acceptDocs); return; } - in.searchNearestVectors(field, target, collector, new TimeOutCheckingBits(acceptDocs)); + in.searchNearestVectors(field, target, collector, createTimeOutCheckingBits(acceptDocs)); } @Override @@ -163,7 +164,98 @@ public void searchNearestVectors(String field, float[] target, KnnCollector coll in.searchNearestVectors(field, target, collector, acceptDocs); return; } - in.searchNearestVectors(field, target, collector, new TimeOutCheckingBits(acceptDocs)); + in.searchNearestVectors(field, target, collector, createTimeOutCheckingBits(acceptDocs)); + } + + private Bits createTimeOutCheckingBits(Bits acceptDocs) { + if (acceptDocs == null || acceptDocs instanceof BitSet) { + return new TimeOutCheckingBitSet((BitSet) acceptDocs); + } + return new TimeOutCheckingBits(acceptDocs); + } + + private class TimeOutCheckingBitSet extends BitSet { + private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10; + private int calls; + private final BitSet inner; + private final int maxDoc; + + private TimeOutCheckingBitSet(BitSet inner) { + this.inner = inner; + this.maxDoc = maxDoc(); + } + + @Override + public void set(int i) { + throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); + } + + @Override + public boolean getAndSet(int i) { + throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); + } + + @Override + public void clear(int i) { + throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); + } + + @Override + public void clear(int startIndex, int endIndex) { + throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); + } + + @Override + public int cardinality() { + if (inner == null) { + return maxDoc; + } + return inner.cardinality(); + } + + @Override + public int approximateCardinality() { + if (inner == null) { + return maxDoc; + } + return inner.approximateCardinality(); + } + + @Override + public int prevSetBit(int index) { + throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); + } + + @Override + public int nextSetBit(int start, int end) { + throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); + } + + @Override + public long ramBytesUsed() { + throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet"); + } + + @Override + public boolean get(int index) { + if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) { + queryCancellation.checkCancelled(); + } + if (inner == null) { + // if acceptDocs is null, we assume all docs are accepted + return index >= 0 && index < maxDoc; + } + return inner.get(index); + } + + @Override + public int length() { + if (inner == null) { + // if acceptDocs is null, we assume all docs are accepted + return maxDoc; + } + return 0; + } } private class TimeOutCheckingBits implements Bits { @@ -171,7 +263,7 @@ private class TimeOutCheckingBits implements Bits { private final Bits updatedAcceptDocs; private int calls; - TimeOutCheckingBits(Bits acceptDocs) { + private TimeOutCheckingBits(Bits acceptDocs) { // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would // match all docs to allow timeout checking. this.updatedAcceptDocs = acceptDocs == null ? new Bits.MatchAllBits(maxDoc()) : acceptDocs; diff --git a/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfileShardResult.java b/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfileShardResult.java index b754adc6c9620..50bb1b1a913e1 100644 --- a/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfileShardResult.java +++ b/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfileShardResult.java @@ -137,4 +137,8 @@ public int hashCode() { public String toString() { return Strings.toString(this); } + + public Long getVectorOperationsCount() { + return vectorOperationsCount; + } }