Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

package org.elasticsearch.search.query;

import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
Expand All @@ -30,10 +31,11 @@ public class VectorIT extends ESIntegTestCase {
private static final String VECTOR_FIELD = "vector";
private static final String NUM_ID_FIELD = "num_id";

private static void randomVector(float[] vector) {
private static void randomVector(float[] vector, int constant) {
for (int i = 0; i < vector.length; i++) {
vector[i] = randomFloat();
vector[i] = randomFloat() * constant;
}
VectorUtil.l2normalize(vector);
}

@Before
Expand All @@ -43,6 +45,7 @@ public void setup() throws IOException {
.startObject("properties")
.startObject(VECTOR_FIELD)
.field("type", "dense_vector")
.field("similarity", "dot_product")
.startObject("index_options")
.field("type", "hnsw")
.endObject()
Expand All @@ -59,20 +62,21 @@ public void setup() throws IOException {
.build();
prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get();
ensureGreen(INDEX_NAME);
float[] vector = new float[16];
for (int i = 0; i < 150; i++) {
float[] vector = new float[8];
randomVector(vector);
randomVector(vector, i % 25 + 1);
prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource(VECTOR_FIELD, vector, NUM_ID_FIELD, i).get();
}
forceMerge(true);
refresh(INDEX_NAME);
}

public void testFilteredQueryStrategy() {
float[] vector = new float[8];
randomVector(vector);
float[] vector = new float[16];
randomVector(vector, 25);
int upperLimit = 35;
var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, null, null).addFilterQuery(
QueryBuilders.rangeQuery(NUM_ID_FIELD).lte(30)
QueryBuilders.rangeQuery(NUM_ID_FIELD).lte(35)
);
assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), acornResponse -> {
assertNotEquals(0, acornResponse.getHits().getHits().length);
Expand Down Expand Up @@ -116,6 +120,8 @@ public void testFilteredQueryStrategy() {
assertTrue(
"fanoutVectorOps [" + fanoutVectorOpsSum + "] is not gt acornVectorOps [" + vectorOpsSum + "]",
fanoutVectorOpsSum > vectorOpsSum
// if both switch to brute-force due to excessive exploration, they will both equal to upperLimit
|| (fanoutVectorOpsSum == vectorOpsSum && vectorOpsSum == upperLimit + 1)
);
});
});
Expand Down