Skip to content

Commit 2eb9fab

Browse files
authored
Semantic Text Rolling Upgrade Tests (elastic#126548)
1 parent a2e580f commit 2eb9fab

File tree

4 files changed

+257
-1
lines changed

4 files changed

+257
-1
lines changed

x-pack/plugin/inference/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
apply plugin: 'elasticsearch.internal-es-plugin'
99
apply plugin: 'elasticsearch.internal-cluster-test'
1010
apply plugin: 'elasticsearch.internal-yaml-rest-test'
11+
apply plugin: 'elasticsearch.internal-test-artifact'
1112

1213
restResources {
1314
restApi {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.List;
3232
import java.util.Map;
3333

34+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS;
3435
import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
3536
import static org.elasticsearch.test.ESTestCase.randomFrom;
3637
import static org.elasticsearch.test.ESTestCase.randomInt;
@@ -46,9 +47,14 @@ public static TestModel createRandomInstance(TaskType taskType) {
4647
}
4748

4849
public static TestModel createRandomInstance(TaskType taskType, List<SimilarityMeasure> excludedSimilarities) {
50+
// Use a max dimension count that has a reasonable probability of being compatible with BBQ
51+
return createRandomInstance(taskType, excludedSimilarities, BBQ_MIN_DIMS * 2);
52+
}
53+
54+
public static TestModel createRandomInstance(TaskType taskType, List<SimilarityMeasure> excludedSimilarities, int maxDimensions) {
4955
var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null;
5056
var dimensions = taskType == TaskType.TEXT_EMBEDDING
51-
? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 64)
57+
? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions)
5258
: null;
5359

5460
SimilarityMeasure similarity = null;

x-pack/qa/rolling-upgrade/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ apply plugin: 'elasticsearch.bwc-test'
1414
apply plugin: 'elasticsearch.rest-resources'
1515

1616
dependencies {
17+
testImplementation testArtifact(project(':server'))
1718
testImplementation testArtifact(project(xpackModule('core')))
1819
testImplementation project(':x-pack:qa')
1920
testImplementation project(':modules:reindex')
21+
testImplementation testArtifact(project(xpackModule('inference')))
2022
}
2123

2224
restResources {
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.upgrades;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
11+
12+
import org.apache.lucene.search.join.ScoreMode;
13+
import org.elasticsearch.action.admin.indices.create.CreateIndexResponse;
14+
import org.elasticsearch.client.Request;
15+
import org.elasticsearch.client.RequestOptions;
16+
import org.elasticsearch.client.Response;
17+
import org.elasticsearch.common.Strings;
18+
import org.elasticsearch.common.settings.Settings;
19+
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
20+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
21+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
22+
import org.elasticsearch.index.query.NestedQueryBuilder;
23+
import org.elasticsearch.index.query.QueryBuilder;
24+
import org.elasticsearch.inference.Model;
25+
import org.elasticsearch.inference.SimilarityMeasure;
26+
import org.elasticsearch.inference.TaskType;
27+
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
28+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
29+
import org.elasticsearch.test.rest.ObjectPath;
30+
import org.elasticsearch.xcontent.XContentBuilder;
31+
import org.elasticsearch.xcontent.XContentFactory;
32+
import org.elasticsearch.xcontent.XContentType;
33+
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
34+
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
35+
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
36+
import org.elasticsearch.xpack.inference.model.TestModel;
37+
import org.junit.BeforeClass;
38+
39+
import java.io.IOException;
40+
import java.util.Arrays;
41+
import java.util.HashSet;
42+
import java.util.List;
43+
import java.util.Map;
44+
import java.util.Set;
45+
46+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests.addSemanticTextInferenceResults;
47+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText;
48+
import static org.hamcrest.CoreMatchers.equalTo;
49+
import static org.hamcrest.CoreMatchers.notNullValue;
50+
51+
public class SemanticTextUpgradeIT extends AbstractUpgradeTestCase {
52+
private static final String INDEX_BASE_NAME = "semantic_text_test_index";
53+
private static final String SPARSE_FIELD = "sparse_field";
54+
private static final String DENSE_FIELD = "dense_field";
55+
56+
private static final String DOC_1_ID = "doc_1";
57+
private static final String DOC_2_ID = "doc_2";
58+
private static final Map<String, List<String>> DOC_VALUES = Map.of(
59+
DOC_1_ID,
60+
List.of("a test value", "with multiple test values"),
61+
DOC_2_ID,
62+
List.of("another test value")
63+
);
64+
65+
private static Model SPARSE_MODEL;
66+
private static Model DENSE_MODEL;
67+
68+
private final boolean useLegacyFormat;
69+
70+
@BeforeClass
71+
public static void beforeClass() {
72+
SPARSE_MODEL = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
73+
// Exclude dot product because we are not producing unit length vectors
74+
DENSE_MODEL = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT));
75+
}
76+
77+
public SemanticTextUpgradeIT(boolean useLegacyFormat) {
78+
this.useLegacyFormat = useLegacyFormat;
79+
}
80+
81+
@ParametersFactory
82+
public static Iterable<Object[]> parameters() {
83+
return List.of(new Object[] { true }, new Object[] { false });
84+
}
85+
86+
public void testSemanticTextOperations() throws Exception {
87+
switch (CLUSTER_TYPE) {
88+
case OLD -> createAndPopulateIndex();
89+
case MIXED, UPGRADED -> performIndexQueryHighlightOps();
90+
default -> throw new UnsupportedOperationException("Unknown cluster type [" + CLUSTER_TYPE + "]");
91+
}
92+
}
93+
94+
private void createAndPopulateIndex() throws IOException {
95+
final String indexName = getIndexName();
96+
final String mapping = Strings.format("""
97+
{
98+
"properties": {
99+
"%s": {
100+
"type": "semantic_text",
101+
"inference_id": "%s"
102+
},
103+
"%s": {
104+
"type": "semantic_text",
105+
"inference_id": "%s"
106+
}
107+
}
108+
}
109+
""", SPARSE_FIELD, SPARSE_MODEL.getInferenceEntityId(), DENSE_FIELD, DENSE_MODEL.getInferenceEntityId());
110+
111+
CreateIndexResponse response = createIndex(
112+
indexName,
113+
Settings.builder().put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat).build(),
114+
mapping
115+
);
116+
assertThat(response.isAcknowledged(), equalTo(true));
117+
118+
indexDoc(DOC_1_ID, DOC_VALUES.get(DOC_1_ID));
119+
}
120+
121+
private void performIndexQueryHighlightOps() throws IOException {
122+
indexDoc(DOC_2_ID, DOC_VALUES.get(DOC_2_ID));
123+
124+
ObjectPath sparseQueryObjectPath = semanticQuery(SPARSE_FIELD, SPARSE_MODEL, "test value", 3);
125+
assertQueryResponseWithHighlights(sparseQueryObjectPath, SPARSE_FIELD);
126+
127+
ObjectPath denseQueryObjectPath = semanticQuery(DENSE_FIELD, DENSE_MODEL, "test value", 3);
128+
assertQueryResponseWithHighlights(denseQueryObjectPath, DENSE_FIELD);
129+
}
130+
131+
private String getIndexName() {
132+
return INDEX_BASE_NAME + (useLegacyFormat ? "_legacy" : "_new");
133+
}
134+
135+
private void indexDoc(String id, List<String> semanticTextFieldValue) throws IOException {
136+
final String indexName = getIndexName();
137+
final SemanticTextField sparseFieldValue = randomSemanticText(
138+
useLegacyFormat,
139+
SPARSE_FIELD,
140+
SPARSE_MODEL,
141+
null,
142+
semanticTextFieldValue,
143+
XContentType.JSON
144+
);
145+
final SemanticTextField denseFieldValue = randomSemanticText(
146+
useLegacyFormat,
147+
DENSE_FIELD,
148+
DENSE_MODEL,
149+
null,
150+
semanticTextFieldValue,
151+
XContentType.JSON
152+
);
153+
154+
XContentBuilder builder = XContentFactory.jsonBuilder();
155+
builder.startObject();
156+
if (useLegacyFormat == false) {
157+
builder.field(sparseFieldValue.fieldName(), semanticTextFieldValue);
158+
builder.field(denseFieldValue.fieldName(), semanticTextFieldValue);
159+
}
160+
addSemanticTextInferenceResults(useLegacyFormat, builder, List.of(sparseFieldValue, denseFieldValue));
161+
builder.endObject();
162+
163+
RequestOptions requestOptions = RequestOptions.DEFAULT.toBuilder().addParameter("refresh", "true").build();
164+
Request request = new Request("POST", indexName + "/_doc/" + id);
165+
request.setJsonEntity(Strings.toString(builder));
166+
request.setOptions(requestOptions);
167+
168+
Response response = client().performRequest(request);
169+
assertOK(response);
170+
}
171+
172+
private ObjectPath semanticQuery(String field, Model fieldModel, String query, Integer numOfHighlightFragments) throws IOException {
173+
// We can't perform a real semantic query because that requires performing inference, so instead we perform an equivalent nested
174+
// query
175+
final String embeddingsFieldName = SemanticTextField.getEmbeddingsFieldName(field);
176+
final QueryBuilder innerQueryBuilder = switch (fieldModel.getTaskType()) {
177+
case SPARSE_EMBEDDING -> {
178+
List<WeightedToken> weightedTokens = Arrays.stream(query.split("\\s")).map(t -> new WeightedToken(t, 1.0f)).toList();
179+
yield new SparseVectorQueryBuilder(embeddingsFieldName, weightedTokens, null, null, null, null);
180+
}
181+
case TEXT_EMBEDDING -> {
182+
DenseVectorFieldMapper.ElementType elementType = fieldModel.getServiceSettings().elementType();
183+
int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(
184+
elementType,
185+
fieldModel.getServiceSettings().dimensions()
186+
);
187+
188+
// Create a query vector with a value of 1 for each dimension, which will effectively act as a pass-through for the document
189+
// vector
190+
float[] queryVector = new float[embeddingLength];
191+
if (elementType == DenseVectorFieldMapper.ElementType.BIT) {
192+
Arrays.fill(queryVector, -128.0f);
193+
} else {
194+
Arrays.fill(queryVector, 1.0f);
195+
}
196+
197+
yield new KnnVectorQueryBuilder(embeddingsFieldName, queryVector, DOC_VALUES.size(), null, null, null);
198+
}
199+
default -> throw new UnsupportedOperationException("Unhandled task type [" + fieldModel.getTaskType() + "]");
200+
};
201+
202+
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(
203+
SemanticTextField.getChunksFieldName(field),
204+
innerQueryBuilder,
205+
ScoreMode.Max
206+
);
207+
208+
XContentBuilder builder = XContentFactory.jsonBuilder();
209+
builder.startObject();
210+
builder.field("query", nestedQueryBuilder);
211+
if (numOfHighlightFragments != null) {
212+
HighlightBuilder.Field highlightField = new HighlightBuilder.Field(field);
213+
highlightField.numOfFragments(numOfHighlightFragments);
214+
215+
HighlightBuilder highlightBuilder = new HighlightBuilder();
216+
highlightBuilder.field(highlightField);
217+
218+
builder.field("highlight", highlightBuilder);
219+
}
220+
builder.endObject();
221+
222+
Request request = new Request("GET", getIndexName() + "/_search");
223+
request.setJsonEntity(Strings.toString(builder));
224+
225+
Response response = client().performRequest(request);
226+
return assertOKAndCreateObjectPath(response);
227+
}
228+
229+
private static void assertQueryResponseWithHighlights(ObjectPath queryObjectPath, String field) throws IOException {
230+
assertThat(queryObjectPath.evaluate("hits.total.value"), equalTo(2));
231+
assertThat(queryObjectPath.evaluateArraySize("hits.hits"), equalTo(2));
232+
233+
Set<String> docIds = new HashSet<>();
234+
List<Map<String, Object>> hits = queryObjectPath.evaluate("hits.hits");
235+
for (Map<String, Object> hit : hits) {
236+
String id = ObjectPath.evaluate(hit, "_id");
237+
assertThat(id, notNullValue());
238+
docIds.add(id);
239+
240+
List<String> expectedHighlight = DOC_VALUES.get(id);
241+
assertThat(expectedHighlight, notNullValue());
242+
assertThat(ObjectPath.evaluate(hit, "highlight." + field), equalTo(expectedHighlight));
243+
}
244+
245+
assertThat(docIds, equalTo(Set.of(DOC_1_ID, DOC_2_ID)));
246+
}
247+
}

0 commit comments

Comments
 (0)