Skip to content

Commit 2754697

Browse files
committed
Improve testing to add random indexing, and fix similarity
1 parent 9ff5ed8 commit 2754697

File tree

1 file changed

+65
-32
lines changed

1 file changed

+65
-32
lines changed

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
import com.carrotsearch.randomizedtesting.annotations.Name;
1111
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
1212

13-
import org.elasticsearch.action.bulk.BulkRequestBuilder;
14-
import org.elasticsearch.action.index.IndexRequest;
15-
import org.elasticsearch.action.support.WriteRequest;
13+
import org.elasticsearch.action.index.IndexRequestBuilder;
14+
import org.elasticsearch.cluster.metadata.IndexMetadata;
1615
import org.elasticsearch.common.settings.Settings;
1716
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
1817
import org.junit.Before;
1918

19+
import java.util.ArrayList;
2020
import java.util.HashMap;
2121
import java.util.List;
2222
import java.util.Locale;
@@ -39,6 +39,8 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
3939
);
4040

4141
private final String indexType;
42+
private int numDims;
43+
private int numDocs;
4244

4345
@ParametersFactory
4446
public static Iterable<Object[]> parameters() throws Exception {
@@ -49,15 +51,7 @@ public DenseVectorFieldTypeIT(@Name("indexType") String indexType) {
4951
this.indexType = indexType;
5052
}
5153

52-
private static Map<Integer, List<Float>> DOC_VALUES = new HashMap<>();
53-
static {
54-
DOC_VALUES.put(1, List.of(1.0f, 2.0f, 3.0f));
55-
DOC_VALUES.put(2, List.of(4.0f, 5.0f, 6.0f));
56-
DOC_VALUES.put(3, List.of(7.0f, 8.0f, 9.0f));
57-
DOC_VALUES.put(4, List.of(10.0f, 11.0f, 12.0f));
58-
DOC_VALUES.put(5, List.of(13.0f, 14.0f, 15.0f));
59-
DOC_VALUES.put(6, List.of(16.0f, 17.0f, 18.0f));
60-
}
54+
private Map<Integer, List<Float>> indexedDocs = new HashMap<>();
6155

6256
public void testRetrieveFieldType() {
6357
var query = """
@@ -66,57 +60,96 @@ public void testRetrieveFieldType() {
6660

6761
try (var resp = run(query)) {
6862
assertColumnNames(resp.columns(), List.of("id", "vector"));
69-
assertColumnTypes(resp.columns(), List.of("long", "double"));
63+
assertColumnTypes(resp.columns(), List.of("integer", "dense_vector"));
7064
}
7165
}
7266

7367
@SuppressWarnings("unchecked")
74-
public void testRetrieveDenseVectorFieldData() {
68+
public void testRetrieveOrderedDenseVectorFieldData() {
7569
var query = """
7670
FROM test
7771
| SORT id ASC
7872
""";
7973

8074
try (var resp = run(query)) {
8175
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
82-
DOC_VALUES.forEach((id, vector) -> {
83-
var values = valuesList.get(id - 1);
84-
assertEquals(id.intValue(), ((Long) values.get(0)).intValue());
76+
indexedDocs.forEach((id, vector) -> {
77+
var values = valuesList.get(id);
78+
assertEquals(id, values.get(0));
8579
List<Double> vectors = (List<Double>) values.get(1);
8680
assertEquals(vector.size(), vectors.size());
8781
for (int i = 0; i < vector.size(); i++) {
88-
assertEquals((float) vector.get(i), vectors.get(i).floatValue(), 0F);
82+
assertEquals(vector.get(i), vectors.get(i).floatValue(), 0F);
83+
}
84+
});
85+
}
86+
}
87+
88+
@SuppressWarnings("unchecked")
89+
public void testRetrieveUnOrderedDenseVectorFieldData() {
90+
var query = "FROM test";
91+
92+
try (var resp = run(query)) {
93+
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
94+
assertEquals(valuesList.size(), indexedDocs.size());
95+
valuesList.forEach(value -> {;
96+
assertEquals(2, value.size());
97+
Integer id = (Integer) value.get(0);
98+
List<Double> vector = (List<Double>) value.get(1);
99+
100+
List<Float> expectedVector = indexedDocs.get(id);
101+
for (int i = 0; i < vector.size(); i++) {
102+
assertEquals(expectedVector.get(i), vector.get(i).floatValue(), 0F);
89103
}
90104
});
91105
}
92106
}
93107

94108
@Before
95109
public void setup() {
110+
numDims = randomIntBetween(64, 256);
111+
numDocs = randomIntBetween(10, 100);
112+
for (int i = 0; i < numDocs; i++) {
113+
List<Float> vector = new ArrayList<>(numDims);
114+
for (int j = 0; j < numDims; j++) {
115+
// vector.add(randomFloat());
116+
vector.add(1.0f);
117+
}
118+
indexedDocs.put(i, vector);
119+
}
120+
96121
var indexName = "test";
97122
var client = client().admin().indices();
98123
var mapping = String.format(Locale.ROOT, """
99-
"id": integer,
124+
{
125+
"properties": {
126+
"id": {
127+
"type": "integer"
128+
},
100129
"vector": {
101-
"type": "dense_vector",
102-
"index_options": {
103-
"type": "%s"
104-
}
130+
"type": "dense_vector",
131+
"similarity": "l2_norm",
132+
"index_options": {
133+
"type": "%s"
134+
}
105135
}
136+
}
137+
}
106138
""", indexType);
139+
Settings settings = Settings.builder()
140+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
141+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5))
142+
.build();
107143
var CreateRequest = client.prepareCreate(indexName)
108144
.setSettings(Settings.builder().put("index.number_of_shards", 1))
109-
.setMapping(mapping);
145+
.setMapping(mapping)
146+
.setSettings(settings);
110147
assertAcked(CreateRequest);
111148

112-
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
113-
for (var entry : DOC_VALUES.entrySet()) {
114-
bulkRequestBuilder.add(
115-
new IndexRequest(indexName).id(entry.getKey().toString()).source("id", entry.getKey(), "vector", entry.getValue())
116-
);
149+
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
150+
for (int i = 0; i < numDocs; i++) {
151+
docs[i] = prepareIndex("test").setId("" + i).setSource("id", i, "vector", indexedDocs.get(i));
117152
}
118-
119-
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).get();
120-
ensureYellow(indexName);
153+
indexRandom(true, docs);
121154
}
122155
}

0 commit comments

Comments
 (0)