Skip to content

Commit 31faca1

Browse files
committed
Add IT
1 parent 0c18eba commit 31faca1

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.vector;
9+
10+
import org.apache.lucene.index.VectorSimilarityFunction;
11+
import org.elasticsearch.action.index.IndexRequestBuilder;
12+
import org.elasticsearch.cluster.metadata.IndexMetadata;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.xcontent.XContentBuilder;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xpack.esql.EsqlTestUtils;
17+
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
18+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
19+
import org.junit.Before;
20+
21+
import java.io.IOException;
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
import java.util.Set;
25+
26+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
27+
28+
public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
29+
30+
private static final Set<String> DENSE_VECTOR_INDEX_TYPES = Set.of(
31+
/* "int8_hnsw",
32+
"hnsw",
33+
"int4_hnsw",
34+
"bbq_hnsw",
35+
"int8_flat",
36+
"int4_flat",
37+
"bbq_flat",*/
38+
"flat"
39+
);
40+
41+
@SuppressWarnings("unchecked")
42+
public void testCosineSimilarity() {
43+
var query = """
44+
FROM test
45+
| EVAL similarity = v_cosine_similarity(left_vector, right_vector)
46+
| KEEP id, left_vector, right_vector, similarity
47+
""";
48+
49+
try (var resp = run(query)) {
50+
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
51+
valuesList.forEach(values -> {
52+
List<Float> leftVector = (List<Float>) values.get(1);
53+
float[] leftScratch = new float[leftVector.size()];
54+
for (int i = 0; i < leftVector.size(); i++) {
55+
leftScratch[i] = leftVector.get(i);
56+
}
57+
List<Float> rightVector = (List<Float>) values.get(2);
58+
float[] rightScratch = new float[rightVector.size()];
59+
for (int i = 0; i < rightVector.size(); i++) {
60+
rightScratch[i] = rightVector.get(i);
61+
}
62+
Double similarity = (Double) values.get(3);
63+
assertNotNull(similarity);
64+
65+
float expectedSimilarity = VectorSimilarityFunction.COSINE.compare(leftScratch, rightScratch);
66+
assertEquals(expectedSimilarity, similarity, 0.0001);
67+
});
68+
}
69+
}
70+
71+
@Before
72+
public void setup() throws IOException {
73+
assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
74+
75+
createIndexWithDenseVector("test");
76+
77+
int numDims = randomIntBetween(32, 64) * 2; // min 64, even number
78+
int numDocs = randomIntBetween(10, 100);
79+
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
80+
for (int i = 0; i < numDocs; i++) {
81+
List<Float> leftVector = new ArrayList<>(numDims);
82+
for (int j = 0; j < numDims; j++) {
83+
leftVector.add(randomFloat());
84+
}
85+
List<Float> rightVector = new ArrayList<>(numDims);
86+
for (int j = 0; j < numDims; j++) {
87+
rightVector.add(randomFloat());
88+
}
89+
docs[i] = prepareIndex("test").setId("" + i)
90+
.setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector);
91+
}
92+
93+
indexRandom(true, docs);
94+
}
95+
96+
private void createIndexWithDenseVector(String indexName) throws IOException {
97+
var client = client().admin().indices();
98+
XContentBuilder mapping = XContentFactory.jsonBuilder()
99+
.startObject()
100+
.startObject("properties")
101+
.startObject("id")
102+
.field("type", "integer")
103+
.endObject();
104+
createDenseVectorField(mapping, "left_vector");
105+
createDenseVectorField(mapping, "right_vector");
106+
mapping.endObject().endObject();
107+
Settings.Builder settingsBuilder = Settings.builder()
108+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
109+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5));
110+
111+
var CreateRequest = client.prepareCreate(indexName)
112+
.setSettings(Settings.builder().put("index.number_of_shards", 1))
113+
.setMapping(mapping)
114+
.setSettings(settingsBuilder.build());
115+
assertAcked(CreateRequest);
116+
}
117+
118+
private void createDenseVectorField(XContentBuilder mapping, String fieldName) throws IOException {
119+
mapping.startObject(fieldName).field("type", "dense_vector").field("similarity", "cosine");
120+
mapping.endObject();
121+
}
122+
}

0 commit comments

Comments
 (0)