diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 0539e71dd7ab3..1ef7524a0fc8e 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -9,6 +9,7 @@ import org.elasticsearch.gradle.internal.info.BuildParams apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' apply plugin: 'elasticsearch.internal-yaml-rest-test' +apply plugin: 'elasticsearch.internal-test-artifact' restResources { restApi { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index a00f8e55a4e27..c3b50cdb4a670 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -31,6 +31,7 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; import static org.elasticsearch.test.ESTestCase.randomFrom; import static org.elasticsearch.test.ESTestCase.randomInt; @@ -46,9 +47,14 @@ public static TestModel createRandomInstance(TaskType taskType) { } public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities) { + // Use a max dimension count that has a reasonable probability of being compatible with BBQ + return createRandomInstance(taskType, excludedSimilarities, BBQ_MIN_DIMS * 2); + } + + public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities, int maxDimensions) { var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null; var dimensions = taskType == TaskType.TEXT_EMBEDDING - ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 64) + ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions) : null; SimilarityMeasure similarity = null; diff --git a/x-pack/qa/rolling-upgrade/build.gradle b/x-pack/qa/rolling-upgrade/build.gradle index 3485277e340af..e0f00aa200f5b 100644 --- a/x-pack/qa/rolling-upgrade/build.gradle +++ b/x-pack/qa/rolling-upgrade/build.gradle @@ -8,9 +8,11 @@ apply plugin: 'elasticsearch.bwc-test' apply plugin: 'elasticsearch.rest-resources' dependencies { + testImplementation testArtifact(project(':server')) testImplementation testArtifact(project(xpackModule('core'))) testImplementation project(':x-pack:qa') testImplementation project(':modules:reindex') + testImplementation testArtifact(project(xpackModule('inference'))) } restResources { diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java new file mode 100644 index 0000000000000..93468fa5c6243 --- /dev/null +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java @@ -0,0 +1,261 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.upgrades; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.Version; +import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; +import org.elasticsearch.index.query.NestedQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.test.rest.ObjectPath; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.model.TestModel; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests.addSemanticTextInferenceResults; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; + +public class SemanticTextUpgradeIT extends AbstractUpgradeTestCase { + private static final String INDEX_BASE_NAME = "semantic_text_test_index"; + private static final String SPARSE_FIELD = "sparse_field"; + private static final String DENSE_FIELD = "dense_field"; + private static final Version UPGRADE_FROM_VERSION_PARSED = Version.fromString(UPGRADE_FROM_VERSION); + + private static final String DOC_1_ID = "doc_1"; + private static final String DOC_2_ID = "doc_2"; + private static final Map> DOC_VALUES = Map.of( + DOC_1_ID, + List.of("a test value", "with multiple test values"), + DOC_2_ID, + List.of("another test value") + ); + + private static Model SPARSE_MODEL; + private static Model DENSE_MODEL; + + private final boolean useLegacyFormat; + + @BeforeClass + public static void beforeClass() { + SPARSE_MODEL = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + // Exclude dot product because we are not producing unit length vectors + DENSE_MODEL = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT)); + } + + public SemanticTextUpgradeIT(boolean useLegacyFormat) { + this.useLegacyFormat = useLegacyFormat; + } + + @ParametersFactory + public static Iterable parameters() { + List parameters = new ArrayList<>(); + parameters.add(new Object[] { true }); + if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0)) { + // New semantic text format added in 8.18 + parameters.add(new Object[] { false }); + } + return parameters; + } + + public void testSemanticTextOperations() throws Exception { + assumeTrue("Upgrade from version supports semantic text", UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_15_0)); + switch (CLUSTER_TYPE) { + case OLD -> createAndPopulateIndex(); + case MIXED, UPGRADED -> performIndexQueryHighlightOps(); + default -> throw new UnsupportedOperationException("Unknown cluster type [" + CLUSTER_TYPE + "]"); + } + } + + private void createAndPopulateIndex() throws IOException { + final String indexName = getIndexName(); + final String mapping = Strings.format(""" + { + "properties": { + "%s": { + "type": "semantic_text", + "inference_id": "%s" + }, + "%s": { + "type": "semantic_text", + "inference_id": "%s" + } + } + } + """, SPARSE_FIELD, SPARSE_MODEL.getInferenceEntityId(), DENSE_FIELD, DENSE_MODEL.getInferenceEntityId()); + + Settings.Builder settingsBuilder = Settings.builder(); + if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0)) { + settingsBuilder.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat); + } + + CreateIndexResponse response = createIndex(indexName, settingsBuilder.build(), mapping); + assertThat(response.isAcknowledged(), equalTo(true)); + + indexDoc(DOC_1_ID, DOC_VALUES.get(DOC_1_ID)); + } + + private void performIndexQueryHighlightOps() throws IOException { + indexDoc(DOC_2_ID, DOC_VALUES.get(DOC_2_ID)); + + ObjectPath sparseQueryObjectPath = semanticQuery(SPARSE_FIELD, SPARSE_MODEL, "test value", 3); + assertQueryResponseWithHighlights(sparseQueryObjectPath, SPARSE_FIELD); + + ObjectPath denseQueryObjectPath = semanticQuery(DENSE_FIELD, DENSE_MODEL, "test value", 3); + assertQueryResponseWithHighlights(denseQueryObjectPath, DENSE_FIELD); + } + + private String getIndexName() { + return INDEX_BASE_NAME + (useLegacyFormat ? "_legacy" : "_new"); + } + + private void indexDoc(String id, List semanticTextFieldValue) throws IOException { + final String indexName = getIndexName(); + final SemanticTextField sparseFieldValue = randomSemanticText( + useLegacyFormat, + SPARSE_FIELD, + SPARSE_MODEL, + null, + semanticTextFieldValue, + XContentType.JSON + ); + final SemanticTextField denseFieldValue = randomSemanticText( + useLegacyFormat, + DENSE_FIELD, + DENSE_MODEL, + null, + semanticTextFieldValue, + XContentType.JSON + ); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + if (useLegacyFormat == false) { + builder.field(sparseFieldValue.fieldName(), semanticTextFieldValue); + builder.field(denseFieldValue.fieldName(), semanticTextFieldValue); + } + addSemanticTextInferenceResults(useLegacyFormat, builder, List.of(sparseFieldValue, denseFieldValue)); + builder.endObject(); + + RequestOptions requestOptions = RequestOptions.DEFAULT.toBuilder().addParameter("refresh", "true").build(); + Request request = new Request("POST", indexName + "/_doc/" + id); + request.setJsonEntity(Strings.toString(builder)); + request.setOptions(requestOptions); + + Response response = client().performRequest(request); + assertOK(response); + } + + private ObjectPath semanticQuery(String field, Model fieldModel, String query, Integer numOfHighlightFragments) throws IOException { + // We can't perform a real semantic query because that requires performing inference, so instead we perform an equivalent nested + // query + final String embeddingsFieldName = SemanticTextField.getEmbeddingsFieldName(field); + final QueryBuilder innerQueryBuilder = switch (fieldModel.getTaskType()) { + case SPARSE_EMBEDDING -> { + List weightedTokens = Arrays.stream(query.split("\\s")).map(t -> new WeightedToken(t, 1.0f)).toList(); + yield new SparseVectorQueryBuilder(embeddingsFieldName, weightedTokens, null, null, null, null); + } + case TEXT_EMBEDDING -> { + DenseVectorFieldMapper.ElementType elementType = fieldModel.getServiceSettings().elementType(); + int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength( + elementType, + fieldModel.getServiceSettings().dimensions() + ); + + // Create a query vector with a value of 1 for each dimension, which will effectively act as a pass-through for the document + // vector + float[] queryVector = new float[embeddingLength]; + if (elementType == DenseVectorFieldMapper.ElementType.BIT) { + Arrays.fill(queryVector, -128.0f); + } else { + Arrays.fill(queryVector, 1.0f); + } + + yield new KnnVectorQueryBuilder(embeddingsFieldName, queryVector, DOC_VALUES.size(), null, null, null); + } + default -> throw new UnsupportedOperationException("Unhandled task type [" + fieldModel.getTaskType() + "]"); + }; + + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder( + SemanticTextField.getChunksFieldName(field), + innerQueryBuilder, + ScoreMode.Max + ); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field("query", nestedQueryBuilder); + if (numOfHighlightFragments != null) { + HighlightBuilder.Field highlightField = new HighlightBuilder.Field(field); + highlightField.numOfFragments(numOfHighlightFragments); + + HighlightBuilder highlightBuilder = new HighlightBuilder(); + highlightBuilder.field(highlightField); + + builder.field("highlight", highlightBuilder); + } + builder.endObject(); + + Request request = new Request("GET", getIndexName() + "/_search"); + request.setJsonEntity(Strings.toString(builder)); + + Response response = client().performRequest(request); + return assertOKAndCreateObjectPath(response); + } + + private static void assertQueryResponseWithHighlights(ObjectPath queryObjectPath, String field) throws IOException { + assertThat(queryObjectPath.evaluate("hits.total.value"), equalTo(2)); + assertThat(queryObjectPath.evaluateArraySize("hits.hits"), equalTo(2)); + + Set docIds = new HashSet<>(); + List> hits = queryObjectPath.evaluate("hits.hits"); + for (Map hit : hits) { + String id = ObjectPath.evaluate(hit, "_id"); + assertThat(id, notNullValue()); + docIds.add(id); + + if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0) || CLUSTER_TYPE == ClusterType.UPGRADED) { + // Semantic highlighting only functions reliably on clusters where all nodes are 8.18.0 or later + List expectedHighlight = DOC_VALUES.get(id); + assertThat(expectedHighlight, notNullValue()); + assertThat(ObjectPath.evaluate(hit, "highlight." + field), equalTo(expectedHighlight)); + } + } + + assertThat(docIds, equalTo(Set.of(DOC_1_ID, DOC_2_ID))); + } +}