diff --git a/docs/changelog/136312.yaml b/docs/changelog/136312.yaml new file mode 100644 index 0000000000000..e2da763cf13dd --- /dev/null +++ b/docs/changelog/136312.yaml @@ -0,0 +1,7 @@ +pr: 136312 +summary: Fix _inference_fields handling on old indices +area: Vector Search +type: bug +issues: [ + 136130 +] diff --git a/server/src/main/java/org/elasticsearch/index/get/ShardGetService.java b/server/src/main/java/org/elasticsearch/index/get/ShardGetService.java index 94630d58a0ecb..f6cf928cce412 100644 --- a/server/src/main/java/org/elasticsearch/index/get/ShardGetService.java +++ b/server/src/main/java/org/elasticsearch/index/get/ShardGetService.java @@ -311,8 +311,7 @@ private GetResult innerGetFetch( fetchSourceContext = res.v1(); } - if (mappingLookup.inferenceFields().isEmpty() == false - && shouldExcludeInferenceFieldsFromSource(indexSettings, fetchSourceContext) == false) { + if (mappingLookup.inferenceFields().isEmpty() == false && shouldExcludeInferenceFieldsFromSource(fetchSourceContext) == false) { storedFieldSet.add(InferenceMetadataFieldsMapper.NAME); } @@ -424,17 +423,30 @@ private static Boolean shouldExcludeVectorsFromSourceExplicit(FetchSourceContext return fetchSourceContext != null ? fetchSourceContext.excludeVectors() : null; } - public static boolean shouldExcludeInferenceFieldsFromSource(IndexSettings indexSettings, FetchSourceContext fetchSourceContext) { - var explicit = shouldExcludeInferenceFieldsFromSourceExplicit(fetchSourceContext); - var filter = fetchSourceContext != null ? fetchSourceContext.filter() : null; - if (filter != null) { - if (filter.isPathFiltered(InferenceMetadataFieldsMapper.NAME, true)) { + public static boolean shouldExcludeInferenceFieldsFromSource(FetchSourceContext fetchSourceContext) { + if (fetchSourceContext != null) { + if (fetchSourceContext.fetchSource() == false) { + // Source is disabled return true; - } else if (filter.isExplicitlyIncluded(InferenceMetadataFieldsMapper.NAME)) { - return false; + } + + var filter = fetchSourceContext.filter(); + if (filter != null) { + if (filter.isPathFiltered(InferenceMetadataFieldsMapper.NAME, true)) { + return true; + } else if (filter.isExplicitlyIncluded(InferenceMetadataFieldsMapper.NAME)) { + return false; + } + } + + Boolean excludeInferenceFieldsExplicit = shouldExcludeInferenceFieldsFromSourceExplicit(fetchSourceContext); + if (excludeInferenceFieldsExplicit != null) { + return excludeInferenceFieldsExplicit; } } - return explicit != null ? explicit : INDEX_MAPPING_EXCLUDE_SOURCE_VECTORS_SETTING.get(indexSettings.getSettings()); + + // We always default to excluding the inference metadata field, unless the fetch source context says otherwise + return true; } private static Boolean shouldExcludeInferenceFieldsFromSourceExplicit(FetchSourceContext fetchSourceContext) { diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java index 14c392b675a65..aa27e7d2f0c82 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java @@ -134,8 +134,7 @@ private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Pr context.fetchSourceContext(res.v1()); } - if (lookup.inferenceFields().isEmpty() == false - && shouldExcludeInferenceFieldsFromSource(context.indexShard().indexSettings(), context.fetchSourceContext()) == false) { + if (lookup.inferenceFields().isEmpty() == false && shouldExcludeInferenceFieldsFromSource(context.fetchSourceContext()) == false) { // Rehydrate the inference fields into the {@code _source} because they were explicitly requested. var oldFetchFieldsContext = context.fetchFieldsContext(); var newFetchFieldsContext = new FetchFieldsContext(new ArrayList<>()); diff --git a/server/src/test/java/org/elasticsearch/index/shard/ShardGetServiceTests.java b/server/src/test/java/org/elasticsearch/index/shard/ShardGetServiceTests.java index a10a23db0b838..3b059fd2d906b 100644 --- a/server/src/test/java/org/elasticsearch/index/shard/ShardGetServiceTests.java +++ b/server/src/test/java/org/elasticsearch/index/shard/ShardGetServiceTests.java @@ -24,9 +24,12 @@ import org.elasticsearch.index.engine.TranslogOperationAsserter; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.get.GetResult; +import org.elasticsearch.index.get.ShardGetService; +import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.mapper.RoutingFieldMapper; import org.elasticsearch.index.translog.Translog; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.search.lookup.SourceFilter; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; @@ -411,6 +414,16 @@ public void testGetFromTranslog() throws IOException { closeShards(primary); } + public void testShouldExcludeInferenceFieldsFromSource() { + for (int i = 0; i < 100; i++) { + ExcludeInferenceFieldsTestScenario scenario = new ExcludeInferenceFieldsTestScenario(); + assertThat( + ShardGetService.shouldExcludeInferenceFieldsFromSource(scenario.fetchSourceContext), + equalTo(scenario.shouldExcludeInferenceFields()) + ); + } + } + Translog.Index toIndexOp(String source) throws IOException { XContentParser parser = createParser(XContentType.JSON.xContent(), source); XContentBuilder builder = XContentFactory.jsonBuilder(); @@ -425,4 +438,74 @@ Translog.Index toIndexOp(String source) throws IOException { IndexRequest.UNSET_AUTO_GENERATED_TIMESTAMP ); } + + private static class ExcludeInferenceFieldsTestScenario { + private final FetchSourceContext fetchSourceContext; + + private ExcludeInferenceFieldsTestScenario() { + this.fetchSourceContext = generateRandomFetchSourceContext(); + } + + private boolean shouldExcludeInferenceFields() { + if (fetchSourceContext != null) { + if (fetchSourceContext.fetchSource() == false) { + return true; + } + + SourceFilter filter = fetchSourceContext.filter(); + if (filter != null) { + if (Arrays.asList(filter.getExcludes()).contains(InferenceMetadataFieldsMapper.NAME)) { + return true; + } else if (filter.getIncludes().length > 0) { + return Arrays.asList(filter.getIncludes()).contains(InferenceMetadataFieldsMapper.NAME) == false; + } + } + + Boolean excludeInferenceFieldsExplicit = fetchSourceContext.excludeInferenceFields(); + if (excludeInferenceFieldsExplicit != null) { + return excludeInferenceFieldsExplicit; + } + } + + return true; + } + + private static FetchSourceContext generateRandomFetchSourceContext() { + FetchSourceContext fetchSourceContext = switch (randomIntBetween(0, 4)) { + case 0 -> FetchSourceContext.FETCH_SOURCE; + case 1 -> FetchSourceContext.FETCH_ALL_SOURCE; + case 2 -> FetchSourceContext.FETCH_ALL_SOURCE_EXCLUDE_INFERENCE_FIELDS; + case 3 -> FetchSourceContext.DO_NOT_FETCH_SOURCE; + case 4 -> null; + default -> throw new IllegalStateException("Unhandled randomized case"); + }; + + if (fetchSourceContext != null && fetchSourceContext.fetchSource()) { + String[] includes = null; + String[] excludes = null; + if (randomBoolean()) { + // Randomly include a non-existent field to test explicit inclusion handling + String field = randomBoolean() ? InferenceMetadataFieldsMapper.NAME : randomIdentifier(); + includes = new String[] { field }; + } + if (randomBoolean()) { + // Randomly exclude a non-existent field to test implicit inclusion handling + String field = randomBoolean() ? InferenceMetadataFieldsMapper.NAME : randomIdentifier(); + excludes = new String[] { field }; + } + + if (includes != null || excludes != null) { + fetchSourceContext = FetchSourceContext.of( + fetchSourceContext.fetchSource(), + fetchSourceContext.excludeVectors(), + fetchSourceContext.excludeInferenceFields(), + includes, + excludes + ); + } + } + + return fetchSourceContext; + } + } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java new file mode 100644 index 0000000000000..bb26153965499 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java @@ -0,0 +1,364 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.search.lookup.SourceFilter; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.index.IndexVersionUtils; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; +import org.junit.After; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.CoreMatchers.nullValue; + +@ESIntegTestCase.ClusterScope(minNumDataNodes = 3, maxNumDataNodes = 5) +public class SemanticTextInferenceFieldsIT extends ESIntegTestCase { + private final String indexName = randomIdentifier(); + private final Map inferenceIds = new HashMap<>(); + + private static final Map SPARSE_EMBEDDING_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key"); + private static final Map TEXT_EMBEDDING_SERVICE_SETTINGS = Map.of( + "model", + "my_model", + "dimensions", + 256, + "similarity", + "cosine", + "api_key", + "my_api_key" + ); + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + @Override + protected Collection> nodePlugins() { + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, ReindexPlugin.class, FakeMlPlugin.class); + } + + @Override + protected boolean forbidPrivateIndexSettings() { + return false; + } + + @After + public void cleanUp() { + deleteIndex(indexName); + for (var entry : inferenceIds.entrySet()) { + assertAcked( + safeGet( + client().execute( + DeleteInferenceEndpointAction.INSTANCE, + new DeleteInferenceEndpointAction.Request(entry.getKey(), entry.getValue(), true, false) + ) + ) + ); + } + } + + public void testExcludeInferenceFieldsFromSource() throws Exception { + excludeInferenceFieldsFromSourceTestCase(IndexVersion.current(), IndexVersion.current(), 10); + } + + public void testExcludeInferenceFieldsFromSourceOldIndexVersions() throws Exception { + excludeInferenceFieldsFromSourceTestCase( + IndexVersions.SEMANTIC_TEXT_FIELD_TYPE, + IndexVersionUtils.getPreviousVersion(IndexVersion.current()), + 40 + ); + } + + private void excludeInferenceFieldsFromSourceTestCase(IndexVersion minIndexVersion, IndexVersion maxIndexVersion, int iterations) + throws Exception { + final String sparseEmbeddingInferenceId = randomIdentifier(); + final String textEmbeddingInferenceId = randomIdentifier(); + createInferenceEndpoint(TaskType.SPARSE_EMBEDDING, sparseEmbeddingInferenceId, SPARSE_EMBEDDING_SERVICE_SETTINGS); + createInferenceEndpoint(TaskType.TEXT_EMBEDDING, textEmbeddingInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS); + + final String sparseEmbeddingField = randomIdentifier(); + final String textEmbeddingField = randomIdentifier(); + + for (int i = 0; i < iterations; i++) { + final IndexVersion indexVersion = IndexVersionUtils.randomVersionBetween(random(), minIndexVersion, maxIndexVersion); + final Settings indexSettings = generateIndexSettings(indexVersion); + XContentBuilder mappings = generateMapping( + Map.of(sparseEmbeddingField, sparseEmbeddingInferenceId, textEmbeddingField, textEmbeddingInferenceId) + ); + assertAcked(prepareCreate(indexName).setSettings(indexSettings).setMapping(mappings)); + + final int docCount = randomIntBetween(10, 50); + indexDocuments(sparseEmbeddingField, docCount); + indexDocuments(textEmbeddingField, docCount); + + QueryBuilder sparseEmbeddingFieldQuery = new SemanticQueryBuilder(sparseEmbeddingField, randomAlphaOfLength(10)); + assertSearchResponse(sparseEmbeddingFieldQuery, indexSettings, docCount, request -> { + request.source().fetchSource(generateRandomFetchSourceContext()).fetchField(sparseEmbeddingField); + }, response -> { + for (SearchHit hit : response.getHits()) { + Map documentFields = hit.getDocumentFields(); + assertThat(documentFields.size(), is(1)); + assertThat(documentFields.containsKey(sparseEmbeddingField), is(true)); + } + }); + + QueryBuilder textEmbeddingFieldQuery = new SemanticQueryBuilder(textEmbeddingField, randomAlphaOfLength(10)); + assertSearchResponse(textEmbeddingFieldQuery, indexSettings, docCount, request -> { + request.source().fetchSource(generateRandomFetchSourceContext()).fetchField(textEmbeddingField); + }, response -> { + for (SearchHit hit : response.getHits()) { + Map documentFields = hit.getDocumentFields(); + assertThat(documentFields.size(), is(1)); + assertThat(documentFields.containsKey(textEmbeddingField), is(true)); + } + }); + + deleteIndex(indexName); + } + } + + private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map serviceSettings) throws IOException { + final String service = switch (taskType) { + case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; + case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; + default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); + }; + + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("service", service); + builder.field("service_settings", serviceSettings); + builder.endObject(); + + content = BytesReference.bytes(builder); + } + + PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( + taskType, + inferenceId, + content, + XContentType.JSON, + TEST_REQUEST_TIMEOUT + ); + var responseFuture = client().execute(PutInferenceModelAction.INSTANCE, request); + assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); + + inferenceIds.put(inferenceId, taskType); + } + + private Settings generateIndexSettings(IndexVersion indexVersion) { + int numDataNodes = internalCluster().numDataNodes(); + return Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, indexVersion) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numDataNodes) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .build(); + } + + private void indexDocuments(String field, int count) { + for (int i = 0; i < count; i++) { + Map source = Map.of(field, randomAlphaOfLength(10)); + DocWriteResponse response = client().prepareIndex(indexName).setSource(source).get(TEST_REQUEST_TIMEOUT); + assertThat(response.getResult(), is(DocWriteResponse.Result.CREATED)); + } + + client().admin().indices().prepareRefresh(indexName).get(); + } + + private void assertSearchResponse( + QueryBuilder queryBuilder, + Settings indexSettings, + int expectedHits, + @Nullable Consumer searchRequestModifier, + @Nullable Consumer searchResponseValidator + ) throws Exception { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(queryBuilder).size(expectedHits); + SearchRequest searchRequest = new SearchRequest(new String[] { indexName }, searchSourceBuilder); + if (searchRequestModifier != null) { + searchRequestModifier.accept(searchRequest); + } + + ExpectedSource expectedSource = getExpectedSource(indexSettings, searchRequest.source().fetchSource()); + assertResponse(client().search(searchRequest), response -> { + assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards())); + assertThat(response.getHits().getTotalHits().value(), equalTo((long) expectedHits)); + + for (SearchHit hit : response.getHits()) { + switch (expectedSource) { + case NONE -> assertThat(hit.getSourceAsMap(), nullValue()); + case INFERENCE_FIELDS_EXCLUDED -> { + Map sourceAsMap = hit.getSourceAsMap(); + assertThat(sourceAsMap, notNullValue()); + assertThat(sourceAsMap.containsKey(InferenceMetadataFieldsMapper.NAME), is(false)); + } + case INFERENCE_FIELDS_INCLUDED -> { + Map sourceAsMap = hit.getSourceAsMap(); + assertThat(sourceAsMap, notNullValue()); + assertThat(sourceAsMap.containsKey(InferenceMetadataFieldsMapper.NAME), is(true)); + } + } + } + + if (searchResponseValidator != null) { + searchResponseValidator.accept(response); + } + }); + } + + private static ExpectedSource getExpectedSource(Settings indexSettings, FetchSourceContext fetchSourceContext) { + if (fetchSourceContext != null && fetchSourceContext.fetchSource() == false) { + return ExpectedSource.NONE; + } else if (InferenceMetadataFieldsMapper.isEnabled(indexSettings) == false) { + return ExpectedSource.INFERENCE_FIELDS_EXCLUDED; + } + + if (fetchSourceContext != null) { + SourceFilter filter = fetchSourceContext.filter(); + if (filter != null) { + if (Arrays.asList(filter.getExcludes()).contains(InferenceMetadataFieldsMapper.NAME)) { + return ExpectedSource.INFERENCE_FIELDS_EXCLUDED; + } else if (filter.getIncludes().length > 0) { + return Arrays.asList(filter.getIncludes()).contains(InferenceMetadataFieldsMapper.NAME) + ? ExpectedSource.INFERENCE_FIELDS_INCLUDED + : ExpectedSource.INFERENCE_FIELDS_EXCLUDED; + } + } + + Boolean excludeInferenceFieldsExplicit = fetchSourceContext.excludeInferenceFields(); + if (excludeInferenceFieldsExplicit != null) { + return excludeInferenceFieldsExplicit ? ExpectedSource.INFERENCE_FIELDS_EXCLUDED : ExpectedSource.INFERENCE_FIELDS_INCLUDED; + } + } + + return ExpectedSource.INFERENCE_FIELDS_EXCLUDED; + } + + private static FetchSourceContext generateRandomFetchSourceContext() { + FetchSourceContext fetchSourceContext = switch (randomIntBetween(0, 4)) { + case 0 -> FetchSourceContext.FETCH_SOURCE; + case 1 -> FetchSourceContext.FETCH_ALL_SOURCE; + case 2 -> FetchSourceContext.FETCH_ALL_SOURCE_EXCLUDE_INFERENCE_FIELDS; + case 3 -> FetchSourceContext.DO_NOT_FETCH_SOURCE; + case 4 -> null; + default -> throw new IllegalStateException("Unhandled randomized case"); + }; + + if (fetchSourceContext != null && fetchSourceContext.fetchSource()) { + String[] includes = null; + String[] excludes = null; + if (randomBoolean()) { + // Randomly include a non-existent field to test explicit inclusion handling + String field = randomBoolean() ? InferenceMetadataFieldsMapper.NAME : randomIdentifier(); + includes = new String[] { field }; + } + if (randomBoolean()) { + // Randomly exclude a non-existent field to test implicit inclusion handling + String field = randomBoolean() ? InferenceMetadataFieldsMapper.NAME : randomIdentifier(); + excludes = new String[] { field }; + } + + if (includes != null || excludes != null) { + fetchSourceContext = FetchSourceContext.of( + fetchSourceContext.fetchSource(), + fetchSourceContext.excludeVectors(), + fetchSourceContext.excludeInferenceFields(), + includes, + excludes + ); + } + } + + return fetchSourceContext; + } + + private static XContentBuilder generateMapping(Map semanticTextFields) throws IOException { + XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject("properties"); + for (var entry : semanticTextFields.entrySet()) { + mapping.startObject(entry.getKey()); + mapping.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + mapping.field("inference_id", entry.getValue()); + mapping.endObject(); + } + mapping.endObject().endObject(); + + return mapping; + } + + private static void deleteIndex(String indexName) { + assertAcked( + safeGet( + client().admin() + .indices() + .prepareDelete(indexName) + .setIndicesOptions( + IndicesOptions.builder().concreteTargetOptions(new IndicesOptions.ConcreteTargetOptions(true)).build() + ) + .execute() + ) + ); + } + + private enum ExpectedSource { + NONE, + INFERENCE_FIELDS_EXCLUDED, + INFERENCE_FIELDS_INCLUDED + } + + public static class FakeMlPlugin extends Plugin { + @Override + public List getNamedWriteables() { + return new MlInferenceNamedXContentProvider().getNamedWriteables(); + } + } +}