Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
apply plugin: 'elasticsearch.internal-es-plugin'
apply plugin: 'elasticsearch.internal-cluster-test'
apply plugin: 'elasticsearch.internal-yaml-rest-test'
apply plugin: 'elasticsearch.internal-test-artifact'

restResources {
restApi {
Expand Down
1 change: 1 addition & 0 deletions x-pack/qa/rolling-upgrade/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies {
testImplementation testArtifact(project(xpackModule('core')))
testImplementation project(':x-pack:qa')
testImplementation project(':modules:reindex')
testImplementation testArtifact(project(xpackModule('inference')))
}

restResources {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.upgrades;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.action.admin.indices.create.CreateIndexResponse;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.Response;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.query.NestedQueryBuilder;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.test.rest.ObjectPath;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.junit.BeforeClass;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests.addSemanticTextInferenceResults;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.notNullValue;

public class SemanticTextUpgradeIT extends AbstractUpgradeTestCase {
private static final String INDEX_BASE_NAME = "semantic_text_test_index";
private static final String SEMANTIC_TEXT_FIELD = "semantic_field";

private static Model SPARSE_MODEL;

private final boolean useLegacyFormat;

@BeforeClass
public static void beforeClass() {
SPARSE_MODEL = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
}

public SemanticTextUpgradeIT(boolean useLegacyFormat) {
this.useLegacyFormat = useLegacyFormat;
}

@ParametersFactory
public static Iterable<Object[]> parameters() {
return List.of(new Object[] { true }, new Object[] { false });
}

public void testSemanticTextOperations() throws Exception {
switch (CLUSTER_TYPE) {
case OLD -> createAndPopulateIndex();
case MIXED, UPGRADED -> performIndexQueryHighlightOps();
default -> throw new UnsupportedOperationException("Unknown cluster type [" + CLUSTER_TYPE + "]");
}
}

private void createAndPopulateIndex() throws IOException {
final String indexName = getIndexName();
final String mapping = Strings.format("""
{
"properties": {
"%s": {
"type": "semantic_text",
"inference_id": "%s"
}
}
}
""", SEMANTIC_TEXT_FIELD, SPARSE_MODEL.getInferenceEntityId());

CreateIndexResponse response = createIndex(
indexName,
Settings.builder().put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat).build(),
mapping
);
assertThat(response.isAcknowledged(), equalTo(true));

indexDoc("doc_1", List.of("a test value", "with multiple test values"));
}

private void performIndexQueryHighlightOps() throws IOException {
indexDoc("doc_2", List.of("another test value"));
ObjectPath queryObjectPath = semanticQuery("test value", 3);
assertQueryResponse(queryObjectPath);
}

private String getIndexName() {
return INDEX_BASE_NAME + (useLegacyFormat ? "_legacy" : "_new");
}

private void indexDoc(String id, List<String> semanticTextFieldValue) throws IOException {
final String indexName = getIndexName();
final SemanticTextField semanticTextField = randomSemanticText(
useLegacyFormat,
SEMANTIC_TEXT_FIELD,
SPARSE_MODEL,
null,
semanticTextFieldValue,
XContentType.JSON
);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
if (useLegacyFormat == false) {
builder.field(semanticTextField.fieldName(), semanticTextFieldValue);
}
addSemanticTextInferenceResults(useLegacyFormat, builder, List.of(semanticTextField));
builder.endObject();

RequestOptions requestOptions = RequestOptions.DEFAULT.toBuilder().addParameter("refresh", "true").build();
Request request = new Request("POST", indexName + "/_doc/" + id);
request.setJsonEntity(Strings.toString(builder));
request.setOptions(requestOptions);

Response response = client().performRequest(request);
assertOK(response);
}

private ObjectPath semanticQuery(String query, Integer numOfHighlightFragments) throws IOException {
// We can't perform a real semantic query because that requires performing inference, so instead we perform an equivalent nested
// query
List<WeightedToken> weightedTokens = Arrays.stream(query.split("\\s")).map(t -> new WeightedToken(t, 1.0f)).toList();
SparseVectorQueryBuilder sparseVectorQueryBuilder = new SparseVectorQueryBuilder(
SemanticTextField.getEmbeddingsFieldName(SEMANTIC_TEXT_FIELD),
weightedTokens,
null,
null,
null,
null
);
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(
SemanticTextField.getChunksFieldName(SEMANTIC_TEXT_FIELD),
sparseVectorQueryBuilder,
ScoreMode.Max
);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.field("query", nestedQueryBuilder);
if (numOfHighlightFragments != null) {
HighlightBuilder.Field highlightField = new HighlightBuilder.Field(SEMANTIC_TEXT_FIELD);
highlightField.numOfFragments(numOfHighlightFragments);

HighlightBuilder highlightBuilder = new HighlightBuilder();
highlightBuilder.field(highlightField);

builder.field("highlight", highlightBuilder);
}
builder.endObject();

Request request = new Request("GET", getIndexName() + "/_search");
request.setJsonEntity(Strings.toString(builder));

Response response = client().performRequest(request);
return assertOKAndCreateObjectPath(response);
}

@SuppressWarnings("unchecked")
private static void assertQueryResponse(ObjectPath queryObjectPath) throws IOException {
final Map<String, List<String>> expectedHighlights = Map.of(
"doc_1",
List.of("a test value", "with multiple test values"),
"doc_2",
List.of("another test value")
);

assertThat(queryObjectPath.evaluate("hits.total.value"), equalTo(2));
assertThat(queryObjectPath.evaluateArraySize("hits.hits"), equalTo(2));

Set<String> docIds = new HashSet<>();
List<Object> hits = queryObjectPath.evaluate("hits.hits");
for (Object hit : hits) {
assertThat(hit, instanceOf(Map.class));
Map<String, Object> hitMap = (Map<String, Object>) hit;

String id = (String) hitMap.get("_id");
assertThat(id, notNullValue());
docIds.add(id);

List<String> expectedHighlight = expectedHighlights.get(id);
assertThat(expectedHighlight, notNullValue());
assertThat(((Map<String, Object>) hitMap.get("highlight")).get(SEMANTIC_TEXT_FIELD), equalTo(expectedHighlight));
}

assertThat(docIds, equalTo(Set.of("doc_1", "doc_2")));
}
}