Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.elasticsearch.gradle.internal.info.BuildParams
apply plugin: 'elasticsearch.internal-es-plugin'
apply plugin: 'elasticsearch.internal-cluster-test'
apply plugin: 'elasticsearch.internal-yaml-rest-test'
apply plugin: 'elasticsearch.internal-test-artifact'

restResources {
restApi {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.List;
import java.util.Map;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS;
import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
import static org.elasticsearch.test.ESTestCase.randomFrom;
import static org.elasticsearch.test.ESTestCase.randomInt;
Expand All @@ -46,9 +47,14 @@ public static TestModel createRandomInstance(TaskType taskType) {
}

public static TestModel createRandomInstance(TaskType taskType, List<SimilarityMeasure> excludedSimilarities) {
// Use a max dimension count that has a reasonable probability of being compatible with BBQ
return createRandomInstance(taskType, excludedSimilarities, BBQ_MIN_DIMS * 2);
}

public static TestModel createRandomInstance(TaskType taskType, List<SimilarityMeasure> excludedSimilarities, int maxDimensions) {
var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null;
var dimensions = taskType == TaskType.TEXT_EMBEDDING
? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 64)
? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions)
: null;

SimilarityMeasure similarity = null;
Expand Down
2 changes: 2 additions & 0 deletions x-pack/qa/rolling-upgrade/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ apply plugin: 'elasticsearch.bwc-test'
apply plugin: 'elasticsearch.rest-resources'

dependencies {
testImplementation testArtifact(project(':server'))
testImplementation testArtifact(project(xpackModule('core')))
testImplementation project(':x-pack:qa')
testImplementation project(':modules:reindex')
testImplementation testArtifact(project(xpackModule('inference')))
}

restResources {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.upgrades;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.Version;
import org.elasticsearch.action.admin.indices.create.CreateIndexResponse;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.Response;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
import org.elasticsearch.index.query.NestedQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.test.rest.ObjectPath;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.junit.BeforeClass;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests.addSemanticTextInferenceResults;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;

public class SemanticTextUpgradeIT extends AbstractUpgradeTestCase {
private static final String INDEX_BASE_NAME = "semantic_text_test_index";
private static final String SPARSE_FIELD = "sparse_field";
private static final String DENSE_FIELD = "dense_field";
private static final Version UPGRADE_FROM_VERSION_PARSED = Version.fromString(UPGRADE_FROM_VERSION);

private static final String DOC_1_ID = "doc_1";
private static final String DOC_2_ID = "doc_2";
private static final Map<String, List<String>> DOC_VALUES = Map.of(
DOC_1_ID,
List.of("a test value", "with multiple test values"),
DOC_2_ID,
List.of("another test value")
);

private static Model SPARSE_MODEL;
private static Model DENSE_MODEL;

private final boolean useLegacyFormat;

@BeforeClass
public static void beforeClass() {
SPARSE_MODEL = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
// Exclude dot product because we are not producing unit length vectors
DENSE_MODEL = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT));
}

public SemanticTextUpgradeIT(boolean useLegacyFormat) {
this.useLegacyFormat = useLegacyFormat;
}

@ParametersFactory
public static Iterable<Object[]> parameters() {
List<Object[]> parameters = new ArrayList<>();
parameters.add(new Object[] { true });
if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0)) {
// New semantic text format added in 8.18
parameters.add(new Object[] { false });
}
return parameters;
}

public void testSemanticTextOperations() throws Exception {
assumeTrue("Upgrade from version supports semantic text", UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_15_0));
switch (CLUSTER_TYPE) {
case OLD -> createAndPopulateIndex();
case MIXED, UPGRADED -> performIndexQueryHighlightOps();
default -> throw new UnsupportedOperationException("Unknown cluster type [" + CLUSTER_TYPE + "]");
}
}

private void createAndPopulateIndex() throws IOException {
final String indexName = getIndexName();
final String mapping = Strings.format("""
{
"properties": {
"%s": {
"type": "semantic_text",
"inference_id": "%s"
},
"%s": {
"type": "semantic_text",
"inference_id": "%s"
}
}
}
""", SPARSE_FIELD, SPARSE_MODEL.getInferenceEntityId(), DENSE_FIELD, DENSE_MODEL.getInferenceEntityId());

Settings.Builder settingsBuilder = Settings.builder();
if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0)) {
settingsBuilder.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat);
}

CreateIndexResponse response = createIndex(indexName, settingsBuilder.build(), mapping);
assertThat(response.isAcknowledged(), equalTo(true));

indexDoc(DOC_1_ID, DOC_VALUES.get(DOC_1_ID));
}

private void performIndexQueryHighlightOps() throws IOException {
indexDoc(DOC_2_ID, DOC_VALUES.get(DOC_2_ID));

ObjectPath sparseQueryObjectPath = semanticQuery(SPARSE_FIELD, SPARSE_MODEL, "test value", 3);
assertQueryResponseWithHighlights(sparseQueryObjectPath, SPARSE_FIELD);

ObjectPath denseQueryObjectPath = semanticQuery(DENSE_FIELD, DENSE_MODEL, "test value", 3);
assertQueryResponseWithHighlights(denseQueryObjectPath, DENSE_FIELD);
}

