Skip to content

Commit 2aaa7fb

Browse files
authored
[9.0] Semantic Text Rolling Upgrade Tests (elastic#126548) (elastic#127863)
1 parent 3eb6ffc commit 2aaa7fb

File tree

5 files changed

+348
-2
lines changed

5 files changed

+348
-2
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.mapper.vectors;
11+
12+
import com.carrotsearch.randomizedtesting.RandomizedContext;
13+
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
14+
15+
import org.elasticsearch.inference.SimilarityMeasure;
16+
17+
import java.util.List;
18+
import java.util.Random;
19+
20+
public class DenseVectorFieldMapperTestUtils {
21+
private DenseVectorFieldMapperTestUtils() {}
22+
23+
public static List<SimilarityMeasure> getSupportedSimilarities(DenseVectorFieldMapper.ElementType elementType) {
24+
return switch (elementType) {
25+
case FLOAT, BYTE -> List.of(SimilarityMeasure.values());
26+
case BIT -> List.of(SimilarityMeasure.L2_NORM);
27+
};
28+
}
29+
30+
public static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) {
31+
return switch (elementType) {
32+
case FLOAT, BYTE -> dimensions;
33+
case BIT -> {
34+
assert dimensions % Byte.SIZE == 0;
35+
yield dimensions / Byte.SIZE;
36+
}
37+
};
38+
}
39+
40+
public static int randomCompatibleDimensions(DenseVectorFieldMapper.ElementType elementType, int max) {
41+
if (max < 1) {
42+
throw new IllegalArgumentException("max must be at least 1");
43+
}
44+
45+
return switch (elementType) {
46+
case FLOAT, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max);
47+
case BIT -> {
48+
if (max < 8) {
49+
throw new IllegalArgumentException("max must be at least 8 for bit vectors");
50+
}
51+
52+
// Generate a random dimension count that is a multiple of 8
53+
int maxEmbeddingLength = max / 8;
54+
yield RandomNumbers.randomIntBetween(random(), 1, maxEmbeddingLength) * 8;
55+
}
56+
};
57+
}
58+
59+
private static Random random() {
60+
return RandomizedContext.current().getRandom();
61+
}
62+
}

x-pack/plugin/inference/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
apply plugin: 'elasticsearch.internal-es-plugin'
99
apply plugin: 'elasticsearch.internal-cluster-test'
1010
apply plugin: 'elasticsearch.internal-yaml-rest-test'
11+
apply plugin: 'elasticsearch.internal-test-artifact'
1112

