Skip to content

Commit 2eb61a8

Browse files
committed
Add semantic text field that uses dense model to test
1 parent 334de43 commit 2eb61a8

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
2020
import org.elasticsearch.index.query.NestedQueryBuilder;
2121
import org.elasticsearch.inference.Model;
22+
import org.elasticsearch.inference.SimilarityMeasure;
2223
import org.elasticsearch.inference.TaskType;
2324
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
2425
import org.elasticsearch.test.rest.ObjectPath;
@@ -46,15 +47,19 @@
4647

4748
public class SemanticTextUpgradeIT extends AbstractUpgradeTestCase {
4849
private static final String INDEX_BASE_NAME = "semantic_text_test_index";
49-
private static final String SEMANTIC_TEXT_FIELD = "semantic_field";
50+
private static final String SPARSE_FIELD = "sparse_field";
51+
private static final String DENSE_FIELD = "dense_field";
5052

5153
private static Model SPARSE_MODEL;
54+
private static Model DENSE_MODEL;
5255

5356
private final boolean useLegacyFormat;
5457

5558
@BeforeClass
5659
public static void beforeClass() {
5760
SPARSE_MODEL = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
61+
// Exclude dot product because we are not producing unit length vectors
62+
DENSE_MODEL = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT));
5863
}
5964

6065
public SemanticTextUpgradeIT(boolean useLegacyFormat) {
@@ -79,13 +84,17 @@ private void createAndPopulateIndex() throws IOException {
7984
final String mapping = Strings.format("""
8085
{
8186
"properties": {
87+
"%s": {
88+
"type": "semantic_text",
89+
"inference_id": "%s"
90+
},
8291
"%s": {
8392
"type": "semantic_text",
8493
"inference_id": "%s"
8594
}
8695
}
8796
}
88-
""", SEMANTIC_TEXT_FIELD, SPARSE_MODEL.getInferenceEntityId());
97+
""", SPARSE_FIELD, SPARSE_MODEL.getInferenceEntityId(), DENSE_FIELD, DENSE_MODEL.getInferenceEntityId());
8998

9099
CreateIndexResponse response = createIndex(
91100
indexName,
@@ -99,8 +108,8 @@ private void createAndPopulateIndex() throws IOException {
99108

100109
private void performIndexQueryHighlightOps() throws IOException {
101110
indexDoc("doc_2", List.of("another test value"));
102-
ObjectPath queryObjectPath = semanticQuery("test value", 3);
103-
assertQueryResponse(queryObjectPath);
111+
ObjectPath queryObjectPath = semanticQuery(SPARSE_FIELD, "test value", 3);
112+
assertQueryResponse(queryObjectPath, SPARSE_FIELD);
104113
}
105114

106115
private String getIndexName() {
@@ -109,21 +118,30 @@ private String getIndexName() {
109118

110119
private void indexDoc(String id, List<String> semanticTextFieldValue) throws IOException {
111120
final String indexName = getIndexName();
112-
final SemanticTextField semanticTextField = randomSemanticText(
121+
final SemanticTextField sparseFieldValue = randomSemanticText(
113122
useLegacyFormat,
114-
SEMANTIC_TEXT_FIELD,
123+
SPARSE_FIELD,
115124
SPARSE_MODEL,
116125
null,
117126
semanticTextFieldValue,
118127
XContentType.JSON
119128
);
129+
final SemanticTextField denseFieldValue = randomSemanticText(
130+
useLegacyFormat,
131+
DENSE_FIELD,
132+
DENSE_MODEL,
133+
null,
134+
semanticTextFieldValue,
135+
XContentType.JSON
136+
);
120137

121138
XContentBuilder builder = XContentFactory.jsonBuilder();
122139
builder.startObject();
123140
if (useLegacyFormat == false) {
124-
builder.field(semanticTextField.fieldName(), semanticTextFieldValue);
141+
builder.field(sparseFieldValue.fieldName(), semanticTextFieldValue);
142+
builder.field(denseFieldValue.fieldName(), semanticTextFieldValue);
125143
}
126-
addSemanticTextInferenceResults(useLegacyFormat, builder, List.of(semanticTextField));
144+
addSemanticTextInferenceResults(useLegacyFormat, builder, List.of(sparseFieldValue, denseFieldValue));
127145
builder.endObject();
128146

129147
RequestOptions requestOptions = RequestOptions.DEFAULT.toBuilder().addParameter("refresh", "true").build();
@@ -135,20 +153,20 @@ private void indexDoc(String id, List<String> semanticTextFieldValue) throws IOE
135153
assertOK(response);
136154
}
137155

138-
private ObjectPath semanticQuery(String query, Integer numOfHighlightFragments) throws IOException {
156+
private ObjectPath semanticQuery(String field, String query, Integer numOfHighlightFragments) throws IOException {
139157
// We can't perform a real semantic query because that requires performing inference, so instead we perform an equivalent nested
140158
// query
141159
List<WeightedToken> weightedTokens = Arrays.stream(query.split("\\s")).map(t -> new WeightedToken(t, 1.0f)).toList();
142160
SparseVectorQueryBuilder sparseVectorQueryBuilder = new SparseVectorQueryBuilder(
143-
SemanticTextField.getEmbeddingsFieldName(SEMANTIC_TEXT_FIELD),
161+
SemanticTextField.getEmbeddingsFieldName(field),
144162
weightedTokens,
145163
null,
146164
null,
147165
null,
148166
null
149167
);
150168
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(
151-
SemanticTextField.getChunksFieldName(SEMANTIC_TEXT_FIELD),
169+
SemanticTextField.getChunksFieldName(field),
152170
sparseVectorQueryBuilder,
153171
ScoreMode.Max
154172
);
@@ -157,7 +175,7 @@ private ObjectPath semanticQuery(String query, Integer numOfHighlightFragments)
157175
builder.startObject();
158176
builder.field("query", nestedQueryBuilder);
159177
if (numOfHighlightFragments != null) {
160-
HighlightBuilder.Field highlightField = new HighlightBuilder.Field(SEMANTIC_TEXT_FIELD);
178+
HighlightBuilder.Field highlightField = new HighlightBuilder.Field(field);
161179
highlightField.numOfFragments(numOfHighlightFragments);
162180

163181
HighlightBuilder highlightBuilder = new HighlightBuilder();
@@ -175,7 +193,7 @@ private ObjectPath semanticQuery(String query, Integer numOfHighlightFragments)
175193
}
176194

177195
@SuppressWarnings("unchecked")
178-
private static void assertQueryResponse(ObjectPath queryObjectPath) throws IOException {
196+
private static void assertQueryResponse(ObjectPath queryObjectPath, String field) throws IOException {
179197
final Map<String, List<String>> expectedHighlights = Map.of(
180198
"doc_1",
181199
List.of("a test value", "with multiple test values"),
@@ -198,7 +216,7 @@ private static void assertQueryResponse(ObjectPath queryObjectPath) throws IOExc
198216

199217
List<String> expectedHighlight = expectedHighlights.get(id);
200218
assertThat(expectedHighlight, notNullValue());
201-
assertThat(((Map<String, Object>) hitMap.get("highlight")).get(SEMANTIC_TEXT_FIELD), equalTo(expectedHighlight));
219+
assertThat(((Map<String, Object>) hitMap.get("highlight")).get(field), equalTo(expectedHighlight));
202220
}
203221

204222
assertThat(docIds, equalTo(Set.of("doc_1", "doc_2")));

0 commit comments

Comments
 (0)