private String getIndexName() {
return INDEX_BASE_NAME + (useLegacyFormat ? "_legacy" : "_new");
}

private void indexDoc(String id, List<String> semanticTextFieldValue) throws IOException {
final String indexName = getIndexName();
final SemanticTextField sparseFieldValue = randomSemanticText(
useLegacyFormat,
SPARSE_FIELD,
SPARSE_MODEL,
null,
semanticTextFieldValue,
XContentType.JSON
);
final SemanticTextField denseFieldValue = randomSemanticText(
useLegacyFormat,
DENSE_FIELD,
DENSE_MODEL,
null,
semanticTextFieldValue,
XContentType.JSON
);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
if (useLegacyFormat == false) {
builder.field(sparseFieldValue.fieldName(), semanticTextFieldValue);
builder.field(denseFieldValue.fieldName(), semanticTextFieldValue);
}
addSemanticTextInferenceResults(useLegacyFormat, builder, List.of(sparseFieldValue, denseFieldValue));
builder.endObject();

RequestOptions requestOptions = RequestOptions.DEFAULT.toBuilder().addParameter("refresh", "true").build();
Request request = new Request("POST", indexName + "/_doc/" + id);
request.setJsonEntity(Strings.toString(builder));
request.setOptions(requestOptions);

Response response = client().performRequest(request);
assertOK(response);
}

private ObjectPath semanticQuery(String field, Model fieldModel, String query, Integer numOfHighlightFragments) throws IOException {
// We can't perform a real semantic query because that requires performing inference, so instead we perform an equivalent nested
// query
final String embeddingsFieldName = SemanticTextField.getEmbeddingsFieldName(field);
final QueryBuilder innerQueryBuilder = switch (fieldModel.getTaskType()) {
case SPARSE_EMBEDDING -> {
List<WeightedToken> weightedTokens = Arrays.stream(query.split("\\s")).map(t -> new WeightedToken(t, 1.0f)).toList();
yield new SparseVectorQueryBuilder(embeddingsFieldName, weightedTokens, null, null, null, null);
}
case TEXT_EMBEDDING -> {
DenseVectorFieldMapper.ElementType elementType = fieldModel.getServiceSettings().elementType();
int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(
elementType,
fieldModel.getServiceSettings().dimensions()
);

// Create a query vector with a value of 1 for each dimension, which will effectively act as a pass-through for the document
// vector
float[] queryVector = new float[embeddingLength];
if (elementType == DenseVectorFieldMapper.ElementType.BIT) {
Arrays.fill(queryVector, -128.0f);
} else {
Arrays.fill(queryVector, 1.0f);
}

yield new KnnVectorQueryBuilder(embeddingsFieldName, queryVector, DOC_VALUES.size(), null, null, null);
}
default -> throw new UnsupportedOperationException("Unhandled task type [" + fieldModel.getTaskType() + "]");
};

NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(
SemanticTextField.getChunksFieldName(field),
innerQueryBuilder,
ScoreMode.Max
);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.field("query", nestedQueryBuilder);
if (numOfHighlightFragments != null) {
HighlightBuilder.Field highlightField = new HighlightBuilder.Field(field);
highlightField.numOfFragments(numOfHighlightFragments);

HighlightBuilder highlightBuilder = new HighlightBuilder();
highlightBuilder.field(highlightField);

builder.field("highlight", highlightBuilder);
}
builder.endObject();

Request request = new Request("GET", getIndexName() + "/_search");
request.setJsonEntity(Strings.toString(builder));

Response response = client().performRequest(request);
return assertOKAndCreateObjectPath(response);
}

private static void assertQueryResponseWithHighlights(ObjectPath queryObjectPath, String field) throws IOException {
assertThat(queryObjectPath.evaluate("hits.total.value"), equalTo(2));
assertThat(queryObjectPath.evaluateArraySize("hits.hits"), equalTo(2));

Set<String> docIds = new HashSet<>();
List<Map<String, Object>> hits = queryObjectPath.evaluate("hits.hits");
for (Map<String, Object> hit : hits) {
String id = ObjectPath.evaluate(hit, "_id");
assertThat(id, notNullValue());
docIds.add(id);

if (UPGRADE_FROM_VERSION_PARSED.onOrAfter(Version.V_8_18_0) || CLUSTER_TYPE == ClusterType.UPGRADED) {
// Semantic highlighting only functions reliably on clusters where all nodes are 8.18.0 or later
List<String> expectedHighlight = DOC_VALUES.get(id);
assertThat(expectedHighlight, notNullValue());
assertThat(ObjectPath.evaluate(hit, "highlight." + field), equalTo(expectedHighlight));
}
}

assertThat(docIds, equalTo(Set.of(DOC_1_ID, DOC_2_ID)));
}
}