1213
restResources {
1314
restApi {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.common.io.stream.StreamInput;
1313
import org.elasticsearch.common.io.stream.StreamOutput;
1414
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
15+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
1516
import org.elasticsearch.inference.Model;
1617
import org.elasticsearch.inference.ModelConfigurations;
1718
import org.elasticsearch.inference.ModelSecrets;
@@ -25,9 +26,12 @@
2526
import org.elasticsearch.xpack.inference.services.ServiceUtils;
2627

2728
import java.io.IOException;
29+
import java.util.ArrayList;
2830
import java.util.HashMap;
31+
import java.util.List;
2932
import java.util.Map;
3033

34+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS;
3135
import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
3236
import static org.elasticsearch.test.ESTestCase.randomFrom;
3337
import static org.elasticsearch.test.ESTestCase.randomInt;
@@ -39,9 +43,41 @@ public static TestModel createRandomInstance() {
3943
}
4044

4145
public static TestModel createRandomInstance(TaskType taskType) {
42-
var dimensions = taskType == TaskType.TEXT_EMBEDDING ? randomInt(64) : null;
43-
var similarity = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(SimilarityMeasure.values()) : null;
46+
return createRandomInstance(taskType, null);
47+
}
48+
49+
public static TestModel createRandomInstance(TaskType taskType, List<SimilarityMeasure> excludedSimilarities) {
50+
// Use a max dimension count that has a reasonable probability of being compatible with BBQ
51+
return createRandomInstance(taskType, excludedSimilarities, BBQ_MIN_DIMS * 2);
52+
}
53+
54+
public static TestModel createRandomInstance(TaskType taskType, List<SimilarityMeasure> excludedSimilarities, int maxDimensions) {
4455
var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null;
56+
var dimensions = taskType == TaskType.TEXT_EMBEDDING
57+
? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions)
58+
: null;
59+
60+
SimilarityMeasure similarity = null;
61+
if (taskType == TaskType.TEXT_EMBEDDING) {
62+
List<SimilarityMeasure> supportedSimilarities = new ArrayList<>(
63+
DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType)
64+
);
65+
if (excludedSimilarities != null) {
66+
supportedSimilarities.removeAll(excludedSimilarities);
67+
}
68+
69+
if (supportedSimilarities.isEmpty()) {
70+
throw new IllegalArgumentException(
71+
"No supported similarities for combination of element type ["
72+
+ elementType
73+
+ "] and excluded similarities "
74+
+ (excludedSimilarities == null ? List.of() : excludedSimilarities)
75+
);
76+
}
77+
78+
similarity = randomFrom(supportedSimilarities);
79+
}
80+
4581
return new TestModel(
4682
randomAlphaOfLength(4),
4783
taskType,

x-pack/qa/rolling-upgrade/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ apply plugin: 'elasticsearch.bwc-test'
1414
apply plugin: 'elasticsearch.rest-resources'
1515

1616
dependencies {
17+
testImplementation testArtifact(project(':server'))
1718
testImplementation testArtifact(project(xpackModule('core')))
1819
testImplementation project(':x-pack:qa')
1920
testImplementation project(':modules:reindex')
21+
testImplementation testArtifact(project(xpackModule('inference')))
2022
}
2123

2224
restResources {
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.upgrades;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
11+
12+
import org.apache.lucene.search.join.ScoreMode;
13+
import org.elasticsearch.action.admin.indices.create.CreateIndexResponse;
14+
import org.elasticsearch.client.Request;
15+
import org.elasticsearch.client.RequestOptions;
16+
import org.elasticsearch.client.Response;
17+
import org.elasticsearch.common.Strings;
18+
import org.elasticsearch.common.settings.Settings;
19+
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
20+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
21+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
22+
import org.elasticsearch.index.query.NestedQueryBuilder;
23+
import org.elasticsearch.index.query.QueryBuilder;
24+
import org.elasticsearch.inference.Model;
25+
import org.elasticsearch.inference.SimilarityMeasure;
26+
import org.elasticsearch.inference.TaskType;
27+
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
28+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
29+
import org.elasticsearch.test.rest.ObjectPath;
30+
import org.elasticsearch.xcontent.XContentBuilder;
31+
import org.elasticsearch.xcontent.XContentFactory;
32+
import org.elasticsearch.xcontent.XContentType;
33+
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
34+
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
35+
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
36+
import org.elasticsearch.xpack.inference.model.TestModel;
37+
import org.junit.BeforeClass;
38+
39+
import java.io.IOException;
40+
import java.util.Arrays;
41+
import java.util.HashSet;
42+
import java.util.List;
43+
import java.util.Map;
44+
import java.util.Set;
45+
46+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests.addSemanticTextInferenceResults;
47+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText;
48+
import static org.hamcrest.CoreMatchers.equalTo;
49+
import static org.hamcrest.CoreMatchers.notNullValue;
50+
51+
public class SemanticTextUpgradeIT extends AbstractUpgradeTestCase {
52+
private static final String INDEX_BASE_NAME = "semantic_text_test_index";
53+
private static final String SPARSE_FIELD = "sparse_field";
54+
private static final String DENSE_FIELD = "dense_field";
55+
56+
private static final String DOC_1_ID = "doc_1";
57+
private static final String DOC_2_ID = "doc_2";
58+
private static final Map<String, List<String>> DOC_VALUES = Map.of(
59+
DOC_1_ID,
60+
List.of("a test value", "with multiple test values"),
61+
DOC_2_ID,
62+
List.of("another test value")
63+
);
64+
65+
private static Model SPARSE_MODEL;
66+
private static Model DENSE_MODEL;
67+
68+
private final boolean useLegacyFormat;
69+
70+
@BeforeClass
71+
public static void beforeClass() {
72+
SPARSE_MODEL = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
73+
// Exclude dot product because we are not producing unit length vectors
74+
DENSE_MODEL = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT));
75+
}
76+
77+
public SemanticTextUpgradeIT(boolean useLegacyFormat) {
78+
this.useLegacyFormat = useLegacyFormat;
79+
}
80+
81+
@ParametersFactory
82+
public static Iterable<Object[]> parameters() {
83+
return List.of(new Object[] { true }, new Object[] { false });
84+
}
85+
86+
public void testSemanticTextOperations() throws Exception {
87+
switch (CLUSTER_TYPE) {
88+
case OLD -> createAndPopulateIndex();
89+
case MIXED, UPGRADED -> performIndexQueryHighlightOps();
90+
default -> throw new UnsupportedOperationException("Unknown cluster type [" + CLUSTER_TYPE + "]");
91+
}
92+
}
93+
94+
private void createAndPopulateIndex() throws IOException {
95+
final String indexName = getIndexName();
96+
final String mapping = Strings.format("""
97+
{
98+
"properties": {
99+
"%s": {
100+
"type": "semantic_text",
101+
"inference_id": "%s"
102+
},
103+
"%s": {
104+
"type": "semantic_text",
105+
"inference_id": "%s"
106+
}
107+
}
108+
}
109+
""", SPARSE_FIELD, SPARSE_MODEL.getInferenceEntityId(), DENSE_FIELD, DENSE_MODEL.getInferenceEntityId());
110+
111+
CreateIndexResponse response = createIndex(
112+
indexName,
113+
Settings.builder().put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat).build(),
114+
mapping
115+
);
116+
assertThat(response.isAcknowledged(), equalTo(true));
117+
118+
indexDoc(DOC_1_ID, DOC_VALUES.get(DOC_1_ID));
119+
}
120+
121+
private void performIndexQueryHighlightOps() throws IOException {
122+
indexDoc(DOC_2_ID, DOC_VALUES.get(DOC_2_ID));
123+
124+
ObjectPath sparseQueryObjectPath = semanticQuery(SPARSE_FIELD, SPARSE_MODEL, "test value", 3);
125+
assertQueryResponseWithHighlights(sparseQueryObjectPath, SPARSE_FIELD);
126+
127+
ObjectPath denseQueryObjectPath = semanticQuery(DENSE_FIELD, DENSE_MODEL, "test value", 3);
128+
assertQueryResponseWithHighlights(denseQueryObjectPath, DENSE_FIELD);
129+
}
130+
131+
private String getIndexName() {
132+
return INDEX_BASE_NAME + (useLegacyFormat ? "_legacy" : "_new");
133+
}
134+
135+
private void indexDoc(String id, List<String> semanticTextFieldValue) throws IOException {
136+
final String indexName = getIndexName();
137+
final SemanticTextField sparseFieldValue = randomSemanticText(
138+
useLegacyFormat,
139+
SPARSE_FIELD,
140+
SPARSE_MODEL,
141+
semanticTextFieldValue,
142+
XContentType.JSON
143+
);
144+
final SemanticTextField denseFieldValue = randomSemanticText(
145+
useLegacyFormat,
146+
DENSE_FIELD,
147+
DENSE_MODEL,
148+
semanticTextFieldValue,
149+
XContentType.JSON
150+
);
151+
152+
XContentBuilder builder = XContentFactory.jsonBuilder();
153+
builder.startObject();
154+
if (useLegacyFormat == false) {
155+
builder.field(sparseFieldValue.fieldName(), semanticTextFieldValue);
156+
builder.field(denseFieldValue.fieldName(), semanticTextFieldValue);
157+
}
158+
addSemanticTextInferenceResults(useLegacyFormat, builder, List.of(sparseFieldValue, denseFieldValue));
159+
builder.endObject();
160+
161+
RequestOptions requestOptions = RequestOptions.DEFAULT.toBuilder().addParameter("refresh", "true").build();
162+
Request request = new Request("POST", indexName + "/_doc/" + id);
163+
request.setJsonEntity(Strings.toString(builder));
164+
request.setOptions(requestOptions);
165+
166+
Response response = client().performRequest(request);
167+
assertOK(response);
168+
}
169+
170+
private ObjectPath semanticQuery(String field, Model fieldModel, String query, Integer numOfHighlightFragments) throws IOException {
171+
// We can't perform a real semantic query because that requires performing inference, so instead we perform an equivalent nested
172+
// query
173+
final String embeddingsFieldName = SemanticTextField.getEmbeddingsFieldName(field);
174+
final QueryBuilder innerQueryBuilder = switch (fieldModel.getTaskType()) {
175+
case SPARSE_EMBEDDING -> {
176+
List<WeightedToken> weightedTokens = Arrays.stream(query.split("\\s")).map(t -> new WeightedToken(t, 1.0f)).toList();
177+
yield new SparseVectorQueryBuilder(embeddingsFieldName, weightedTokens, null, null, null, null);
178+
}
179+
case TEXT_EMBEDDING -> {
180+
DenseVectorFieldMapper.ElementType elementType = fieldModel.getServiceSettings().elementType();
181+
int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(
182+
elementType,
183+
fieldModel.getServiceSettings().dimensions()
184+
);
185+
186+
// Create a query vector with a value of 1 for each dimension, which will effectively act as a pass-through for the document
187+
// vector
188+
float[] queryVector = new float[embeddingLength];
189+
if (elementType == DenseVectorFieldMapper.ElementType.BIT) {
190+
Arrays.fill(queryVector, -128.0f);
191+
} else {
192+
Arrays.fill(queryVector, 1.0f);
193+
}
194+
195+
yield new KnnVectorQueryBuilder(embeddingsFieldName, queryVector, DOC_VALUES.size(), null, null, null);
196+
}
197+
default -> throw new UnsupportedOperationException("Unhandled task type [" + fieldModel.getTaskType() + "]");
198+
};
199+
200+
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(
201+
SemanticTextField.getChunksFieldName(field),
202+
innerQueryBuilder,
203+
ScoreMode.Max
204+
);
205+
206+
XContentBuilder builder = XContentFactory.jsonBuilder();
207+
builder.startObject();
208+
builder.field("query", nestedQueryBuilder);
209+
if (numOfHighlightFragments != null) {
210+
HighlightBuilder.Field highlightField = new HighlightBuilder.Field(field);
211+
highlightField.numOfFragments(numOfHighlightFragments);
212+
213+
HighlightBuilder highlightBuilder = new HighlightBuilder();
214+
highlightBuilder.field(highlightField);
215+
216+
builder.field("highlight", highlightBuilder);
217+
}
218+
builder.endObject();
219+
220+
Request request = new Request("GET", getIndexName() + "/_search");
221+
request.setJsonEntity(Strings.toString(builder));
222+
223+
Response response = client().performRequest(request);
224+
return assertOKAndCreateObjectPath(response);
225+
}
226+
227+
private static void assertQueryResponseWithHighlights(ObjectPath queryObjectPath, String field) throws IOException {
228+
assertThat(queryObjectPath.evaluate("hits.total.value"), equalTo(2));
229+
assertThat(queryObjectPath.evaluateArraySize("hits.hits"), equalTo(2));
230+
231+
Set<String> docIds = new HashSet<>();
232+
List<Map<String, Object>> hits = queryObjectPath.evaluate("hits.hits");
233+
for (Map<String, Object> hit : hits) {
234+
String id = ObjectPath.evaluate(hit, "_id");
235+
assertThat(id, notNullValue());
236+
docIds.add(id);
237+
238+
List<String> expectedHighlight = DOC_VALUES.get(id);
239+
assertThat(expectedHighlight, notNullValue());
240+
assertThat(ObjectPath.evaluate(hit, "highlight." + field), equalTo(expectedHighlight));
241+
}
242+
243+
assertThat(docIds, equalTo(Set.of(DOC_1_ID, DOC_2_ID)));
244+
}
245+
}

0 commit comments

Comments
 (0)