diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java new file mode 100644 index 0000000000000..9478508da88d0 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.mapper.vectors; + +import com.carrotsearch.randomizedtesting.RandomizedContext; +import com.carrotsearch.randomizedtesting.generators.RandomNumbers; + +import org.elasticsearch.inference.SimilarityMeasure; + +import java.util.List; +import java.util.Random; + +public class DenseVectorFieldMapperTestUtils { + private DenseVectorFieldMapperTestUtils() {} + + public static List getSupportedSimilarities(DenseVectorFieldMapper.ElementType elementType) { + return switch (elementType) { + case FLOAT, BYTE -> List.of(SimilarityMeasure.values()); + case BIT -> List.of(SimilarityMeasure.L2_NORM); + }; + } + + public static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) { + return switch (elementType) { + case FLOAT, BYTE -> dimensions; + case BIT -> { + assert dimensions % Byte.SIZE == 0; + yield dimensions / Byte.SIZE; + } + }; + } + + public static int randomCompatibleDimensions(DenseVectorFieldMapper.ElementType elementType, int max) { + if (max < 1) { + throw new IllegalArgumentException("max must be at least 1"); + } + + return switch (elementType) { + case FLOAT, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max); + case BIT -> { + if (max < 8) { + throw new IllegalArgumentException("max must be at least 8 for bit vectors"); + } + + // Generate a random dimension count that is a multiple of 8 + int maxEmbeddingLength = max / 8; + yield RandomNumbers.randomIntBetween(random(), 1, maxEmbeddingLength) * 8; + } + }; + } + + private static Random random() { + return RandomizedContext.current().getRandom(); + } +} diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 8864c2b6650ba..b0d7ce76d2533 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -9,6 +9,7 @@ import org.elasticsearch.gradle.internal.info.BuildParams apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' apply plugin: 'elasticsearch.internal-yaml-rest-test' +apply plugin: 'elasticsearch.internal-test-artifact' restResources { restApi { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index 779a98e023455..43ad8b8be0cdd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -25,9 +26,13 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; import static org.elasticsearch.test.ESTestCase.randomFrom; import static org.elasticsearch.test.ESTestCase.randomInt; @@ -39,9 +44,60 @@ public static TestModel createRandomInstance() { } public static TestModel createRandomInstance(TaskType taskType) { - var dimensions = taskType == TaskType.TEXT_EMBEDDING ? randomInt(64) : null; - var similarity = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(SimilarityMeasure.values()) : null; - var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null; + return createRandomInstance(taskType, null, null); + } + + public static TestModel createRandomInstance( + TaskType taskType, + List excludedElementTypes, + List excludedSimilarities + ) { + // Use a max dimension count that has a reasonable probability of being compatible with BBQ + return createRandomInstance(taskType, excludedElementTypes, excludedSimilarities, BBQ_MIN_DIMS * 2); + } + + public static TestModel createRandomInstance( + TaskType taskType, + List excludedElementTypes, + List excludedSimilarities, + int maxDimensions + ) { + List supportedElementTypes = new ArrayList<>( + Arrays.asList(DenseVectorFieldMapper.ElementType.values()) + ); + if (excludedElementTypes != null) { + supportedElementTypes.removeAll(excludedElementTypes); + if (supportedElementTypes.isEmpty()) { + throw new IllegalArgumentException("No supported element types with excluded element types " + excludedElementTypes); + } + } + + var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(supportedElementTypes) : null; + var dimensions = taskType == TaskType.TEXT_EMBEDDING + ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions) + : null; + + SimilarityMeasure similarity = null; + if (taskType == TaskType.TEXT_EMBEDDING) { + List supportedSimilarities = new ArrayList<>( + DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType) + ); + if (excludedSimilarities != null) { + supportedSimilarities.removeAll(excludedSimilarities); + } + + if (supportedSimilarities.isEmpty()) { + throw new IllegalArgumentException( + "No supported similarities for combination of element type [" + + elementType + + "] and excluded similarities " + + (excludedSimilarities == null ? List.of() : excludedSimilarities) + ); + } + + similarity = randomFrom(supportedSimilarities); + } + return new TestModel( randomAlphaOfLength(4), taskType, diff --git a/x-pack/qa/rolling-upgrade/build.gradle b/x-pack/qa/rolling-upgrade/build.gradle index ac3361cb2a19c..74eb9c6ccf089 100644 --- a/x-pack/qa/rolling-upgrade/build.gradle +++ b/x-pack/qa/rolling-upgrade/build.gradle @@ -8,9 +8,11 @@ apply plugin: 'elasticsearch.bwc-test' apply plugin: 'elasticsearch.rest-resources' dependencies { + testImplementation testArtifact(project(':server')) testImplementation testArtifact(project(xpackModule('core'))) testImplementation project(':x-pack:qa') testImplementation project(':modules:reindex') + testImplementation testArtifact(project(xpackModule('inference'))) } restResources { diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java new file mode 100644 index 0000000000000..c8d7d6aa0e482 --- /dev/null +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java @@ -0,0 +1,264 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.upgrades; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.Version; +import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; +import org.elasticsearch.index.query.NestedQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.test.rest.ObjectPath; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.model.TestModel; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests.addSemanticTextInferenceResults; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; + +public class SemanticTextUpgradeIT extends AbstractUpgradeTestCase { + private static final String INDEX_BASE_NAME = "semantic_text_test_index"; + private static final String SPARSE_FIELD = "sparse_field"; + private static final String DENSE_FIELD = "dense_field"; + private static final Version UPGRADE_FROM_VERSION_PARSED = Version.fromString(UPGRADE_FROM_VERSION); + + private static final String DOC_1_ID = "doc_1"; + private static final String DOC_2_ID = "doc_2"; + private static final Map> DOC_VALUES = Map.of( + DOC_1_ID, + List.of("a test value", "with multiple test values"), + DOC_2_ID, + List.of("another test value") + ); + + private static Model SPARSE_MODEL; + private static Model DENSE_MODEL; + + private final boolean useLegacyFormat; + + @BeforeClass + public static void beforeClass() { + SPARSE_MODEL = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + // Exclude bit vectors because semantic text does not fully support them + // Exclude dot product because we are not producing unit length vectors + DENSE_MODEL = TestModel.createRandomInstance( + TaskType.TEXT_EMBEDDING, + List.of(DenseVectorFieldMapper.ElementType.BIT), + List.of(SimilarityMeasure.DOT_PRODUCT) + ); + } + + public SemanticTextUpgradeIT(boolean useLegacyFormat) { + this.useLegacyFormat = useLegacyFormat; + } + + @ParametersFactory + public static Iterable parameters() { + List parameters = new ArrayList<>(); + parameters.add(new Object[] { true }); + if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0)) { + // New semantic text format added in 8.18 + parameters.add(new Object[] { false }); + } + return parameters; + } + + public void testSemanticTextOperations() throws Exception { + assumeTrue("Upgrade from version supports semantic text", UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_15_0)); + switch (CLUSTER_TYPE) { + case OLD -> createAndPopulateIndex(); + case MIXED, UPGRADED -> performIndexQueryHighlightOps(); + default -> throw new UnsupportedOperationException("Unknown cluster type [" + CLUSTER_TYPE + "]"); + } + } + + private void createAndPopulateIndex() throws IOException { + final String indexName = getIndexName(); + final String mapping = Strings.format(""" + { + "properties": { + "%s": { + "type": "semantic_text", + "inference_id": "%s" + }, + "%s": { + "type": "semantic_text", + "inference_id": "%s" + } + } + } + """, SPARSE_FIELD, SPARSE_MODEL.getInferenceEntityId(), DENSE_FIELD, DENSE_MODEL.getInferenceEntityId()); + + Settings.Builder settingsBuilder = Settings.builder(); + if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0)) { + settingsBuilder.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat); + } + + CreateIndexResponse response = createIndex(indexName, settingsBuilder.build(), mapping); + assertThat(response.isAcknowledged(), equalTo(true)); + + indexDoc(DOC_1_ID, DOC_VALUES.get(DOC_1_ID)); + } + + private void performIndexQueryHighlightOps() throws IOException { + indexDoc(DOC_2_ID, DOC_VALUES.get(DOC_2_ID)); + + ObjectPath sparseQueryObjectPath = semanticQuery(SPARSE_FIELD, SPARSE_MODEL, "test value", 3); + assertQueryResponseWithHighlights(sparseQueryObjectPath, SPARSE_FIELD); + + ObjectPath denseQueryObjectPath = semanticQuery(DENSE_FIELD, DENSE_MODEL, "test value", 3); + assertQueryResponseWithHighlights(denseQueryObjectPath, DENSE_FIELD); + } + + private String getIndexName() { + return INDEX_BASE_NAME + (useLegacyFormat ? "_legacy" : "_new"); + } + + private void indexDoc(String id, List semanticTextFieldValue) throws IOException { + final String indexName = getIndexName(); + final SemanticTextField sparseFieldValue = randomSemanticText( + useLegacyFormat, + SPARSE_FIELD, + SPARSE_MODEL, + semanticTextFieldValue, + XContentType.JSON + ); + final SemanticTextField denseFieldValue = randomSemanticText( + useLegacyFormat, + DENSE_FIELD, + DENSE_MODEL, + semanticTextFieldValue, + XContentType.JSON + ); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + if (useLegacyFormat == false) { + builder.field(sparseFieldValue.fieldName(), semanticTextFieldValue); + builder.field(denseFieldValue.fieldName(), semanticTextFieldValue); + } + addSemanticTextInferenceResults(useLegacyFormat, builder, List.of(sparseFieldValue, denseFieldValue)); + builder.endObject(); + + RequestOptions requestOptions = RequestOptions.DEFAULT.toBuilder().addParameter("refresh", "true").build(); + Request request = new Request("POST", indexName + "/_doc/" + id); + request.setJsonEntity(Strings.toString(builder)); + request.setOptions(requestOptions); + + Response response = client().performRequest(request); + assertOK(response); + } + + private ObjectPath semanticQuery(String field, Model fieldModel, String query, Integer numOfHighlightFragments) throws IOException { + // We can't perform a real semantic query because that requires performing inference, so instead we perform an equivalent nested + // query + final String embeddingsFieldName = SemanticTextField.getEmbeddingsFieldName(field); + final QueryBuilder innerQueryBuilder = switch (fieldModel.getTaskType()) { + case SPARSE_EMBEDDING -> { + List weightedTokens = Arrays.stream(query.split("\\s")).map(t -> new WeightedToken(t, 1.0f)).toList(); + yield new SparseVectorQueryBuilder(embeddingsFieldName, weightedTokens, null, null, null, null); + } + case TEXT_EMBEDDING -> { + DenseVectorFieldMapper.ElementType elementType = fieldModel.getServiceSettings().elementType(); + int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength( + elementType, + fieldModel.getServiceSettings().dimensions() + ); + + // Create a query vector with a value of 1 for each dimension, which will effectively act as a pass-through for the document + // vector + float[] queryVector = new float[embeddingLength]; + if (elementType == DenseVectorFieldMapper.ElementType.BIT) { + Arrays.fill(queryVector, -128.0f); + } else { + Arrays.fill(queryVector, 1.0f); + } + + yield new KnnVectorQueryBuilder(embeddingsFieldName, queryVector, DOC_VALUES.size(), null, null, null); + } + default -> throw new UnsupportedOperationException("Unhandled task type [" + fieldModel.getTaskType() + "]"); + }; + + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder( + SemanticTextField.getChunksFieldName(field), + innerQueryBuilder, + ScoreMode.Max + ); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field("query", nestedQueryBuilder); + if (numOfHighlightFragments != null) { + HighlightBuilder.Field highlightField = new HighlightBuilder.Field(field); + highlightField.numOfFragments(numOfHighlightFragments); + + HighlightBuilder highlightBuilder = new HighlightBuilder(); + highlightBuilder.field(highlightField); + + builder.field("highlight", highlightBuilder); + } + builder.endObject(); + + Request request = new Request("GET", getIndexName() + "/_search"); + request.setJsonEntity(Strings.toString(builder)); + + Response response = client().performRequest(request); + return assertOKAndCreateObjectPath(response); + } + + private static void assertQueryResponseWithHighlights(ObjectPath queryObjectPath, String field) throws IOException { + assertThat(queryObjectPath.evaluate("hits.total.value"), equalTo(2)); + assertThat(queryObjectPath.evaluateArraySize("hits.hits"), equalTo(2)); + + Set docIds = new HashSet<>(); + List> hits = queryObjectPath.evaluate("hits.hits"); + for (Map hit : hits) { + String id = ObjectPath.evaluate(hit, "_id"); + assertThat(id, notNullValue()); + docIds.add(id); + + if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0) || CLUSTER_TYPE == ClusterType.UPGRADED) { + // Semantic highlighting only functions reliably on clusters where all nodes are 8.18.0 or later + List expectedHighlight = DOC_VALUES.get(id); + assertThat(expectedHighlight, notNullValue()); + assertThat(ObjectPath.evaluate(hit, "highlight." + field), equalTo(expectedHighlight)); + } + } + + assertThat(docIds, equalTo(Set.of(DOC_1_ID, DOC_2_ID))); + } +}