From 0a06056fb98f4d5122e79114dca521ee2715a296 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 7 May 2025 13:47:30 -0400 Subject: [PATCH 1/3] [8.19] Semantic Text Rolling Upgrade Tests (#126548) (#127748) * Semantic Text Rolling Upgrade Tests (#126548) * Fix test failures --------- Co-authored-by: Elastic Machine (cherry picked from commit a3a64ea2fbb7107ada66f645ca0701260b8377b0) # Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java --- x-pack/plugin/inference/build.gradle | 1 + .../xpack/inference/model/TestModel.java | 37 ++- x-pack/qa/rolling-upgrade/build.gradle | 2 + .../upgrades/SemanticTextUpgradeIT.java | 261 ++++++++++++++++++ 4 files changed, 299 insertions(+), 2 deletions(-) create mode 100644 x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 8864c2b6650ba..b0d7ce76d2533 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -9,6 +9,7 @@ import org.elasticsearch.gradle.internal.info.BuildParams apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' apply plugin: 'elasticsearch.internal-yaml-rest-test' +apply plugin: 'elasticsearch.internal-test-artifact' restResources { restApi { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index 779a98e023455..315c05d86b368 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -28,6 +28,7 @@ import java.util.HashMap; import java.util.Map; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; import static org.elasticsearch.test.ESTestCase.randomFrom; import static org.elasticsearch.test.ESTestCase.randomInt; @@ -39,9 +40,41 @@ public static TestModel createRandomInstance() { } public static TestModel createRandomInstance(TaskType taskType) { - var dimensions = taskType == TaskType.TEXT_EMBEDDING ? randomInt(64) : null; - var similarity = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(SimilarityMeasure.values()) : null; + return createRandomInstance(taskType, null); + } + + public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities) { + // Use a max dimension count that has a reasonable probability of being compatible with BBQ + return createRandomInstance(taskType, excludedSimilarities, BBQ_MIN_DIMS * 2); + } + + public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities, int maxDimensions) { var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null; + var dimensions = taskType == TaskType.TEXT_EMBEDDING + ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions) + : null; + + SimilarityMeasure similarity = null; + if (taskType == TaskType.TEXT_EMBEDDING) { + List supportedSimilarities = new ArrayList<>( + DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType) + ); + if (excludedSimilarities != null) { + supportedSimilarities.removeAll(excludedSimilarities); + } + + if (supportedSimilarities.isEmpty()) { + throw new IllegalArgumentException( + "No supported similarities for combination of element type [" + + elementType + + "] and excluded similarities " + + (excludedSimilarities == null ? List.of() : excludedSimilarities) + ); + } + + similarity = randomFrom(supportedSimilarities); + } + return new TestModel( randomAlphaOfLength(4), taskType, diff --git a/x-pack/qa/rolling-upgrade/build.gradle b/x-pack/qa/rolling-upgrade/build.gradle index ac3361cb2a19c..74eb9c6ccf089 100644 --- a/x-pack/qa/rolling-upgrade/build.gradle +++ b/x-pack/qa/rolling-upgrade/build.gradle @@ -8,9 +8,11 @@ apply plugin: 'elasticsearch.bwc-test' apply plugin: 'elasticsearch.rest-resources' dependencies { + testImplementation testArtifact(project(':server')) testImplementation testArtifact(project(xpackModule('core'))) testImplementation project(':x-pack:qa') testImplementation project(':modules:reindex') + testImplementation testArtifact(project(xpackModule('inference'))) } restResources { diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java new file mode 100644 index 0000000000000..93468fa5c6243 --- /dev/null +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java @@ -0,0 +1,261 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.upgrades; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.Version; +import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; +import org.elasticsearch.index.query.NestedQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.test.rest.ObjectPath; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.model.TestModel; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests.addSemanticTextInferenceResults; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; + +public class SemanticTextUpgradeIT extends AbstractUpgradeTestCase { + private static final String INDEX_BASE_NAME = "semantic_text_test_index"; + private static final String SPARSE_FIELD = "sparse_field"; + private static final String DENSE_FIELD = "dense_field"; + private static final Version UPGRADE_FROM_VERSION_PARSED = Version.fromString(UPGRADE_FROM_VERSION); + + private static final String DOC_1_ID = "doc_1"; + private static final String DOC_2_ID = "doc_2"; + private static final Map> DOC_VALUES = Map.of( + DOC_1_ID, + List.of("a test value", "with multiple test values"), + DOC_2_ID, + List.of("another test value") + ); + + private static Model SPARSE_MODEL; + private static Model DENSE_MODEL; + + private final boolean useLegacyFormat; + + @BeforeClass + public static void beforeClass() { + SPARSE_MODEL = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + // Exclude dot product because we are not producing unit length vectors + DENSE_MODEL = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT)); + } + + public SemanticTextUpgradeIT(boolean useLegacyFormat) { + this.useLegacyFormat = useLegacyFormat; + } + + @ParametersFactory + public static Iterable parameters() { + List parameters = new ArrayList<>(); + parameters.add(new Object[] { true }); + if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0)) { + // New semantic text format added in 8.18 + parameters.add(new Object[] { false }); + } + return parameters; + } + + public void testSemanticTextOperations() throws Exception { + assumeTrue("Upgrade from version supports semantic text", UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_15_0)); + switch (CLUSTER_TYPE) { + case OLD -> createAndPopulateIndex(); + case MIXED, UPGRADED -> performIndexQueryHighlightOps(); + default -> throw new UnsupportedOperationException("Unknown cluster type [" + CLUSTER_TYPE + "]"); + } + } + + private void createAndPopulateIndex() throws IOException { + final String indexName = getIndexName(); + final String mapping = Strings.format(""" + { + "properties": { + "%s": { + "type": "semantic_text", + "inference_id": "%s" + }, + "%s": { + "type": "semantic_text", + "inference_id": "%s" + } + } + } + """, SPARSE_FIELD, SPARSE_MODEL.getInferenceEntityId(), DENSE_FIELD, DENSE_MODEL.getInferenceEntityId()); + + Settings.Builder settingsBuilder = Settings.builder(); + if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0)) { + settingsBuilder.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat); + } + + CreateIndexResponse response = createIndex(indexName, settingsBuilder.build(), mapping); + assertThat(response.isAcknowledged(), equalTo(true)); + + indexDoc(DOC_1_ID, DOC_VALUES.get(DOC_1_ID)); + } + + private void performIndexQueryHighlightOps() throws IOException { + indexDoc(DOC_2_ID, DOC_VALUES.get(DOC_2_ID)); + + ObjectPath sparseQueryObjectPath = semanticQuery(SPARSE_FIELD, SPARSE_MODEL, "test value", 3); + assertQueryResponseWithHighlights(sparseQueryObjectPath, SPARSE_FIELD); + + ObjectPath denseQueryObjectPath = semanticQuery(DENSE_FIELD, DENSE_MODEL, "test value", 3); + assertQueryResponseWithHighlights(denseQueryObjectPath, DENSE_FIELD); + } + + private String getIndexName() { + return INDEX_BASE_NAME + (useLegacyFormat ? "_legacy" : "_new"); + } + + private void indexDoc(String id, List semanticTextFieldValue) throws IOException { + final String indexName = getIndexName(); + final SemanticTextField sparseFieldValue = randomSemanticText( + useLegacyFormat, + SPARSE_FIELD, + SPARSE_MODEL, + null, + semanticTextFieldValue, + XContentType.JSON + ); + final SemanticTextField denseFieldValue = randomSemanticText( + useLegacyFormat, + DENSE_FIELD, + DENSE_MODEL, + null, + semanticTextFieldValue, + XContentType.JSON + ); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + if (useLegacyFormat == false) { + builder.field(sparseFieldValue.fieldName(), semanticTextFieldValue); + builder.field(denseFieldValue.fieldName(), semanticTextFieldValue); + } + addSemanticTextInferenceResults(useLegacyFormat, builder, List.of(sparseFieldValue, denseFieldValue)); + builder.endObject(); + + RequestOptions requestOptions = RequestOptions.DEFAULT.toBuilder().addParameter("refresh", "true").build(); + Request request = new Request("POST", indexName + "/_doc/" + id); + request.setJsonEntity(Strings.toString(builder)); + request.setOptions(requestOptions); + + Response response = client().performRequest(request); + assertOK(response); + } + + private ObjectPath semanticQuery(String field, Model fieldModel, String query, Integer numOfHighlightFragments) throws IOException { + // We can't perform a real semantic query because that requires performing inference, so instead we perform an equivalent nested + // query + final String embeddingsFieldName = SemanticTextField.getEmbeddingsFieldName(field); + final QueryBuilder innerQueryBuilder = switch (fieldModel.getTaskType()) { + case SPARSE_EMBEDDING -> { + List weightedTokens = Arrays.stream(query.split("\\s")).map(t -> new WeightedToken(t, 1.0f)).toList(); + yield new SparseVectorQueryBuilder(embeddingsFieldName, weightedTokens, null, null, null, null); + } + case TEXT_EMBEDDING -> { + DenseVectorFieldMapper.ElementType elementType = fieldModel.getServiceSettings().elementType(); + int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength( + elementType, + fieldModel.getServiceSettings().dimensions() + ); + + // Create a query vector with a value of 1 for each dimension, which will effectively act as a pass-through for the document + // vector + float[] queryVector = new float[embeddingLength]; + if (elementType == DenseVectorFieldMapper.ElementType.BIT) { + Arrays.fill(queryVector, -128.0f); + } else { + Arrays.fill(queryVector, 1.0f); + } + + yield new KnnVectorQueryBuilder(embeddingsFieldName, queryVector, DOC_VALUES.size(), null, null, null); + } + default -> throw new UnsupportedOperationException("Unhandled task type [" + fieldModel.getTaskType() + "]"); + }; + + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder( + SemanticTextField.getChunksFieldName(field), + innerQueryBuilder, + ScoreMode.Max + ); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field("query", nestedQueryBuilder); + if (numOfHighlightFragments != null) { + HighlightBuilder.Field highlightField = new HighlightBuilder.Field(field); + highlightField.numOfFragments(numOfHighlightFragments); + + HighlightBuilder highlightBuilder = new HighlightBuilder(); + highlightBuilder.field(highlightField); + + builder.field("highlight", highlightBuilder); + } + builder.endObject(); + + Request request = new Request("GET", getIndexName() + "/_search"); + request.setJsonEntity(Strings.toString(builder)); + + Response response = client().performRequest(request); + return assertOKAndCreateObjectPath(response); + } + + private static void assertQueryResponseWithHighlights(ObjectPath queryObjectPath, String field) throws IOException { + assertThat(queryObjectPath.evaluate("hits.total.value"), equalTo(2)); + assertThat(queryObjectPath.evaluateArraySize("hits.hits"), equalTo(2)); + + Set docIds = new HashSet<>(); + List> hits = queryObjectPath.evaluate("hits.hits"); + for (Map hit : hits) { + String id = ObjectPath.evaluate(hit, "_id"); + assertThat(id, notNullValue()); + docIds.add(id); + + if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0) || CLUSTER_TYPE == ClusterType.UPGRADED) { + // Semantic highlighting only functions reliably on clusters where all nodes are 8.18.0 or later + List expectedHighlight = DOC_VALUES.get(id); + assertThat(expectedHighlight, notNullValue()); + assertThat(ObjectPath.evaluate(hit, "highlight." + field), equalTo(expectedHighlight)); + } + } + + assertThat(docIds, equalTo(Set.of(DOC_1_ID, DOC_2_ID))); + } +} From 6a48e1265e670dd122c74faed0dab80aa190aa5e Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 7 May 2025 19:19:44 -0400 Subject: [PATCH 2/3] Fix build errors --- .../DenseVectorFieldMapperTestUtils.java | 62 +++++++++++++++++++ .../xpack/inference/model/TestModel.java | 8 ++- .../upgrades/SemanticTextUpgradeIT.java | 2 - 3 files changed, 69 insertions(+), 3 deletions(-) create mode 100644 server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java new file mode 100644 index 0000000000000..9478508da88d0 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.mapper.vectors; + +import com.carrotsearch.randomizedtesting.RandomizedContext; +import com.carrotsearch.randomizedtesting.generators.RandomNumbers; + +import org.elasticsearch.inference.SimilarityMeasure; + +import java.util.List; +import java.util.Random; + +public class DenseVectorFieldMapperTestUtils { + private DenseVectorFieldMapperTestUtils() {} + + public static List getSupportedSimilarities(DenseVectorFieldMapper.ElementType elementType) { + return switch (elementType) { + case FLOAT, BYTE -> List.of(SimilarityMeasure.values()); + case BIT -> List.of(SimilarityMeasure.L2_NORM); + }; + } + + public static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) { + return switch (elementType) { + case FLOAT, BYTE -> dimensions; + case BIT -> { + assert dimensions % Byte.SIZE == 0; + yield dimensions / Byte.SIZE; + } + }; + } + + public static int randomCompatibleDimensions(DenseVectorFieldMapper.ElementType elementType, int max) { + if (max < 1) { + throw new IllegalArgumentException("max must be at least 1"); + } + + return switch (elementType) { + case FLOAT, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max); + case BIT -> { + if (max < 8) { + throw new IllegalArgumentException("max must be at least 8 for bit vectors"); + } + + // Generate a random dimension count that is a multiple of 8 + int maxEmbeddingLength = max / 8; + yield RandomNumbers.randomIntBetween(random(), 1, maxEmbeddingLength) * 8; + } + }; + } + + private static Random random() { + return RandomizedContext.current().getRandom(); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index 315c05d86b368..d1dc4ee3d0b31 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -25,7 +26,9 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; @@ -49,7 +52,10 @@ public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities, int maxDimensions) { - var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null; + // Don't use bit vectors because of incomplete support + var elementType = taskType == TaskType.TEXT_EMBEDDING + ? randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BYTE) + : null; var dimensions = taskType == TaskType.TEXT_EMBEDDING ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions) : null; diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java index 93468fa5c6243..d40899d384f56 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java @@ -149,7 +149,6 @@ private void indexDoc(String id, List semanticTextFieldValue) throws IOE useLegacyFormat, SPARSE_FIELD, SPARSE_MODEL, - null, semanticTextFieldValue, XContentType.JSON ); @@ -157,7 +156,6 @@ private void indexDoc(String id, List semanticTextFieldValue) throws IOE useLegacyFormat, DENSE_FIELD, DENSE_MODEL, - null, semanticTextFieldValue, XContentType.JSON ); From c361ae1eacca972c58aed4f869de5c59d500f6f8 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 8 May 2025 08:04:44 -0400 Subject: [PATCH 3/3] Exclude bit vectors only when running semantic text rolling upgrade tests --- .../xpack/inference/model/TestModel.java | 33 ++++++++++++++----- .../upgrades/SemanticTextUpgradeIT.java | 7 +++- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index d1dc4ee3d0b31..43ad8b8be0cdd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -27,6 +27,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -43,19 +44,35 @@ public static TestModel createRandomInstance() { } public static TestModel createRandomInstance(TaskType taskType) { - return createRandomInstance(taskType, null); + return createRandomInstance(taskType, null, null); } - public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities) { + public static TestModel createRandomInstance( + TaskType taskType, + List excludedElementTypes, + List excludedSimilarities + ) { // Use a max dimension count that has a reasonable probability of being compatible with BBQ - return createRandomInstance(taskType, excludedSimilarities, BBQ_MIN_DIMS * 2); + return createRandomInstance(taskType, excludedElementTypes, excludedSimilarities, BBQ_MIN_DIMS * 2); } - public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities, int maxDimensions) { - // Don't use bit vectors because of incomplete support - var elementType = taskType == TaskType.TEXT_EMBEDDING - ? randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BYTE) - : null; + public static TestModel createRandomInstance( + TaskType taskType, + List excludedElementTypes, + List excludedSimilarities, + int maxDimensions + ) { + List supportedElementTypes = new ArrayList<>( + Arrays.asList(DenseVectorFieldMapper.ElementType.values()) + ); + if (excludedElementTypes != null) { + supportedElementTypes.removeAll(excludedElementTypes); + if (supportedElementTypes.isEmpty()) { + throw new IllegalArgumentException("No supported element types with excluded element types " + excludedElementTypes); + } + } + + var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(supportedElementTypes) : null; var dimensions = taskType == TaskType.TEXT_EMBEDDING ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions) : null; diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java index d40899d384f56..c8d7d6aa0e482 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java @@ -73,8 +73,13 @@ public class SemanticTextUpgradeIT extends AbstractUpgradeTestCase { @BeforeClass public static void beforeClass() { SPARSE_MODEL = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + // Exclude bit vectors because semantic text does not fully support them // Exclude dot product because we are not producing unit length vectors - DENSE_MODEL = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT)); + DENSE_MODEL = TestModel.createRandomInstance( + TaskType.TEXT_EMBEDDING, + List.of(DenseVectorFieldMapper.ElementType.BIT), + List.of(SimilarityMeasure.DOT_PRODUCT) + ); } public SemanticTextUpgradeIT(boolean useLegacyFormat) {