Skip to content

Commit a3a64ea

Browse files
[8.19] Semantic Text Rolling Upgrade Tests (elastic#126548) (elastic#127748)
* Semantic Text Rolling Upgrade Tests (elastic#126548) * Fix test failures --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent feb8beb commit a3a64ea

File tree

4 files changed

+271
-1
lines changed

4 files changed

+271
-1
lines changed

x-pack/plugin/inference/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import org.elasticsearch.gradle.internal.info.BuildParams
99
apply plugin: 'elasticsearch.internal-es-plugin'
1010
apply plugin: 'elasticsearch.internal-cluster-test'
1111
apply plugin: 'elasticsearch.internal-yaml-rest-test'
12+
apply plugin: 'elasticsearch.internal-test-artifact'
1213

1314
restResources {
1415
restApi {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.List;
3232
import java.util.Map;
3333

34+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS;
3435
import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
3536
import static org.elasticsearch.test.ESTestCase.randomFrom;
3637
import static org.elasticsearch.test.ESTestCase.randomInt;
@@ -46,9 +47,14 @@ public static TestModel createRandomInstance(TaskType taskType) {
4647
}
4748

4849
public static TestModel createRandomInstance(TaskType taskType, List<SimilarityMeasure> excludedSimilarities) {
50+
// Use a max dimension count that has a reasonable probability of being compatible with BBQ
51+
return createRandomInstance(taskType, excludedSimilarities, BBQ_MIN_DIMS * 2);
52+
}
53+
54+
public static TestModel createRandomInstance(TaskType taskType, List<SimilarityMeasure> excludedSimilarities, int maxDimensions) {
4955
var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null;
5056
var dimensions = taskType == TaskType.TEXT_EMBEDDING
51-
? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 64)
57+
? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions)
5258
: null;
5359

5460
SimilarityMeasure similarity = null;

x-pack/qa/rolling-upgrade/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ apply plugin: 'elasticsearch.bwc-test'
88
apply plugin: 'elasticsearch.rest-resources'
99

1010
dependencies {
11+
testImplementation testArtifact(project(':server'))
1112
testImplementation testArtifact(project(xpackModule('core')))
1213
testImplementation project(':x-pack:qa')
1314
testImplementation project(':modules:reindex')
15+
testImplementation testArtifact(project(xpackModule('inference')))
1416
}
1517

1618
restResources {
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.upgrades;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
11+
12+
import org.apache.lucene.search.join.ScoreMode;
13+
import org.elasticsearch.Version;
14+
import org.elasticsearch.action.admin.indices.create.CreateIndexResponse;
15+
import org.elasticsearch.client.Request;
16+
import org.elasticsearch.client.RequestOptions;
17+
import org.elasticsearch.client.Response;
18+
import org.elasticsearch.common.Strings;
19+
import org.elasticsearch.common.settings.Settings;
20+
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
21+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
22+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
23+
import org.elasticsearch.index.query.NestedQueryBuilder;
24+
import org.elasticsearch.index.query.QueryBuilder;
25+
import org.elasticsearch.inference.Model;
26+
import org.elasticsearch.inference.SimilarityMeasure;
27+
import org.elasticsearch.inference.TaskType;
28+
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
29+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
30+
import org.elasticsearch.test.rest.ObjectPath;
31+
import org.elasticsearch.xcontent.XContentBuilder;
32+
import org.elasticsearch.xcontent.XContentFactory;
33+
import org.elasticsearch.xcontent.XContentType;
34+
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
35+
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
36+
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
37+
import org.elasticsearch.xpack.inference.model.TestModel;
38+
import org.junit.BeforeClass;
39+
40+
import java.io.IOException;
41+
import java.util.ArrayList;
42+
import java.util.Arrays;
43+
import java.util.HashSet;
44+
import java.util.List;
45+
import java.util.Map;
46+
import java.util.Set;
47+
48+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests.addSemanticTextInferenceResults;
49+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText;
50+
import static org.hamcrest.CoreMatchers.equalTo;
51+
import static org.hamcrest.CoreMatchers.notNullValue;
52+
53+
public class SemanticTextUpgradeIT extends AbstractUpgradeTestCase {
54+
private static final String INDEX_BASE_NAME = "semantic_text_test_index";
55+
private static final String SPARSE_FIELD = "sparse_field";
56+
private static final String DENSE_FIELD = "dense_field";
57+
private static final Version UPGRADE_FROM_VERSION_PARSED = Version.fromString(UPGRADE_FROM_VERSION);
58+
59+
private static final String DOC_1_ID = "doc_1";
60+
private static final String DOC_2_ID = "doc_2";
61+
private static final Map<String, List<String>> DOC_VALUES = Map.of(
62+
DOC_1_ID,
63+
List.of("a test value", "with multiple test values"),
64+
DOC_2_ID,
65+
List.of("another test value")
66+
);
67+
68+
private static Model SPARSE_MODEL;
69+
private static Model DENSE_MODEL;
70+
71+
private final boolean useLegacyFormat;
72+
73+
@BeforeClass
74+
public static void beforeClass() {
75+
SPARSE_MODEL = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
76+
// Exclude dot product because we are not producing unit length vectors
77+
DENSE_MODEL = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT));
78+
}
79+
80+
public SemanticTextUpgradeIT(boolean useLegacyFormat) {
81+
this.useLegacyFormat = useLegacyFormat;
82+
}
83+
84+
@ParametersFactory
85+
public static Iterable<Object[]> parameters() {
86+
List<Object[]> parameters = new ArrayList<>();
87+
parameters.add(new Object[] { true });
88+
if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0)) {
89+
// New semantic text format added in 8.18
90+
parameters.add(new Object[] { false });
91+
}
92+
return parameters;
93+
}
94+
95+
public void testSemanticTextOperations() throws Exception {
96+
assumeTrue("Upgrade from version supports semantic text", UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_15_0));
97+
switch (CLUSTER_TYPE) {
98+
case OLD -> createAndPopulateIndex();
99+
case MIXED, UPGRADED -> performIndexQueryHighlightOps();
100+
default -> throw new UnsupportedOperationException("Unknown cluster type [" + CLUSTER_TYPE + "]");
101+
}
102+
}
103+
104+
private void createAndPopulateIndex() throws IOException {
105+
final String indexName = getIndexName();
106+
final String mapping = Strings.format("""
107+
{
108+
"properties": {
109+
"%s": {
110+
"type": "semantic_text",
111+
"inference_id": "%s"
112+
},
113+
"%s": {
114+
"type": "semantic_text",
115+
"inference_id": "%s"
116+
}
117+
}
118+
}
119+
""", SPARSE_FIELD, SPARSE_MODEL.getInferenceEntityId(), DENSE_FIELD, DENSE_MODEL.getInferenceEntityId());
120+
121+
Settings.Builder settingsBuilder = Settings.builder();
122+
if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0)) {
123+
settingsBuilder.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat);
124+
}
125+
126+
CreateIndexResponse response = createIndex(indexName, settingsBuilder.build(), mapping);
127+
assertThat(response.isAcknowledged(), equalTo(true));
128+
129+
indexDoc(DOC_1_ID, DOC_VALUES.get(DOC_1_ID));
130+
}
131+
132+
private void performIndexQueryHighlightOps() throws IOException {
133+
indexDoc(DOC_2_ID, DOC_VALUES.get(DOC_2_ID));
134+
135+
ObjectPath sparseQueryObjectPath = semanticQuery(SPARSE_FIELD, SPARSE_MODEL, "test value", 3);
136+
assertQueryResponseWithHighlights(sparseQueryObjectPath, SPARSE_FIELD);
137+
138+
ObjectPath denseQueryObjectPath = semanticQuery(DENSE_FIELD, DENSE_MODEL, "test value", 3);
139+
assertQueryResponseWithHighlights(denseQueryObjectPath, DENSE_FIELD);
140+
}
141+
142+
private String getIndexName() {
143+
return INDEX_BASE_NAME + (useLegacyFormat ? "_legacy" : "_new");
144+
}
145+
146+
private void indexDoc(String id, List<String> semanticTextFieldValue) throws IOException {
147+
final String indexName = getIndexName();
148+
final SemanticTextField sparseFieldValue = randomSemanticText(
149+
useLegacyFormat,
150+
SPARSE_FIELD,
151+
SPARSE_MODEL,
152+
null,
153+
semanticTextFieldValue,
154+
XContentType.JSON
155+
);
156+
final SemanticTextField denseFieldValue = randomSemanticText(
157+
useLegacyFormat,
158+
DENSE_FIELD,
159+
DENSE_MODEL,
160+
null,
161+
semanticTextFieldValue,
162+
XContentType.JSON
163+
);
164+
165+
XContentBuilder builder = XContentFactory.jsonBuilder();
166+
builder.startObject();
167+
if (useLegacyFormat == false) {
168+
builder.field(sparseFieldValue.fieldName(), semanticTextFieldValue);
169+
builder.field(denseFieldValue.fieldName(), semanticTextFieldValue);
170+
}
171+
addSemanticTextInferenceResults(useLegacyFormat, builder, List.of(sparseFieldValue, denseFieldValue));
172+
builder.endObject();
173+
174+
RequestOptions requestOptions = RequestOptions.DEFAULT.toBuilder().addParameter("refresh", "true").build();
175+
Request request = new Request("POST", indexName + "/_doc/" + id);
176+
request.setJsonEntity(Strings.toString(builder));
177+
request.setOptions(requestOptions);
178+
179+
Response response = client().performRequest(request);
180+
assertOK(response);
181+
}
182+
183+
private ObjectPath semanticQuery(String field, Model fieldModel, String query, Integer numOfHighlightFragments) throws IOException {
184+
// We can't perform a real semantic query because that requires performing inference, so instead we perform an equivalent nested
185+
// query
186+
final String embeddingsFieldName = SemanticTextField.getEmbeddingsFieldName(field);
187+
final QueryBuilder innerQueryBuilder = switch (fieldModel.getTaskType()) {
188+
case SPARSE_EMBEDDING -> {
189+
List<WeightedToken> weightedTokens = Arrays.stream(query.split("\\s")).map(t -> new WeightedToken(t, 1.0f)).toList();
190+
yield new SparseVectorQueryBuilder(embeddingsFieldName, weightedTokens, null, null, null, null);
191+
}
192+
case TEXT_EMBEDDING -> {
193+
DenseVectorFieldMapper.ElementType elementType = fieldModel.getServiceSettings().elementType();
194+
int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(
195+
elementType,
196+
fieldModel.getServiceSettings().dimensions()
197+
);
198+
199+
// Create a query vector with a value of 1 for each dimension, which will effectively act as a pass-through for the document
200+
// vector
201+
float[] queryVector = new float[embeddingLength];
202+
if (elementType == DenseVectorFieldMapper.ElementType.BIT) {
203+
Arrays.fill(queryVector, -128.0f);
204+
} else {
205+
Arrays.fill(queryVector, 1.0f);
206+
}
207+
208+
yield new KnnVectorQueryBuilder(embeddingsFieldName, queryVector, DOC_VALUES.size(), null, null, null);
209+
}
210+
default -> throw new UnsupportedOperationException("Unhandled task type [" + fieldModel.getTaskType() + "]");
211+
};
212+
213+
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(
214+
SemanticTextField.getChunksFieldName(field),
215+
innerQueryBuilder,
216+
ScoreMode.Max
217+
);
218+
219+
XContentBuilder builder = XContentFactory.jsonBuilder();
220+
builder.startObject();
221+
builder.field("query", nestedQueryBuilder);
222+
if (numOfHighlightFragments != null) {
223+
HighlightBuilder.Field highlightField = new HighlightBuilder.Field(field);
224+
highlightField.numOfFragments(numOfHighlightFragments);
225+
226+
HighlightBuilder highlightBuilder = new HighlightBuilder();
227+
highlightBuilder.field(highlightField);
228+
229+
builder.field("highlight", highlightBuilder);
230+
}
231+
builder.endObject();
232+
233+
Request request = new Request("GET", getIndexName() + "/_search");
234+
request.setJsonEntity(Strings.toString(builder));
235+
236+
Response response = client().performRequest(request);
237+
return assertOKAndCreateObjectPath(response);
238+
}
239+
240+
private static void assertQueryResponseWithHighlights(ObjectPath queryObjectPath, String field) throws IOException {
241+
assertThat(queryObjectPath.evaluate("hits.total.value"), equalTo(2));
242+
assertThat(queryObjectPath.evaluateArraySize("hits.hits"), equalTo(2));
243+
244+
Set<String> docIds = new HashSet<>();
245+
List<Map<String, Object>> hits = queryObjectPath.evaluate("hits.hits");
246+
for (Map<String, Object> hit : hits) {
247+
String id = ObjectPath.evaluate(hit, "_id");
248+
assertThat(id, notNullValue());
249+
docIds.add(id);
250+
251+
if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0) || CLUSTER_TYPE == ClusterType.UPGRADED) {
252+
// Semantic highlighting only functions reliably on clusters where all nodes are 8.18.0 or later
253+
List<String> expectedHighlight = DOC_VALUES.get(id);
254+
assertThat(expectedHighlight, notNullValue());
255+
assertThat(ObjectPath.evaluate(hit, "highlight." + field), equalTo(expectedHighlight));
256+
}
257+
}
258+
259+
assertThat(docIds, equalTo(Set.of(DOC_1_ID, DOC_2_ID)));
260+
}
261+
}

0 commit comments

Comments
 (0)