diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index b2094e695a112..9a6cfb2bb3d8e 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -33,6 +33,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.settings.Settings; @@ -49,6 +50,7 @@ import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.mapper.DateFieldMapper; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.search.QueryParserHelper; import org.elasticsearch.index.seqno.SequenceNumbers; import org.elasticsearch.index.shard.IndexLongFieldRange; import org.elasticsearch.index.shard.ShardId; @@ -87,6 +89,7 @@ import static org.elasticsearch.cluster.node.DiscoveryNodeFilters.OpType.OR; import static org.elasticsearch.cluster.node.DiscoveryNodeFilters.validateIpValue; import static org.elasticsearch.common.settings.Settings.readSettingsFromStream; +import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; import static org.elasticsearch.snapshots.SearchableSnapshotsSettings.SEARCHABLE_SNAPSHOT_PARTIAL_SETTING_KEY; public class IndexMetadata implements Diffable, ToXContentFragment { @@ -1355,6 +1358,32 @@ public OptionalLong getForecastedShardSizeInBytes() { return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast); } + /** + * Get the inference fields that match the provided field pattern map. The matches are returned as a map where the key is the + * {@link InferenceFieldMetadata} for the matching inference field and the value is the effective weight of the field. + * If {@code useDefaultFields} is true and {@code fields} is empty, then the field pattern map will be derived from the value of + * {@link IndexSettings#DEFAULT_FIELD_SETTING} for the index. + * + * @param fields The field pattern map, where the key is the field pattern and the value is the pattern weight. + * @param resolveWildcards If {@code true}, wildcards in field patterns will be resolved. Otherwise, only explicit matches will be + * returned. + * @param useDefaultFields If {@code true}, default fields will be used if {@code fields} is empty. + * @return A map of inference field matches + */ + public Map getMatchingInferenceFields( + Map fields, + boolean resolveWildcards, + boolean useDefaultFields + ) { + Map effectiveFields = fields; + if (effectiveFields.isEmpty() && useDefaultFields) { + List defaultFields = settings.getAsList(DEFAULT_FIELD_SETTING.getKey(), DEFAULT_FIELD_SETTING.getDefault(settings)); + effectiveFields = QueryParserHelper.parseFieldsAndWeights(defaultFields); + } + + return getMatchingInferenceFields(inferenceFields, effectiveFields, resolveWildcards); + } + public static final String INDEX_RESIZE_SOURCE_UUID_KEY = "index.resize.source.uuid"; public static final String INDEX_RESIZE_SOURCE_NAME_KEY = "index.resize.source.name"; public static final Setting INDEX_RESIZE_SOURCE_UUID = Setting.simpleString(INDEX_RESIZE_SOURCE_UUID_KEY); @@ -3255,4 +3284,51 @@ public static int parseIndexNameCounter(String indexName) { throw new IllegalArgumentException("unable to parse the index name [" + indexName + "] to extract the counter", e); } } + + /** + * An overload of {@link #getMatchingInferenceFields(Map, boolean, boolean)}}, where the inference field metadata map to match against + * is provided by the caller. {@code useDefaultFields} is unavailable because the index's {@link IndexSettings#DEFAULT_FIELD_SETTING} is + * out of scope. + * + * @param inferenceFieldMetadataMap The inference field metadata map to match against. + * @param fieldMap The field pattern map, where the key is the field pattern and the value is the pattern weight. + * @param resolveWildcards If {@code true}, wildcards in field patterns will be resolved. Otherwise, only explicit matches will be + * returned. + * @return A map of inference field matches + */ + public static Map getMatchingInferenceFields( + Map inferenceFieldMetadataMap, + Map fieldMap, + boolean resolveWildcards + ) { + Map matches = new HashMap<>(); + for (var entry : fieldMap.entrySet()) { + String field = entry.getKey(); + Float weight = entry.getValue(); + + if (inferenceFieldMetadataMap.containsKey(field)) { + // No wildcards in field name + addToMatchingInferenceFieldsMap(matches, inferenceFieldMetadataMap.get(field), weight); + } else if (resolveWildcards) { + if (Regex.isMatchAllPattern(field)) { + inferenceFieldMetadataMap.values().forEach(ifm -> addToMatchingInferenceFieldsMap(matches, ifm, weight)); + } else if (Regex.isSimpleMatchPattern(field)) { + inferenceFieldMetadataMap.values() + .stream() + .filter(ifm -> Regex.simpleMatch(field, ifm.getName())) + .forEach(ifm -> addToMatchingInferenceFieldsMap(matches, ifm, weight)); + } + } + } + + return matches; + } + + private static void addToMatchingInferenceFieldsMap( + Map matches, + InferenceFieldMetadata inferenceFieldMetadata, + Float weight + ) { + matches.compute(inferenceFieldMetadata, (k, v) -> v == null ? weight : v * weight); + } } diff --git a/server/src/main/resources/transport/definitions/referable/get_inference_fields_action.csv b/server/src/main/resources/transport/definitions/referable/get_inference_fields_action.csv new file mode 100644 index 0000000000000..711a96c6acee7 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/get_inference_fields_action.csv @@ -0,0 +1 @@ +9220000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index afc3bb444e49d..28aaf9df2e51d 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -search_project_routing,9219000 +get_inference_fields_action,9220000 diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index d0aae463dd193..c543a44e4475c 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -57,6 +57,7 @@ import static org.elasticsearch.cluster.metadata.IndexMetadata.INDEX_HIDDEN_SETTING; import static org.elasticsearch.cluster.metadata.IndexMetadata.parseIndexNameCounter; import static org.elasticsearch.index.IndexModule.INDEX_STORE_TYPE_SETTING; +import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; import static org.elasticsearch.snapshots.SearchableSnapshotsSettings.SNAPSHOT_PARTIAL_SETTING; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -681,6 +682,104 @@ public void testInferenceFieldMetadata() { assertThat(idxMeta2.getInferenceFields(), equalTo(dynamicFields)); } + public void testGetMatchingInferenceFields() { + final String inferenceField1 = "inference-field-1"; + final String inferenceField2 = "inference-field-2"; + final String inferenceField3 = "inference-field-3"; + final Map inferenceFields = Map.of( + inferenceField1, + randomInferenceFieldMetadata(inferenceField1), + inferenceField2, + randomInferenceFieldMetadata(inferenceField2), + inferenceField3, + randomInferenceFieldMetadata(inferenceField3) + ); + + Settings.Builder settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); + IndexMetadata indexMetadata = IndexMetadata.builder("test").settings(settings).putInferenceFields(inferenceFields).build(); + + final Map fieldPatternMap = Map.of(inferenceField1, 1.5f, "inference-field-*", 2.0f, "*-field-3", 1.75f); + + // Explicit matches only + assertThat( + indexMetadata.getMatchingInferenceFields(fieldPatternMap, false, false), + equalTo(Map.of(inferenceFields.get(inferenceField1), 1.5f)) + ); + + // Resolve wildcards + assertThat( + indexMetadata.getMatchingInferenceFields(fieldPatternMap, true, false), + equalTo( + Map.of( + inferenceFields.get(inferenceField1), + 3.0f, + inferenceFields.get(inferenceField2), + 2.0f, + inferenceFields.get(inferenceField3), + 3.5f + ) + ) + ); + } + + public void testGetMatchingInferenceFieldsUsingDefaultFields() { + final String inferenceField1 = "inference-field-1"; + final String inferenceField2 = "inference-field-2"; + final String inferenceField3 = "inference-field-3"; + final Map inferenceFields = Map.of( + inferenceField1, + randomInferenceFieldMetadata(inferenceField1), + inferenceField2, + randomInferenceFieldMetadata(inferenceField2), + inferenceField3, + randomInferenceFieldMetadata(inferenceField3) + ); + + Settings.Builder index1Settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0).putList( + DEFAULT_FIELD_SETTING.getKey(), + List.of(inferenceField1 + "^1.5", "inference-field-*^2.0", "*-field-3^1.75") + ); + IndexMetadata index1Metadata = IndexMetadata.builder("test1").settings(index1Settings).putInferenceFields(inferenceFields).build(); + + Settings.Builder index2Settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); + IndexMetadata index2Metadata = IndexMetadata.builder("test2").settings(index2Settings).putInferenceFields(inferenceFields).build(); + + // Explicit matches only + assertThat( + index1Metadata.getMatchingInferenceFields(Map.of(), false, true), + equalTo(Map.of(inferenceFields.get(inferenceField1), 1.5f)) + ); + assertThat(index2Metadata.getMatchingInferenceFields(Map.of(), false, true), equalTo(Map.of())); + + // Resolve wildcards + assertThat( + index1Metadata.getMatchingInferenceFields(Map.of(), true, true), + equalTo( + Map.of( + inferenceFields.get(inferenceField1), + 3.0f, + inferenceFields.get(inferenceField2), + 2.0f, + inferenceFields.get(inferenceField3), + 3.5f + ) + ) + ); + assertThat( + index2Metadata.getMatchingInferenceFields(Map.of(), true, true), + equalTo( + Map.of( + inferenceFields.get(inferenceField1), + 1.0f, + inferenceFields.get(inferenceField2), + 1.0f, + inferenceFields.get(inferenceField3), + 1.0f + ) + ) + ); + } + public void testReshardingBWCSerialization() throws IOException { final int numShards = randomIntBetween(1, 8); final var settings = indexSettings(IndexVersion.current(), numShards, 0); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceFieldsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceFieldsAction.java new file mode 100644 index 0000000000000..ddd1cbce9aa6d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceFieldsAction.java @@ -0,0 +1,273 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.RemoteClusterActionType; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.inference.InferenceResults; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import static org.elasticsearch.action.ValidateActions.addValidationError; + +/** + *

+ * An internal action for getting inference fields for a set of indices and optionally, the inference results for those + * fields given a query. + *

+ *

+ * Note that this action is intended to be used to get inference fields for a remote cluster. Local cluster inference + * fields can be gathered more directly using {@link IndexMetadata#getMatchingInferenceFields}. + *

+ */ +public class GetInferenceFieldsAction extends ActionType { + public static final GetInferenceFieldsAction INSTANCE = new GetInferenceFieldsAction(); + public static final RemoteClusterActionType REMOTE_TYPE = new RemoteClusterActionType<>(INSTANCE.name(), Response::new); + + public static final TransportVersion GET_INFERENCE_FIELDS_ACTION_TV = TransportVersion.fromName("get_inference_fields_action"); + + public static final String NAME = "cluster:internal/xpack/inference/fields/get"; + + public GetInferenceFieldsAction() { + super(NAME); + } + + public static class Request extends ActionRequest { + private final Set indices; + private final Map fields; + private final boolean resolveWildcards; + private final boolean useDefaultFields; + private final String query; + private final IndicesOptions indicesOptions; + + /** + * An overload of {@link #Request(Set, Map, boolean, boolean, String, IndicesOptions)} that uses {@link IndicesOptions#DEFAULT} + */ + public Request( + Set indices, + Map fields, + boolean resolveWildcards, + boolean useDefaultFields, + @Nullable String query + ) { + this(indices, fields, resolveWildcards, useDefaultFields, query, null); + } + + /** + *

+ * Constructs a request to get inference fields. + *

+ *

+ * If {@code useDefaultFields} is true and {@code fields} is empty, then the field pattern map will be derived from the value of + * {@link IndexSettings#DEFAULT_FIELD_SETTING} for each index. + *

+ *

+ * If {@code query} is {@code null}, then no inference results will be generated. This can be useful in scenarios where the caller + * only needs to check for the existence of inference fields. + *

+ * + * @param indices The indices to get inference fields for. + * @param fields The field pattern map, where the key is the field pattern and the value is the pattern weight. + * @param resolveWildcards If {@code true}, wildcards in field patterns will be resolved. Otherwise, only explicit matches will be + * returned. + * @param useDefaultFields If {@code true}, default fields will be used if {@code fields} is empty. + * @param query The query to generate inference results for. + * @param indicesOptions The {@link IndicesOptions} to use when resolving indices. + */ + public Request( + Set indices, + Map fields, + boolean resolveWildcards, + boolean useDefaultFields, + @Nullable String query, + @Nullable IndicesOptions indicesOptions + ) { + this.indices = indices; + this.fields = fields; + this.resolveWildcards = resolveWildcards; + this.useDefaultFields = useDefaultFields; + this.query = query; + this.indicesOptions = indicesOptions == null ? IndicesOptions.DEFAULT : indicesOptions; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.indices = in.readCollectionAsSet(StreamInput::readString); + this.fields = in.readMap(StreamInput::readFloat); + this.resolveWildcards = in.readBoolean(); + this.useDefaultFields = in.readBoolean(); + this.query = in.readOptionalString(); + this.indicesOptions = IndicesOptions.readIndicesOptions(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeStringCollection(indices); + out.writeMap(fields, StreamOutput::writeFloat); + out.writeBoolean(resolveWildcards); + out.writeBoolean(useDefaultFields); + out.writeOptionalString(query); + indicesOptions.writeIndicesOptions(out); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (indices == null) { + validationException = addValidationError("indices must not be null", validationException); + } + + if (fields == null) { + validationException = addValidationError("fields must not be null", validationException); + } else { + for (var entry : fields.entrySet()) { + if (entry.getValue() == null) { + validationException = addValidationError( + "weight for field [" + entry.getKey() + "] must not be null", + validationException + ); + } + } + } + + return validationException; + } + + public Set getIndices() { + return Collections.unmodifiableSet(indices); + } + + public Map getFields() { + return Collections.unmodifiableMap(fields); + } + + public boolean resolveWildcards() { + return resolveWildcards; + } + + public boolean useDefaultFields() { + return useDefaultFields; + } + + public String getQuery() { + return query; + } + + public IndicesOptions getIndicesOptions() { + return indicesOptions; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(indices, request.indices) + && Objects.equals(fields, request.fields) + && resolveWildcards == request.resolveWildcards + && useDefaultFields == request.useDefaultFields + && Objects.equals(query, request.query) + && Objects.equals(indicesOptions, request.indicesOptions); + } + + @Override + public int hashCode() { + return Objects.hash(indices, fields, resolveWildcards, useDefaultFields, query, indicesOptions); + } + } + + /** + *

+ * A response containing an inference fields map and, if a query was specified in the {@link Request}, the inference + * results for those fields. + *

+ *

+ * The inference fields map key is a concrete index name. The value is a list of {@link ExtendedInferenceFieldMetadata}, + * representing the metadata for all matching inference fields in that index. + *

+ *

+ * The inference results map key is an inference ID. The value is the inference results from the inference endpoint + * that the inference ID resolves to. If no query was specified in the {@link Request}, this will be an empty map. + *

+ */ + public static class Response extends ActionResponse { + private final Map> inferenceFieldsMap; + private final Map inferenceResultsMap; + + public Response( + Map> inferenceFieldsMap, + Map inferenceResultsMap + ) { + this.inferenceFieldsMap = inferenceFieldsMap; + this.inferenceResultsMap = inferenceResultsMap; + } + + public Response(StreamInput in) throws IOException { + this.inferenceFieldsMap = in.readImmutableMap(i -> i.readCollectionAsImmutableList(ExtendedInferenceFieldMetadata::new)); + this.inferenceResultsMap = in.readImmutableMap(i -> i.readNamedWriteable(InferenceResults.class)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(inferenceFieldsMap, StreamOutput::writeCollection); + out.writeMap(inferenceResultsMap, StreamOutput::writeNamedWriteable); + } + + public Map> getInferenceFieldsMap() { + return Collections.unmodifiableMap(this.inferenceFieldsMap); + } + + public Map getInferenceResultsMap() { + return Collections.unmodifiableMap(this.inferenceResultsMap); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response response = (Response) o; + return Objects.equals(inferenceFieldsMap, response.inferenceFieldsMap) + && Objects.equals(inferenceResultsMap, response.inferenceResultsMap); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceFieldsMap, inferenceResultsMap); + } + } + + public record ExtendedInferenceFieldMetadata(InferenceFieldMetadata inferenceFieldMetadata, float weight) implements Writeable { + public ExtendedInferenceFieldMetadata(StreamInput in) throws IOException { + this(new InferenceFieldMetadata(in), in.readFloat()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeWriteable(inferenceFieldMetadata); + out.writeFloat(weight); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResultsTests.java index 20b2b4737c8b5..d469a241360c9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResultsTests.java @@ -18,6 +18,10 @@ public class ErrorInferenceResultsTests extends InferenceResultsTestCase { + public static ErrorInferenceResults createRandomResults() { + return new ErrorInferenceResults(new ElasticsearchStatusException(randomAlphaOfLength(8), randomFrom(RestStatus.values()))); + } + @Override protected Writeable.Reader instanceReader() { return ErrorInferenceResults::new; @@ -25,7 +29,7 @@ protected Writeable.Reader instanceReader() { @Override protected ErrorInferenceResults createTestInstance() { - return new ErrorInferenceResults(new ElasticsearchStatusException(randomAlphaOfLength(8), randomFrom(RestStatus.values()))); + return createRandomResults(); } @Override diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsCrossClusterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsCrossClusterIT.java new file mode 100644 index 0000000000000..fbb274c632c7a --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsCrossClusterIT.java @@ -0,0 +1,145 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.RemoteClusterClient; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.AbstractMultiClustersTestCase; +import org.elasticsearch.transport.RemoteClusterService; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.inference.FakeMlPlugin; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xpack.inference.integration.GetInferenceFieldsIT.assertInferenceFieldsMap; +import static org.elasticsearch.xpack.inference.integration.GetInferenceFieldsIT.assertInferenceResultsMap; +import static org.elasticsearch.xpack.inference.integration.GetInferenceFieldsIT.generateDefaultWeightFieldMap; +import static org.elasticsearch.xpack.inference.integration.IntegrationTestUtils.createInferenceEndpoint; +import static org.elasticsearch.xpack.inference.integration.IntegrationTestUtils.generateSemanticTextMapping; +import static org.hamcrest.Matchers.containsString; + +public class GetInferenceFieldsCrossClusterIT extends AbstractMultiClustersTestCase { + private static final String REMOTE_CLUSTER = "cluster_a"; + private static final String INDEX_NAME = "test-index"; + private static final String INFERENCE_FIELD = "test-inference-field"; + private static final String INFERENCE_ID = "test-inference-id"; + private static final Map INFERENCE_ENDPOINT_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key"); + + private boolean clustersConfigured = false; + + @Override + protected List remoteClusterAlias() { + return List.of(REMOTE_CLUSTER); + } + + @Override + protected Map skipUnavailableForRemoteClusters() { + return Map.of(REMOTE_CLUSTER, DEFAULT_SKIP_UNAVAILABLE); + } + + @Override + protected Settings nodeSettings() { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + @Override + protected Collection> nodePlugins(String clusterAlias) { + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, FakeMlPlugin.class); + } + + @Before + public void configureClusters() throws Exception { + if (clustersConfigured == false) { + setupTwoClusters(); + clustersConfigured = true; + } + } + + public void testRemoteIndex() { + Consumer assertFailedRequest = r -> { + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> client().execute(GetInferenceFieldsAction.INSTANCE, r).actionGet(TEST_REQUEST_TIMEOUT) + ); + assertThat(e.getMessage(), containsString("GetInferenceFieldsAction does not support remote indices")); + }; + + var concreteIndexRequest = new GetInferenceFieldsAction.Request( + Set.of(REMOTE_CLUSTER + ":test-index"), + Map.of(), + false, + false, + "foo" + ); + assertFailedRequest.accept(concreteIndexRequest); + + var wildcardIndexRequest = new GetInferenceFieldsAction.Request(Set.of(REMOTE_CLUSTER + ":*"), Map.of(), false, false, "foo"); + assertFailedRequest.accept(wildcardIndexRequest); + + var wildcardClusterAndIndexRequest = new GetInferenceFieldsAction.Request(Set.of("*:*"), Map.of(), false, false, "foo"); + assertFailedRequest.accept(wildcardClusterAndIndexRequest); + } + + public void testRemoteClusterAction() { + RemoteClusterClient remoteClusterClient = client().getRemoteClusterClient( + REMOTE_CLUSTER, + EsExecutors.DIRECT_EXECUTOR_SERVICE, + RemoteClusterService.DisconnectedStrategy.RECONNECT_IF_DISCONNECTED + ); + + var request = new GetInferenceFieldsAction.Request( + Set.of(INDEX_NAME), + generateDefaultWeightFieldMap(Set.of(INFERENCE_FIELD)), + false, + false, + "foo" + ); + PlainActionFuture future = new PlainActionFuture<>(); + remoteClusterClient.execute(GetInferenceFieldsAction.REMOTE_TYPE, request, future); + + var response = future.actionGet(TEST_REQUEST_TIMEOUT); + assertInferenceFieldsMap( + response.getInferenceFieldsMap(), + Map.of(INDEX_NAME, Set.of(new GetInferenceFieldsIT.InferenceFieldWithTestMetadata(INFERENCE_FIELD, INFERENCE_ID, 1.0f))) + ); + assertInferenceResultsMap(response.getInferenceResultsMap(), Map.of(INFERENCE_ID, TextExpansionResults.class)); + } + + private void setupTwoClusters() throws IOException { + setupCluster(LOCAL_CLUSTER); + setupCluster(REMOTE_CLUSTER); + } + + private void setupCluster(String clusterAlias) throws IOException { + final Client client = client(clusterAlias); + + createInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, INFERENCE_ID, INFERENCE_ENDPOINT_SERVICE_SETTINGS); + + int dataNodeCount = cluster(clusterAlias).numDataNodes(); + XContentBuilder mappings = generateSemanticTextMapping(Map.of(INFERENCE_FIELD, INFERENCE_ID)); + Settings indexSettings = indexSettings(randomIntBetween(1, dataNodeCount), 0).build(); + assertAcked(client.admin().indices().prepareCreate(INDEX_NAME).setSettings(indexSettings).setMapping(mappings)); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsIT.java new file mode 100644 index 0000000000000..f81efbef3b8e3 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsIT.java @@ -0,0 +1,561 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.mapper.TextFieldMapper; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.inference.FakeMlPlugin; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xpack.inference.integration.IntegrationTestUtils.createInferenceEndpoint; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE) +public class GetInferenceFieldsIT extends ESIntegTestCase { + private static final Map SPARSE_EMBEDDING_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key"); + private static final Map TEXT_EMBEDDING_SERVICE_SETTINGS = Map.of( + "model", + "my_model", + "dimensions", + 256, + "similarity", + "cosine", + "api_key", + "my_api_key" + ); + + private static final String SPARSE_EMBEDDING_INFERENCE_ID = "sparse-embedding-id"; + private static final String TEXT_EMBEDDING_INFERENCE_ID = "text-embedding-id"; + + private static final String INDEX_1 = "index-1"; + private static final String INDEX_2 = "index-2"; + private static final Set ALL_INDICES = Set.of(INDEX_1, INDEX_2); + private static final String INDEX_ALIAS = "index-alias"; + + private static final String INFERENCE_FIELD_1 = "inference-field-1"; + private static final String INFERENCE_FIELD_2 = "inference-field-2"; + private static final String INFERENCE_FIELD_3 = "inference-field-3"; + private static final String INFERENCE_FIELD_4 = "inference-field-4"; + private static final String TEXT_FIELD_1 = "text-field-1"; + private static final String TEXT_FIELD_2 = "text-field-2"; + private static final Map ALL_FIELDS = Collections.unmodifiableMap( + generateDefaultWeightFieldMap( + Set.of(INFERENCE_FIELD_1, INFERENCE_FIELD_2, INFERENCE_FIELD_3, INFERENCE_FIELD_4, TEXT_FIELD_1, TEXT_FIELD_2) + ) + ); + + private static final Set INDEX_1_EXPECTED_INFERENCE_FIELDS = Set.of( + new InferenceFieldWithTestMetadata(INFERENCE_FIELD_1, SPARSE_EMBEDDING_INFERENCE_ID, 1.0f), + new InferenceFieldWithTestMetadata(INFERENCE_FIELD_2, TEXT_EMBEDDING_INFERENCE_ID, 1.0f), + new InferenceFieldWithTestMetadata(INFERENCE_FIELD_3, SPARSE_EMBEDDING_INFERENCE_ID, 1.0f), + new InferenceFieldWithTestMetadata(INFERENCE_FIELD_4, TEXT_EMBEDDING_INFERENCE_ID, 1.0f) + ); + private static final Set INDEX_2_EXPECTED_INFERENCE_FIELDS = Set.of( + new InferenceFieldWithTestMetadata(INFERENCE_FIELD_1, TEXT_EMBEDDING_INFERENCE_ID, 1.0f), + new InferenceFieldWithTestMetadata(INFERENCE_FIELD_2, SPARSE_EMBEDDING_INFERENCE_ID, 1.0f), + new InferenceFieldWithTestMetadata(INFERENCE_FIELD_3, SPARSE_EMBEDDING_INFERENCE_ID, 1.0f), + new InferenceFieldWithTestMetadata(INFERENCE_FIELD_4, TEXT_EMBEDDING_INFERENCE_ID, 1.0f) + ); + + private static final Map> ALL_EXPECTED_INFERENCE_RESULTS = Map.of( + SPARSE_EMBEDDING_INFERENCE_ID, + TextExpansionResults.class, + TEXT_EMBEDDING_INFERENCE_ID, + MlDenseEmbeddingResults.class + ); + + private boolean clusterConfigured = false; + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + @Override + protected Collection> nodePlugins() { + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, FakeMlPlugin.class); + } + + @Before + public void setUpCluster() throws Exception { + if (clusterConfigured == false) { + createInferenceEndpoints(); + createTestIndices(); + clusterConfigured = true; + } + } + + public void testNullQuery() { + explicitIndicesAndFieldsTestCase(null); + } + + public void testNonNullQuery() { + explicitIndicesAndFieldsTestCase("foo"); + } + + public void testBlankQuery() { + explicitIndicesAndFieldsTestCase(" "); + } + + public void testFieldWeight() { + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request( + ALL_INDICES, + Map.of(INFERENCE_FIELD_1, 2.0f, "inference-*", 1.5f, TEXT_FIELD_1, 1.75f), + false, + false, + "foo" + ), + Map.of( + INDEX_1, + Set.of(new InferenceFieldWithTestMetadata(INFERENCE_FIELD_1, SPARSE_EMBEDDING_INFERENCE_ID, 2.0f)), + INDEX_2, + Set.of(new InferenceFieldWithTestMetadata(INFERENCE_FIELD_1, TEXT_EMBEDDING_INFERENCE_ID, 2.0f)) + ), + ALL_EXPECTED_INFERENCE_RESULTS + ); + } + + public void testNoInferenceFields() { + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request( + ALL_INDICES, + generateDefaultWeightFieldMap(Set.of(TEXT_FIELD_1, TEXT_FIELD_2)), + false, + false, + "foo" + ), + Map.of(INDEX_1, Set.of(), INDEX_2, Set.of()), + Map.of() + ); + } + + public void testResolveFieldWildcards() { + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(ALL_INDICES, generateDefaultWeightFieldMap(Set.of("*")), true, false, "foo"), + Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS, INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), + ALL_EXPECTED_INFERENCE_RESULTS + ); + + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request( + ALL_INDICES, + Map.of("*-field-1", 2.0f, "*-1", 1.75f, "inference-*-3", 2.0f), + true, + false, + "foo" + ), + Map.of( + INDEX_1, + Set.of( + new InferenceFieldWithTestMetadata(INFERENCE_FIELD_1, SPARSE_EMBEDDING_INFERENCE_ID, 3.5f), + new InferenceFieldWithTestMetadata(INFERENCE_FIELD_3, SPARSE_EMBEDDING_INFERENCE_ID, 2.0f) + ), + INDEX_2, + Set.of( + new InferenceFieldWithTestMetadata(INFERENCE_FIELD_1, TEXT_EMBEDDING_INFERENCE_ID, 3.5f), + new InferenceFieldWithTestMetadata(INFERENCE_FIELD_3, SPARSE_EMBEDDING_INFERENCE_ID, 2.0f) + ) + ), + ALL_EXPECTED_INFERENCE_RESULTS + ); + } + + public void testUseDefaultFields() { + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(Set.of(INDEX_1), Map.of(), true, true, "foo"), + Map.of(INDEX_1, Set.of(new InferenceFieldWithTestMetadata(INFERENCE_FIELD_1, SPARSE_EMBEDDING_INFERENCE_ID, 5.0f))), + filterExpectedInferenceResults(ALL_EXPECTED_INFERENCE_RESULTS, Set.of(SPARSE_EMBEDDING_INFERENCE_ID)) + ); + + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(Set.of(INDEX_2), Map.of(), true, true, "foo"), + Map.of(INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), + ALL_EXPECTED_INFERENCE_RESULTS + ); + } + + public void testMissingIndexName() { + Set indicesWithIndex1 = Set.of(INDEX_1, "missing-index"); + assertFailedRequest( + new GetInferenceFieldsAction.Request(indicesWithIndex1, ALL_FIELDS, false, false, "foo"), + IndexNotFoundException.class, + e -> assertThat(e.getMessage(), containsString("no such index [missing-index]")) + ); + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(indicesWithIndex1, ALL_FIELDS, false, false, "foo", IndicesOptions.LENIENT_EXPAND_OPEN), + Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS), + ALL_EXPECTED_INFERENCE_RESULTS + ); + + Set indicesWithoutIndex1 = Set.of("missing-index"); + assertFailedRequest( + new GetInferenceFieldsAction.Request(indicesWithoutIndex1, ALL_FIELDS, false, false, "foo"), + IndexNotFoundException.class, + e -> assertThat(e.getMessage(), containsString("no such index [missing-index]")) + ); + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(indicesWithoutIndex1, ALL_FIELDS, false, false, "foo", IndicesOptions.LENIENT_EXPAND_OPEN), + Map.of(), + Map.of() + ); + } + + public void testMissingFieldName() { + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(ALL_INDICES, generateDefaultWeightFieldMap(Set.of("missing-field")), false, false, "foo"), + Map.of(INDEX_1, Set.of(), INDEX_2, Set.of()), + Map.of() + ); + + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(ALL_INDICES, generateDefaultWeightFieldMap(Set.of("missing-*")), true, false, "foo"), + Map.of(INDEX_1, Set.of(), INDEX_2, Set.of()), + Map.of() + ); + } + + public void testNoIndices() { + // By default, an empty index set will be interpreted as _all + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(Set.of(), ALL_FIELDS, false, false, "foo"), + Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS, INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), + ALL_EXPECTED_INFERENCE_RESULTS + ); + + // We can provide an IndicesOptions that changes this behavior to interpret an empty index set as no indices + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(Set.of(), ALL_FIELDS, false, false, "foo", IndicesOptions.STRICT_NO_EXPAND_FORBID_CLOSED), + Map.of(), + Map.of() + ); + } + + public void testAllIndices() { + // By default, _all expands to all indices + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(Set.of("_all"), ALL_FIELDS, false, false, "foo"), + Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS, INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), + ALL_EXPECTED_INFERENCE_RESULTS + ); + + // We can provide an IndicesOptions that changes this behavior to interpret it as no indices + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request( + Set.of("_all"), + ALL_FIELDS, + false, + false, + "foo", + IndicesOptions.STRICT_NO_EXPAND_FORBID_CLOSED + ), + Map.of(), + Map.of() + ); + } + + public void testIndexAlias() { + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(Set.of(INDEX_ALIAS), ALL_FIELDS, false, false, "foo"), + Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS, INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), + ALL_EXPECTED_INFERENCE_RESULTS + ); + } + + public void testResolveIndexWildcards() { + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(Set.of("index-*"), ALL_FIELDS, false, false, "foo"), + Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS, INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), + ALL_EXPECTED_INFERENCE_RESULTS + ); + + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(Set.of("*-1"), ALL_FIELDS, false, false, "foo"), + Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS), + ALL_EXPECTED_INFERENCE_RESULTS + ); + + assertFailedRequest( + new GetInferenceFieldsAction.Request( + Set.of("index-*"), + ALL_FIELDS, + false, + false, + "foo", + IndicesOptions.STRICT_NO_EXPAND_FORBID_CLOSED + ), + IndexNotFoundException.class, + e -> assertThat(e.getMessage(), containsString("no such index [index-*]")) + ); + } + + public void testNoFields() { + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(ALL_INDICES, Map.of(), false, false, "foo"), + Map.of(INDEX_1, Set.of(), INDEX_2, Set.of()), + Map.of() + ); + } + + public void testInvalidRequest() { + final BiConsumer> validator = (e, l) -> l.forEach( + s -> assertThat(e.getMessage(), containsString(s)) + ); + + assertFailedRequest( + new GetInferenceFieldsAction.Request(null, Map.of(), false, false, null), + ActionRequestValidationException.class, + e -> validator.accept(e, List.of("indices must not be null")) + ); + assertFailedRequest( + new GetInferenceFieldsAction.Request(Set.of(), null, false, false, null), + ActionRequestValidationException.class, + e -> validator.accept(e, List.of("fields must not be null")) + ); + assertFailedRequest( + new GetInferenceFieldsAction.Request(null, null, false, false, null), + ActionRequestValidationException.class, + e -> validator.accept(e, List.of("indices must not be null", "fields must not be null")) + ); + + Map fields = new HashMap<>(); + fields.put(INFERENCE_FIELD_1, null); + assertFailedRequest( + new GetInferenceFieldsAction.Request(Set.of(), fields, false, false, null), + ActionRequestValidationException.class, + e -> validator.accept(e, List.of("weight for field [" + INFERENCE_FIELD_1 + "] must not be null")) + ); + } + + private void explicitIndicesAndFieldsTestCase(String query) { + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(ALL_INDICES, ALL_FIELDS, false, false, query), + Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS, INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), + query == null || query.isBlank() ? Map.of() : ALL_EXPECTED_INFERENCE_RESULTS + ); + + Map> expectedInferenceResultsSparseOnly = filterExpectedInferenceResults( + ALL_EXPECTED_INFERENCE_RESULTS, + Set.of(SPARSE_EMBEDDING_INFERENCE_ID) + ); + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request( + ALL_INDICES, + generateDefaultWeightFieldMap(Set.of(INFERENCE_FIELD_3)), + false, + false, + query + ), + Map.of( + INDEX_1, + filterExpectedInferenceFieldSet(INDEX_1_EXPECTED_INFERENCE_FIELDS, Set.of(INFERENCE_FIELD_3)), + INDEX_2, + filterExpectedInferenceFieldSet(INDEX_2_EXPECTED_INFERENCE_FIELDS, Set.of(INFERENCE_FIELD_3)) + ), + query == null || query.isBlank() ? Map.of() : expectedInferenceResultsSparseOnly + ); + + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request( + Set.of(INDEX_1), + generateDefaultWeightFieldMap(Set.of(INFERENCE_FIELD_3)), + false, + false, + query + ), + Map.of(INDEX_1, filterExpectedInferenceFieldSet(INDEX_1_EXPECTED_INFERENCE_FIELDS, Set.of(INFERENCE_FIELD_3))), + query == null || query.isBlank() ? Map.of() : expectedInferenceResultsSparseOnly + ); + + assertSuccessfulRequest( + new GetInferenceFieldsAction.Request(ALL_INDICES, generateDefaultWeightFieldMap(Set.of("*")), false, false, query), + Map.of(INDEX_1, Set.of(), INDEX_2, Set.of()), + Map.of() + ); + } + + private void createInferenceEndpoints() throws IOException { + createInferenceEndpoint(client(), TaskType.SPARSE_EMBEDDING, SPARSE_EMBEDDING_INFERENCE_ID, SPARSE_EMBEDDING_SERVICE_SETTINGS); + createInferenceEndpoint(client(), TaskType.TEXT_EMBEDDING, TEXT_EMBEDDING_INFERENCE_ID, TEXT_EMBEDDING_SERVICE_SETTINGS); + } + + private void createTestIndices() throws IOException { + createTestIndex(INDEX_1, List.of("*-field-1^5")); + createTestIndex(INDEX_2, null); + assertAcked( + indicesAdmin().prepareAliases(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) + .addAlias(new String[] { INDEX_1, INDEX_2 }, INDEX_ALIAS) + ); + } + + private void createTestIndex(String indexName, @Nullable List defaultFields) throws IOException { + final String inferenceField1InferenceId = switch (indexName) { + case INDEX_1 -> SPARSE_EMBEDDING_INFERENCE_ID; + case INDEX_2 -> TEXT_EMBEDDING_INFERENCE_ID; + default -> throw new AssertionError("Unhandled index name [" + indexName + "]"); + }; + final String inferenceField2InferenceId = switch (indexName) { + case INDEX_1 -> TEXT_EMBEDDING_INFERENCE_ID; + case INDEX_2 -> SPARSE_EMBEDDING_INFERENCE_ID; + default -> throw new AssertionError("Unhandled index name [" + indexName + "]"); + }; + + XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject("properties"); + addSemanticTextField(INFERENCE_FIELD_1, inferenceField1InferenceId, mapping); + addSemanticTextField(INFERENCE_FIELD_2, inferenceField2InferenceId, mapping); + addSemanticTextField(INFERENCE_FIELD_3, SPARSE_EMBEDDING_INFERENCE_ID, mapping); + addSemanticTextField(INFERENCE_FIELD_4, TEXT_EMBEDDING_INFERENCE_ID, mapping); + addTextField(TEXT_FIELD_1, mapping); + addTextField(TEXT_FIELD_2, mapping); + mapping.endObject().endObject(); + + var createIndexRequest = prepareCreate(indexName).setMapping(mapping); + if (defaultFields != null) { + Settings settings = Settings.builder().putList(DEFAULT_FIELD_SETTING.getKey(), defaultFields).build(); + createIndexRequest.setSettings(settings); + } + assertAcked(createIndexRequest); + } + + private void addSemanticTextField(String fieldName, String inferenceId, XContentBuilder mapping) throws IOException { + mapping.startObject(fieldName); + mapping.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + mapping.field("inference_id", inferenceId); + mapping.endObject(); + } + + private void addTextField(String fieldName, XContentBuilder mapping) throws IOException { + mapping.startObject(fieldName); + mapping.field("type", TextFieldMapper.CONTENT_TYPE); + mapping.endObject(); + } + + private static GetInferenceFieldsAction.Response executeRequest(GetInferenceFieldsAction.Request request) { + return client().execute(GetInferenceFieldsAction.INSTANCE, request).actionGet(TEST_REQUEST_TIMEOUT); + } + + private static void assertSuccessfulRequest( + GetInferenceFieldsAction.Request request, + Map> expectedInferenceFields, + Map> expectedInferenceResults + ) { + var response = executeRequest(request); + assertInferenceFieldsMap(response.getInferenceFieldsMap(), expectedInferenceFields); + assertInferenceResultsMap(response.getInferenceResultsMap(), expectedInferenceResults); + } + + private static void assertFailedRequest( + GetInferenceFieldsAction.Request request, + Class expectedException, + Consumer exceptionValidator + ) { + T exception = assertThrows(expectedException, () -> executeRequest(request)); + exceptionValidator.accept(exception); + } + + static void assertInferenceFieldsMap( + Map> inferenceFieldsMap, + Map> expectedInferenceFields + ) { + assertThat(inferenceFieldsMap.size(), equalTo(expectedInferenceFields.size())); + for (var entry : inferenceFieldsMap.entrySet()) { + String indexName = entry.getKey(); + List indexInferenceFields = entry.getValue(); + + Set expectedIndexInferenceFields = expectedInferenceFields.get(indexName); + assertThat(expectedIndexInferenceFields, notNullValue()); + + Set remainingExpectedIndexInferenceFields = new HashSet<>(expectedIndexInferenceFields); + for (var indexInferenceField : indexInferenceFields) { + InferenceFieldWithTestMetadata inferenceFieldWithTestMetadata = new InferenceFieldWithTestMetadata( + indexInferenceField.inferenceFieldMetadata().getName(), + indexInferenceField.inferenceFieldMetadata().getSearchInferenceId(), + indexInferenceField.weight() + ); + assertThat(remainingExpectedIndexInferenceFields.remove(inferenceFieldWithTestMetadata), is(true)); + } + assertThat(remainingExpectedIndexInferenceFields, empty()); + } + } + + static void assertInferenceResultsMap( + Map inferenceResultsMap, + Map> expectedInferenceResults + ) { + assertThat(inferenceResultsMap.size(), equalTo(expectedInferenceResults.size())); + for (var entry : inferenceResultsMap.entrySet()) { + String inferenceId = entry.getKey(); + InferenceResults inferenceResults = entry.getValue(); + + Class expectedInferenceResultsClass = expectedInferenceResults.get(inferenceId); + assertThat(expectedInferenceResultsClass, notNullValue()); + assertThat(inferenceResults, instanceOf(expectedInferenceResultsClass)); + } + } + + static Map generateDefaultWeightFieldMap(Set fieldList) { + Map fieldMap = new HashMap<>(); + fieldList.forEach(field -> fieldMap.put(field, 1.0f)); + return fieldMap; + } + + private static Set filterExpectedInferenceFieldSet( + Set inferenceFieldSet, + Set fieldNames + ) { + return inferenceFieldSet.stream().filter(i -> fieldNames.contains(i.field())).collect(Collectors.toSet()); + } + + private static Map> filterExpectedInferenceResults( + Map> expectedInferenceResults, + Set inferenceIds + ) { + return expectedInferenceResults.entrySet() + .stream() + .filter(e -> inferenceIds.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + record InferenceFieldWithTestMetadata(String field, String inferenceId, float weight) {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index f5244f2ca66a6..810421580425c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -66,6 +66,7 @@ import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; import org.elasticsearch.xpack.core.inference.action.GetCCMConfigurationAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction; @@ -81,6 +82,7 @@ import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceEndpointAction; import org.elasticsearch.xpack.inference.action.TransportGetCCMConfigurationAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction; +import org.elasticsearch.xpack.inference.action.TransportGetInferenceFieldsAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceServicesAction; import org.elasticsearch.xpack.inference.action.TransportGetRerankerWindowSizeAction; @@ -278,7 +280,8 @@ public List getActions() { new ActionHandler(GetCCMConfigurationAction.INSTANCE, TransportGetCCMConfigurationAction.class), new ActionHandler(PutCCMConfigurationAction.INSTANCE, TransportPutCCMConfigurationAction.class), new ActionHandler(DeleteCCMConfigurationAction.INSTANCE, TransportDeleteCCMConfigurationAction.class), - new ActionHandler(CCMCache.ClearCCMCacheAction.INSTANCE, CCMCache.ClearCCMCacheAction.class) + new ActionHandler(CCMCache.ClearCCMCacheAction.INSTANCE, CCMCache.ClearCCMCacheAction.class), + new ActionHandler(GetInferenceFieldsAction.INSTANCE, TransportGetInferenceFieldsAction.class) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceFieldsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceFieldsAction.java new file mode 100644 index 0000000000000..a985920924108 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceFieldsAction.java @@ -0,0 +1,225 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ProjectState; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.cluster.project.ProjectResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.RemoteClusterAware; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class TransportGetInferenceFieldsAction extends HandledTransportAction< + GetInferenceFieldsAction.Request, + GetInferenceFieldsAction.Response> { + + private final TransportService transportService; + private final ClusterService clusterService; + private final ProjectResolver projectResolver; + private final IndexNameExpressionResolver indexNameExpressionResolver; + private final Client client; + + @Inject + public TransportGetInferenceFieldsAction( + TransportService transportService, + ActionFilters actionFilters, + ClusterService clusterService, + ProjectResolver projectResolver, + IndexNameExpressionResolver indexNameExpressionResolver, + Client client + ) { + super( + GetInferenceFieldsAction.NAME, + transportService, + actionFilters, + GetInferenceFieldsAction.Request::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + this.transportService = transportService; + this.clusterService = clusterService; + this.projectResolver = projectResolver; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.client = client; + } + + @Override + protected void doExecute( + Task task, + GetInferenceFieldsAction.Request request, + ActionListener listener + ) { + final Set indices = request.getIndices(); + final Map fields = request.getFields(); + final boolean resolveWildcards = request.resolveWildcards(); + final boolean useDefaultFields = request.useDefaultFields(); + final String query = request.getQuery(); + final IndicesOptions indicesOptions = request.getIndicesOptions(); + + try { + Map groupedIndices = transportService.getRemoteClusterService() + .groupIndices(indicesOptions, indices.toArray(new String[0]), true); + OriginalIndices localIndices = groupedIndices.remove(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); + if (groupedIndices.isEmpty() == false) { + throw new IllegalArgumentException("GetInferenceFieldsAction does not support remote indices"); + } + + ProjectState projectState = projectResolver.getProjectState(clusterService.state()); + String[] concreteLocalIndices = indexNameExpressionResolver.concreteIndexNames(projectState.metadata(), localIndices); + + Map> inferenceFieldsMap = new HashMap<>( + concreteLocalIndices.length + ); + Arrays.stream(concreteLocalIndices).forEach(index -> { + List inferenceFieldMetadataList = getInferenceFieldMetadata( + index, + fields, + resolveWildcards, + useDefaultFields + ); + inferenceFieldsMap.put(index, inferenceFieldMetadataList); + }); + + if (query != null && query.isBlank() == false) { + Set inferenceIds = inferenceFieldsMap.values() + .stream() + .flatMap(List::stream) + .map(eifm -> eifm.inferenceFieldMetadata().getSearchInferenceId()) + .collect(Collectors.toSet()); + + getInferenceResults(query, inferenceIds, inferenceFieldsMap, listener); + } else { + listener.onResponse(new GetInferenceFieldsAction.Response(inferenceFieldsMap, Map.of())); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + private List getInferenceFieldMetadata( + String index, + Map fields, + boolean resolveWildcards, + boolean useDefaultFields + ) { + IndexMetadata indexMetadata = projectResolver.getProjectMetadata(clusterService.state()).indices().get(index); + if (indexMetadata == null) { + throw new IndexNotFoundException(index); + } + + Map matchingInferenceFieldMap = indexMetadata.getMatchingInferenceFields( + fields, + resolveWildcards, + useDefaultFields + ); + return matchingInferenceFieldMap.entrySet() + .stream() + .map(e -> new GetInferenceFieldsAction.ExtendedInferenceFieldMetadata(e.getKey(), e.getValue())) + .toList(); + } + + private void getInferenceResults( + String query, + Set inferenceIds, + Map> inferenceFieldsMap, + ActionListener listener + ) { + if (inferenceIds.isEmpty()) { + listener.onResponse(new GetInferenceFieldsAction.Response(inferenceFieldsMap, Map.of())); + return; + } + + GroupedActionListener> gal = new GroupedActionListener<>( + inferenceIds.size(), + listener.delegateFailureAndWrap((l, c) -> { + Map inferenceResultsMap = new HashMap<>(inferenceIds.size()); + c.forEach(t -> inferenceResultsMap.put(t.v1(), t.v2())); + + GetInferenceFieldsAction.Response response = new GetInferenceFieldsAction.Response(inferenceFieldsMap, inferenceResultsMap); + l.onResponse(response); + }) + ); + + List inferenceRequests = inferenceIds.stream() + .map( + i -> new InferenceAction.Request( + TaskType.ANY, + i, + null, + null, + null, + List.of(query), + Map.of(), + InputType.INTERNAL_SEARCH, + null, + false + ) + ) + .toList(); + + inferenceRequests.forEach( + request -> executeAsyncWithOrigin(client, ML_ORIGIN, InferenceAction.INSTANCE, request, gal.delegateFailureAndWrap((l, r) -> { + String inferenceId = request.getInferenceEntityId(); + InferenceResults inferenceResults = validateAndConvertInferenceResults(r.getResults(), inferenceId); + l.onResponse(Tuple.tuple(inferenceId, inferenceResults)); + })) + ); + } + + private static InferenceResults validateAndConvertInferenceResults( + InferenceServiceResults inferenceServiceResults, + String inferenceId + ) { + List inferenceResultsList = inferenceServiceResults.transformToCoordinationFormat(); + if (inferenceResultsList.isEmpty()) { + return new ErrorInferenceResults( + new IllegalArgumentException("No inference results retrieved for inference ID [" + inferenceId + "]") + ); + } else if (inferenceResultsList.size() > 1) { + // We don't chunk queries, so there should always be one inference result. + // Thus, if we receive more than one inference result, it is a server-side error. + return new ErrorInferenceResults( + new IllegalStateException( + inferenceResultsList.size() + " inference results retrieved for inference ID [" + inferenceId + "]" + ) + ); + } + + return inferenceResultsList.getFirst(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java index 89fbf94f2f0de..8267643108bcf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java @@ -38,7 +38,9 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; +import static org.elasticsearch.cluster.metadata.IndexMetadata.getMatchingInferenceFields; import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT; @@ -431,30 +433,14 @@ private static Map getInferenceFieldsMap( Map queryFields, boolean resolveWildcards ) { - Map inferenceFieldsToQuery = new HashMap<>(); Map indexInferenceFields = indexMetadataContext.getMappingLookup().inferenceFields(); - for (Map.Entry entry : queryFields.entrySet()) { - String queryField = entry.getKey(); - Float weight = entry.getValue(); - - if (indexInferenceFields.containsKey(queryField)) { - // No wildcards in field name - addToInferenceFieldsMap(inferenceFieldsToQuery, queryField, weight); - continue; - } - if (resolveWildcards) { - if (Regex.isMatchAllPattern(queryField)) { - indexInferenceFields.keySet().forEach(f -> addToInferenceFieldsMap(inferenceFieldsToQuery, f, weight)); - } else if (Regex.isSimpleMatchPattern(queryField)) { - indexInferenceFields.keySet() - .stream() - .filter(f -> Regex.simpleMatch(queryField, f)) - .forEach(f -> addToInferenceFieldsMap(inferenceFieldsToQuery, f, weight)); - } - } - } + Map matchingInferenceFields = getMatchingInferenceFields( + indexInferenceFields, + queryFields, + resolveWildcards + ); - return inferenceFieldsToQuery; + return matchingInferenceFields.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().getName(), Map.Entry::getValue)); } private static Map getDefaultFields(Settings settings) { @@ -462,10 +448,6 @@ private static Map getDefaultFields(Settings settings) { return QueryParserHelper.parseFieldsAndWeights(defaultFieldsList); } - private static void addToInferenceFieldsMap(Map inferenceFields, String field, Float weight) { - inferenceFields.compute(field, (k, v) -> v == null ? weight : v * weight); - } - private static void inferenceResultsErrorCheck(Map inferenceResultsMap) { for (var entry : inferenceResultsMap.entrySet()) { String inferenceId = entry.getKey().inferenceId(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsActionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsActionRequestTests.java new file mode 100644 index 0000000000000..bb62eb11d6d1e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsActionRequestTests.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.TransportVersionUtils; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.Collection; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction.GET_INFERENCE_FIELDS_ACTION_TV; + +public class GetInferenceFieldsActionRequestTests extends AbstractBWCWireSerializationTestCase { + @Override + protected Writeable.Reader instanceReader() { + return GetInferenceFieldsAction.Request::new; + } + + @Override + protected GetInferenceFieldsAction.Request createTestInstance() { + return new GetInferenceFieldsAction.Request( + randomIndentifierSet(), + randomFields(), + randomBoolean(), + randomBoolean(), + randomQuery(), + randomIndicesOptions() + ); + } + + @Override + protected GetInferenceFieldsAction.Request mutateInstance(GetInferenceFieldsAction.Request instance) throws IOException { + return switch (between(0, 5)) { + case 0 -> new GetInferenceFieldsAction.Request( + randomValueOtherThan(instance.getIndices(), GetInferenceFieldsActionRequestTests::randomIndentifierSet), + instance.getFields(), + instance.resolveWildcards(), + instance.useDefaultFields(), + instance.getQuery(), + instance.getIndicesOptions() + ); + case 1 -> new GetInferenceFieldsAction.Request( + instance.getIndices(), + randomValueOtherThan(instance.getFields(), GetInferenceFieldsActionRequestTests::randomFields), + instance.resolveWildcards(), + instance.useDefaultFields(), + instance.getQuery(), + instance.getIndicesOptions() + ); + case 2 -> new GetInferenceFieldsAction.Request( + instance.getIndices(), + instance.getFields(), + randomValueOtherThan(instance.resolveWildcards(), ESTestCase::randomBoolean), + instance.useDefaultFields(), + instance.getQuery(), + instance.getIndicesOptions() + ); + case 3 -> new GetInferenceFieldsAction.Request( + instance.getIndices(), + instance.getFields(), + instance.resolveWildcards(), + randomValueOtherThan(instance.useDefaultFields(), ESTestCase::randomBoolean), + instance.getQuery(), + instance.getIndicesOptions() + ); + case 4 -> new GetInferenceFieldsAction.Request( + instance.getIndices(), + instance.getFields(), + instance.resolveWildcards(), + instance.useDefaultFields(), + randomValueOtherThan(instance.getQuery(), GetInferenceFieldsActionRequestTests::randomQuery), + instance.getIndicesOptions() + ); + case 5 -> new GetInferenceFieldsAction.Request( + instance.getIndices(), + instance.getFields(), + instance.resolveWildcards(), + instance.useDefaultFields(), + instance.getQuery(), + randomValueOtherThan(instance.getIndicesOptions(), GetInferenceFieldsActionRequestTests::randomIndicesOptions) + ); + default -> throw new AssertionError("Invalid value"); + }; + } + + @Override + protected Collection bwcVersions() { + TransportVersion minVersion = TransportVersion.max(TransportVersion.minimumCompatible(), GET_INFERENCE_FIELDS_ACTION_TV); + return TransportVersionUtils.allReleasedVersions().tailSet(minVersion, true); + } + + @Override + protected GetInferenceFieldsAction.Request mutateInstanceForVersion( + GetInferenceFieldsAction.Request instance, + TransportVersion version + ) { + return instance; + } + + private static Set randomIndentifierSet() { + return randomSet(0, 5, ESTestCase::randomIdentifier); + } + + private static Map randomFields() { + return randomMap(0, 5, () -> Tuple.tuple(randomIdentifier(), randomFloat())); + } + + private static String randomQuery() { + return randomBoolean() ? randomAlphaOfLengthBetween(5, 10) : null; + } + + private static IndicesOptions randomIndicesOptions() { + // This isn't an exhaustive list of possible indices options, but there are enough for effective serialization tests. + // Omit IndicesOptions.strictExpandOpen() because it is equal to IndicesOptions#DEFAULT, which we use in the null case. + return switch (between(0, 8)) { + case 0 -> null; + case 1 -> IndicesOptions.strictExpandOpenFailureNoSelectors(); + case 2 -> IndicesOptions.strictExpandOpenAndForbidClosed(); + case 3 -> IndicesOptions.strictExpandOpenAndForbidClosedIgnoreThrottled(); + case 4 -> IndicesOptions.strictExpand(); + case 5 -> IndicesOptions.strictExpandHidden(); + case 6 -> IndicesOptions.strictExpandHiddenNoSelectors(); + case 7 -> IndicesOptions.strictExpandHiddenFailureNoSelectors(); + case 8 -> IndicesOptions.strictNoExpandForbidClosed(); + default -> throw new AssertionError("Invalid value"); + }; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsActionResponseTests.java new file mode 100644 index 0000000000000..9b2d4d86ef880 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsActionResponseTests.java @@ -0,0 +1,128 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.TransportVersionUtils; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResultsTests; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.cluster.metadata.InferenceFieldMetadataTests.generateRandomChunkingSettings; +import static org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction.GET_INFERENCE_FIELDS_ACTION_TV; + +public class GetInferenceFieldsActionResponseTests extends AbstractBWCWireSerializationTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); + } + + @Override + protected Writeable.Reader instanceReader() { + return GetInferenceFieldsAction.Response::new; + } + + @Override + protected GetInferenceFieldsAction.Response createTestInstance() { + return new GetInferenceFieldsAction.Response(randomInferenceFieldsMap(), randomInferenceResultsMap()); + } + + @Override + protected GetInferenceFieldsAction.Response mutateInstance(GetInferenceFieldsAction.Response instance) throws IOException { + return switch (between(0, 1)) { + case 0 -> new GetInferenceFieldsAction.Response( + randomValueOtherThan(instance.getInferenceFieldsMap(), GetInferenceFieldsActionResponseTests::randomInferenceFieldsMap), + instance.getInferenceResultsMap() + ); + case 1 -> new GetInferenceFieldsAction.Response( + instance.getInferenceFieldsMap(), + randomValueOtherThan(instance.getInferenceResultsMap(), GetInferenceFieldsActionResponseTests::randomInferenceResultsMap) + ); + default -> throw new AssertionError("Invalid value"); + }; + } + + private static Map> randomInferenceFieldsMap() { + Map> map = new HashMap<>(); + int numIndices = randomIntBetween(0, 5); + for (int i = 0; i < numIndices; i++) { + String indexName = randomIdentifier(); + List fields = new ArrayList<>(); + int numFields = randomIntBetween(0, 5); + for (int j = 0; j < numFields; j++) { + fields.add(randomeExtendedInferenceFieldMetadata()); + } + map.put(indexName, fields); + } + return map; + } + + @Override + protected Collection bwcVersions() { + TransportVersion minVersion = TransportVersion.max(TransportVersion.minimumCompatible(), GET_INFERENCE_FIELDS_ACTION_TV); + return TransportVersionUtils.allReleasedVersions().tailSet(minVersion, true); + } + + @Override + protected GetInferenceFieldsAction.Response mutateInstanceForVersion( + GetInferenceFieldsAction.Response instance, + TransportVersion version + ) { + return instance; + } + + private static GetInferenceFieldsAction.ExtendedInferenceFieldMetadata randomeExtendedInferenceFieldMetadata() { + return new GetInferenceFieldsAction.ExtendedInferenceFieldMetadata(randomInferenceFieldMetadata(), randomFloat()); + } + + private static InferenceFieldMetadata randomInferenceFieldMetadata() { + return new InferenceFieldMetadata( + randomIdentifier(), + randomIdentifier(), + randomIdentifier(), + randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new), + generateRandomChunkingSettings() + ); + } + + private static Map randomInferenceResultsMap() { + Map map = new HashMap<>(); + int numResults = randomIntBetween(0, 5); + for (int i = 0; i < numResults; i++) { + String inferenceId = randomIdentifier(); + map.put(inferenceId, randomInferenceResults()); + } + return map; + } + + private static InferenceResults randomInferenceResults() { + return randomFrom( + MlDenseEmbeddingResultsTests.createRandomResults(), + TextExpansionResultsTests.createRandomResults(), + WarningInferenceResultsTests.createRandomResults(), + ErrorInferenceResultsTests.createRandomResults() + ); + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/MultiFieldsInnerRetrieverUtils.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/MultiFieldsInnerRetrieverUtils.java index 50d584320d4f3..a8f349fb1fdca 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/MultiFieldsInnerRetrieverUtils.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/MultiFieldsInnerRetrieverUtils.java @@ -10,7 +10,6 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; -import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; @@ -34,8 +33,10 @@ import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; +import java.util.stream.Collectors; import static org.elasticsearch.action.ValidateActions.addValidationError; +import static org.elasticsearch.cluster.metadata.IndexMetadata.getMatchingInferenceFields; import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; /** @@ -275,27 +276,10 @@ private static Map inferenceFieldsAndWeightsForIndex( fieldsAndWeightsToQuery = defaultFieldsAndWeightsForIndex(indexMetadata, weightValidator); } - Map inferenceFields = new HashMap<>(); Map indexInferenceFields = indexMetadata.getInferenceFields(); - for (Map.Entry entry : fieldsAndWeightsToQuery.entrySet()) { - String field = entry.getKey(); - Float weight = entry.getValue(); - - if (Regex.isMatchAllPattern(field)) { - indexInferenceFields.keySet().forEach(f -> addToInferenceFieldsMap(inferenceFields, f, weight)); - } else if (Regex.isSimpleMatchPattern(field)) { - indexInferenceFields.keySet() - .stream() - .filter(f -> Regex.simpleMatch(field, f)) - .forEach(f -> addToInferenceFieldsMap(inferenceFields, f, weight)); - } else { - // No wildcards in field name - if (indexInferenceFields.containsKey(field)) { - addToInferenceFieldsMap(inferenceFields, field, weight); - } - } - } - return inferenceFields; + return getMatchingInferenceFields(indexInferenceFields, fieldsAndWeightsToQuery, true).entrySet() + .stream() + .collect(Collectors.toMap(e -> e.getKey().getName(), Map.Entry::getValue)); } private static Map nonInferenceFieldsAndWeightsForIndex( @@ -365,8 +349,4 @@ private static RetrieverBuilder generateLexicalRetriever( lexicalQueryBuilders.forEach(boolQueryBuilder::should); return new StandardRetrieverBuilder(boolQueryBuilder); } - - private static void addToInferenceFieldsMap(Map inferenceFields, String field, Float weight) { - inferenceFields.compute(field, (k, v) -> v == null ? weight : v * weight); - } } diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index b10112c842de7..dbde7cc6bbb90 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -331,6 +331,7 @@ public class Constants { "cluster:internal/xpack/inference/clear_inference_ccm_cache", "cluster:internal/xpack/inference/clear_inference_endpoint_cache", "cluster:internal/xpack/inference/create_endpoints", + "cluster:internal/xpack/inference/fields/get", "cluster:internal/xpack/inference/rerankwindowsize/get", "cluster:internal/xpack/inference/unified", "cluster:internal/xpack/ml/auditor/reset",