diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java index eda9beece4396..82f63ebbbee12 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java @@ -9,6 +9,7 @@ package org.elasticsearch.search.query; +import org.apache.lucene.util.VectorUtil; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -30,10 +31,11 @@ public class VectorIT extends ESIntegTestCase { private static final String VECTOR_FIELD = "vector"; private static final String NUM_ID_FIELD = "num_id"; - private static void randomVector(float[] vector) { + private static void randomVector(float[] vector, int constant) { for (int i = 0; i < vector.length; i++) { - vector[i] = randomFloat(); + vector[i] = randomFloat() * constant; } + VectorUtil.l2normalize(vector); } @Before @@ -43,6 +45,7 @@ public void setup() throws IOException { .startObject("properties") .startObject(VECTOR_FIELD) .field("type", "dense_vector") + .field("similarity", "dot_product") .startObject("index_options") .field("type", "hnsw") .endObject() @@ -59,9 +62,9 @@ public void setup() throws IOException { .build(); prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get(); ensureGreen(INDEX_NAME); + float[] vector = new float[16]; for (int i = 0; i < 150; i++) { - float[] vector = new float[8]; - randomVector(vector); + randomVector(vector, i % 25 + 1); prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource(VECTOR_FIELD, vector, NUM_ID_FIELD, i).get(); } forceMerge(true); @@ -69,10 +72,11 @@ public void setup() throws IOException { } public void testFilteredQueryStrategy() { - float[] vector = new float[8]; - randomVector(vector); + float[] vector = new float[16]; + randomVector(vector, 25); + int upperLimit = 35; var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, null, null).addFilterQuery( - QueryBuilders.rangeQuery(NUM_ID_FIELD).lte(30) + QueryBuilders.rangeQuery(NUM_ID_FIELD).lte(35) ); assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), acornResponse -> { assertNotEquals(0, acornResponse.getHits().getHits().length); @@ -116,6 +120,8 @@ public void testFilteredQueryStrategy() { assertTrue( "fanoutVectorOps [" + fanoutVectorOpsSum + "] is not gt acornVectorOps [" + vectorOpsSum + "]", fanoutVectorOpsSum > vectorOpsSum + // if both switch to brute-force due to excessive exploration, they will both equal to upperLimit + || (fanoutVectorOpsSum == vectorOpsSum && vectorOpsSum == upperLimit + 1) ); }); });