diff --git a/docs/changelog/133860.yaml b/docs/changelog/133860.yaml new file mode 100644 index 0000000000000..f036902602579 --- /dev/null +++ b/docs/changelog/133860.yaml @@ -0,0 +1,6 @@ +pr: 133860 +summary: Cache Inference Endpoints +area: Machine Learning +type: enhancement +issues: + - 133135 diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 0368965d35303..d7e844e95694c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -335,6 +335,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_FIXED_INDEX_LIKE = def(9_119_0_00); public static final TransportVersion TIME_SERIES_TELEMETRY = def(9_155_0_00); public static final TransportVersion INFERENCE_API_EIS_DIAGNOSTICS = def(9_156_0_00); + public static final TransportVersion ML_INFERENCE_ENDPOINT_CACHE = def(9_157_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java deleted file mode 100644 index 7704304b11365..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference; - -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.xcontent.ToXContentObject; - -public interface SerializableStats extends ToXContentObject, Writeable { - -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java index 025efa1689ed4..c027dab700672 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; import org.elasticsearch.transport.AbstractTransportRequest; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.ToXContentObject; @@ -28,6 +29,8 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.TransportVersions.ML_INFERENCE_ENDPOINT_CACHE; + public class GetInferenceDiagnosticsAction extends ActionType { public static final GetInferenceDiagnosticsAction INSTANCE = new GetInferenceDiagnosticsAction(); @@ -119,14 +122,23 @@ public static class NodeResponse extends BaseNodeResponse implements ToXContentF private static final String EXTERNAL_FIELD = "external"; private static final String EIS_FIELD = "eis_mtls"; private static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats"; + static final String INFERENCE_ENDPOINT_REGISTRY_STATS_FIELD_NAME = "inference_endpoint_registry"; private final ConnectionPoolStats externalConnectionPoolStats; private final ConnectionPoolStats eisMtlsConnectionPoolStats; - - public NodeResponse(DiscoveryNode node, PoolStats poolStats, PoolStats eisPoolStats) { + @Nullable + private final Stats inferenceEndpointRegistryStats; + + public NodeResponse( + DiscoveryNode node, + PoolStats poolStats, + PoolStats eisPoolStats, + @Nullable Stats inferenceEndpointRegistryStats + ) { super(node); externalConnectionPoolStats = ConnectionPoolStats.of(poolStats); eisMtlsConnectionPoolStats = ConnectionPoolStats.of(eisPoolStats); + this.inferenceEndpointRegistryStats = inferenceEndpointRegistryStats; } public NodeResponse(StreamInput in) throws IOException { @@ -138,6 +150,9 @@ public NodeResponse(StreamInput in) throws IOException { } else { eisMtlsConnectionPoolStats = ConnectionPoolStats.EMPTY; } + inferenceEndpointRegistryStats = in.getTransportVersion().onOrAfter(ML_INFERENCE_ENDPOINT_CACHE) + ? in.readOptionalWriteable(Stats::new) + : null; } @Override @@ -148,6 +163,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) { eisMtlsConnectionPoolStats.writeTo(out); } + if (out.getTransportVersion().onOrAfter(ML_INFERENCE_ENDPOINT_CACHE)) { + out.writeOptionalWriteable(inferenceEndpointRegistryStats); + } } @Override @@ -163,6 +181,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(CONNECTION_POOL_STATS_FIELD_NAME, eisMtlsConnectionPoolStats, params); } builder.endObject(); + if (inferenceEndpointRegistryStats != null) { + builder.field(INFERENCE_ENDPOINT_REGISTRY_STATS_FIELD_NAME, inferenceEndpointRegistryStats, params); + } return builder; } @@ -172,12 +193,13 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; NodeResponse response = (NodeResponse) o; return Objects.equals(externalConnectionPoolStats, response.externalConnectionPoolStats) - && Objects.equals(eisMtlsConnectionPoolStats, response.eisMtlsConnectionPoolStats); + && Objects.equals(eisMtlsConnectionPoolStats, response.eisMtlsConnectionPoolStats) + && Objects.equals(inferenceEndpointRegistryStats, response.inferenceEndpointRegistryStats); } @Override public int hashCode() { - return Objects.hash(externalConnectionPoolStats, eisMtlsConnectionPoolStats); + return Objects.hash(externalConnectionPoolStats, eisMtlsConnectionPoolStats, inferenceEndpointRegistryStats); } ConnectionPoolStats getExternalConnectionPoolStats() { @@ -188,6 +210,10 @@ ConnectionPoolStats getEisMtlsConnectionPoolStats() { return eisMtlsConnectionPoolStats; } + public Stats getInferenceEndpointRegistryStats() { + return inferenceEndpointRegistryStats; + } + static class ConnectionPoolStats implements ToXContentObject, Writeable { private static final String LEASED_CONNECTIONS = "leased_connections"; private static final String PENDING_CONNECTIONS = "pending_connections"; @@ -270,5 +296,35 @@ int getMaxConnections() { return maxConnections; } } + + public record Stats(int entryCount, long hits, long misses, long evictions) implements ToXContentObject, Writeable { + + private static final String NUM_OF_CACHE_ENTRIES = "cache_count"; + private static final String CACHE_HITS = "cache_hits"; + private static final String CACHE_MISSES = "cache_misses"; + private static final String CACHE_EVICTIONS = "cache_evictions"; + + public Stats(StreamInput in) throws IOException { + this(in.readVInt(), in.readVLong(), in.readVLong(), in.readVLong()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(entryCount); + out.writeVLong(hits); + out.writeVLong(misses); + out.writeVLong(evictions); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject() + .field(NUM_OF_CACHE_ENTRIES, entryCount) + .field(CACHE_HITS, hits) + .field(CACHE_MISSES, misses) + .field(CACHE_EVICTIONS, evictions) + .endObject(); + } + } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java index 3d1cb795e3f3e..020eeff459502 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java @@ -26,7 +26,7 @@ public static GetInferenceDiagnosticsAction.NodeResponse createRandom() { var randomExternalPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt()); var randomEisPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt()); - return new GetInferenceDiagnosticsAction.NodeResponse(node, randomExternalPoolStats, randomEisPoolStats); + return new GetInferenceDiagnosticsAction.NodeResponse(node, randomExternalPoolStats, randomEisPoolStats, randomCacheStats()); } @Override @@ -45,11 +45,16 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference if (randomBoolean()) { PoolStats mutatedConnPoolStats = mutatePoolStats(instance.getExternalConnectionPoolStats()); PoolStats eisPoolStats = copyPoolStats(instance.getEisMtlsConnectionPoolStats()); - return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), mutatedConnPoolStats, eisPoolStats); + return new GetInferenceDiagnosticsAction.NodeResponse( + instance.getNode(), + mutatedConnPoolStats, + eisPoolStats, + randomCacheStats() + ); } else { PoolStats connPoolStats = copyPoolStats(instance.getExternalConnectionPoolStats()); PoolStats mutatedEisPoolStats = mutatePoolStats(instance.getEisMtlsConnectionPoolStats()); - return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), connPoolStats, mutatedEisPoolStats); + return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), connPoolStats, mutatedEisPoolStats, null); } } @@ -79,24 +84,50 @@ private PoolStats copyPoolStats(GetInferenceDiagnosticsAction.NodeResponse.Conne ); } + private static GetInferenceDiagnosticsAction.NodeResponse.Stats randomCacheStats() { + return new GetInferenceDiagnosticsAction.NodeResponse.Stats( + randomInt(), + randomLongBetween(0, Long.MAX_VALUE), + randomLongBetween(0, Long.MAX_VALUE), + randomLongBetween(0, Long.MAX_VALUE) + ); + } + @Override protected GetInferenceDiagnosticsAction.NodeResponse mutateInstanceForVersion( GetInferenceDiagnosticsAction.NodeResponse instance, TransportVersion version ) { - if (version.before(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) { - return new GetInferenceDiagnosticsAction.NodeResponse( - instance.getNode(), - new PoolStats( - instance.getExternalConnectionPoolStats().getLeasedConnections(), - instance.getExternalConnectionPoolStats().getPendingConnections(), - instance.getExternalConnectionPoolStats().getAvailableConnections(), - instance.getExternalConnectionPoolStats().getMaxConnections() - ), - new PoolStats(0, 0, 0, 0) - ); - } else { + return mutateNodeResponseForVersion(instance, version); + } + + public static GetInferenceDiagnosticsAction.NodeResponse mutateNodeResponseForVersion( + GetInferenceDiagnosticsAction.NodeResponse instance, + TransportVersion version + ) { + if (version.onOrAfter(TransportVersions.ML_INFERENCE_ENDPOINT_CACHE)) { return instance; } + + var eisMltsConnectionPoolStats = version.onOrAfter(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS) + ? new PoolStats( + instance.getEisMtlsConnectionPoolStats().getLeasedConnections(), + instance.getEisMtlsConnectionPoolStats().getPendingConnections(), + instance.getEisMtlsConnectionPoolStats().getAvailableConnections(), + instance.getEisMtlsConnectionPoolStats().getMaxConnections() + ) + : new PoolStats(0, 0, 0, 0); + + return new GetInferenceDiagnosticsAction.NodeResponse( + instance.getNode(), + new PoolStats( + instance.getExternalConnectionPoolStats().getLeasedConnections(), + instance.getExternalConnectionPoolStats().getPendingConnections(), + instance.getExternalConnectionPoolStats().getAvailableConnections(), + instance.getExternalConnectionPoolStats().getMaxConnections() + ), + eisMltsConnectionPoolStats, + null + ); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java index 726015f2156ad..d0f608d55fc36 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java @@ -8,21 +8,23 @@ package org.elasticsearch.xpack.core.inference.action; import org.apache.http.pool.PoolStats; +import org.elasticsearch.TransportVersion; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; import java.util.List; import static org.hamcrest.Matchers.is; -public class GetInferenceDiagnosticsActionResponseTests extends AbstractWireSerializingTestCase { +public class GetInferenceDiagnosticsActionResponseTests extends AbstractBWCWireSerializationTestCase< + GetInferenceDiagnosticsAction.Response> { public static GetInferenceDiagnosticsAction.Response createRandom() { List responses = randomList( @@ -39,7 +41,14 @@ public void testToXContent() throws IOException { var eisPoolStats = new PoolStats(5, 6, 7, 8); var entity = new GetInferenceDiagnosticsAction.Response( ClusterName.DEFAULT, - List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, externalPoolStats, eisPoolStats)), + List.of( + new GetInferenceDiagnosticsAction.NodeResponse( + node, + externalPoolStats, + eisPoolStats, + new GetInferenceDiagnosticsAction.NodeResponse.Stats(5, 6, 7, 8) + ) + ), List.of() ); @@ -65,6 +74,12 @@ public void testToXContent() throws IOException { "available_connections":7, "max_connections":8 } + }, + "inference_endpoint_registry":{ + "cache_count": 5, + "cache_hits": 6, + "cache_misses": 7, + "cache_evictions": 8 } } }"""))); @@ -88,4 +103,19 @@ protected GetInferenceDiagnosticsAction.Response mutateInstance(GetInferenceDiag List.of() ); } + + @Override + protected GetInferenceDiagnosticsAction.Response mutateInstanceForVersion( + GetInferenceDiagnosticsAction.Response instance, + TransportVersion version + ) { + return new GetInferenceDiagnosticsAction.Response( + instance.getClusterName(), + instance.getNodes() + .stream() + .map(nodeResponse -> GetInferenceDiagnosticsActionNodeResponseTests.mutateNodeResponseForVersion(nodeResponse, version)) + .toList(), + instance.failures() + ); + } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java index 9701fff2a5789..48e1a74b7b4a3 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java @@ -110,6 +110,8 @@ public void testAttachWithModelId() throws IOException { var results = infer(inferenceId, List.of("washing machine")); assertNotNull(results.get("sparse_embedding")); + deleteModel(inferenceId); + forceStopMlNodeDeployment(deploymentId); } @@ -225,6 +227,7 @@ public void testNumAllocationsIsUpdated() throws IOException { ) ); + deleteModel(inferenceId); forceStopMlNodeDeployment(deploymentId); } @@ -266,6 +269,7 @@ public void testUpdateWhenInferenceEndpointCreatesDeployment() throws IOExceptio is(Map.of("num_allocations", 2, "num_threads", 1, "model_id", modelId)) ); + deleteModel(inferenceId); forceStopMlNodeDeployment(deploymentId); } @@ -309,6 +313,8 @@ public void testCannotUpdateAnotherInferenceEndpointsCreatedDeployment() throws ) ); + deleteModel(inferenceId); + deleteModel(secondInferenceId); forceStopMlNodeDeployment(deploymentId); } @@ -331,6 +337,7 @@ public void testStoppingDeploymentAttachedToInferenceEndpoint() throws IOExcepti ) ); + deleteModel(inferenceId); // Force stop will stop the deployment forceStopMlNodeDeployment(deploymentId); } @@ -358,16 +365,6 @@ private String endpointConfig(String modelId, String deploymentId) { """, modelId, deploymentId); } - private String updatedEndpointConfig(int numAllocations) { - return Strings.format(""" - { - "service_settings": { - "num_allocations": %d - } - } - """, numAllocations); - } - private Response startMlNodeDeploymemnt(String modelId, String deploymentId) throws IOException { String endPoint = "/_ml/trained_models/" + modelId @@ -413,16 +410,6 @@ private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations return client().performRequest(request); } - private Map updateMlNodeDeploymemnt(String deploymentId, String body) throws IOException { - String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update"; - - Request request = new Request("POST", endPoint); - request.setJsonEntity(body); - var response = client().performRequest(request); - assertStatusOkOrCreated(response); - return entityAsMap(response); - } - protected void stopMlNodeDeployment(String deploymentId) throws IOException { String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop"; Request request = new Request("POST", endpoint); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 35b3977b7049c..e7008c2292def 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.inference; +import org.elasticsearch.cluster.AbstractNamedDiffable; +import org.elasticsearch.cluster.NamedDiff; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; @@ -31,6 +34,7 @@ import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; +import org.elasticsearch.xpack.inference.registry.ClearInferenceEndpointCacheAction; import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettings; @@ -609,6 +613,24 @@ private static void addInternalNamedWriteables(List AbstractNamedDiffable.readDiffFrom( + Metadata.ProjectCustom.class, + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME, + in + ) + ) + ); } private static void addChunkingSettingsNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 5e7198d75f4bb..8380ac1d87c37 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -104,6 +104,8 @@ import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankDoc; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; +import org.elasticsearch.xpack.inference.registry.ClearInferenceEndpointCacheAction; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistryMetadata; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction; @@ -238,7 +240,8 @@ public List getActions() { new ActionHandler(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class), new ActionHandler(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class), new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class), - new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class) + new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class), + new ActionHandler(ClearInferenceEndpointCacheAction.INSTANCE, ClearInferenceEndpointCacheAction.class) ); } @@ -392,6 +395,16 @@ public Collection createComponents(PluginServices services) { // Add binding for interface -> implementation components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator)); + components.add( + new InferenceEndpointRegistry( + services.clusterService(), + settings, + modelRegistry.get(), + serviceRegistry, + services.projectResolver() + ) + ); + return components; } @@ -446,6 +459,13 @@ public List getNamedXContent() { ModelRegistryMetadata::fromXContent ) ); + namedXContent.add( + new NamedXContentRegistry.Entry( + Metadata.ProjectCustom.class, + new ParseField(ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME), + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata::fromXContent + ) + ); return namedXContent; } @@ -541,6 +561,7 @@ public static Set> getInferenceSettings() { settings.add(SKIP_VALIDATE_AND_START); settings.add(INDICES_INFERENCE_BATCH_SIZE); settings.add(INFERENCE_QUERY_TIMEOUT); + settings.addAll(InferenceEndpointRegistry.getSettingsDefinitions()); settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions()); return Collections.unmodifiableSet(settings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 269e0f27fd461..8e34cafa3e878 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -25,7 +25,6 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; @@ -42,7 +41,7 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; import java.io.IOException; @@ -79,7 +78,7 @@ public abstract class BaseTransportInferenceAction { - var serviceName = unparsedModel.service(); + var getModelListener = ActionListener.wrap((Model model) -> { + var serviceName = model.getConfigurations().getService(); try { - validateRequest(request, unparsedModel); + validateRequest(request, model); } catch (Exception e) { - recordRequestDurationMetrics(unparsedModel, timer, e); + recordRequestDurationMetrics(model, timer, e); listener.onFailure(e); return; } @@ -162,15 +161,9 @@ protected void doExecute(Task task, Request request, ActionListener unknownServiceException(serviceName, request.getInferenceEntityId())); validationHelper( - () -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false, - () -> requestModelTaskTypeMismatchException(requestTaskType, unparsedModel.taskType()) - ); - validationHelper( - () -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel), - () -> createInvalidTaskTypeException(request, unparsedModel) + () -> request.getTaskType().isAnyOrSame(model.getTaskType()) == false, + () -> requestModelTaskTypeMismatchException(requestTaskType, model.getTaskType()) ); + validationHelper(() -> isInvalidTaskTypeForInferenceEndpoint(request, model), () -> createInvalidTaskTypeException(request, model)); } - private NodeRoutingDecision determineRouting(String serviceName, Request request, UnparsedModel unparsedModel, String localNodeId) { - var modelTaskType = unparsedModel.taskType(); - + private NodeRoutingDecision determineRouting(String serviceName, Request request, TaskType modelTaskType, String localNodeId) { // Rerouting not supported or request was already rerouted if (inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceName, modelTaskType) == false || request.hasBeenRerouted()) { @@ -274,7 +262,7 @@ public InferenceAction.Response read(StreamInput in) throws IOException { ); } - private void recordRequestDurationMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { + private void recordRequestDurationMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) { Map metricAttributes = new HashMap<>(); metricAttributes.putAll(modelAttributes(model)); metricAttributes.putAll(responseAttributes(unwrapCause(t))); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java index 1ddfd784676f5..0b11b4c1c69a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import java.io.IOException; import java.util.List; @@ -34,6 +35,7 @@ public class TransportGetInferenceDiagnosticsAction extends TransportNodesAction public record ClientManagers(HttpClientManager externalHttpClientManager, HttpClientManager eisMtlsHttpClientManager) {} private final ClientManagers managers; + private final InferenceEndpointRegistry inferenceEndpointRegistry; @Inject public TransportGetInferenceDiagnosticsAction( @@ -41,7 +43,8 @@ public TransportGetInferenceDiagnosticsAction( ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, - ClientManagers managers + ClientManagers managers, + InferenceEndpointRegistry inferenceEndpointRegistry ) { super( GetInferenceDiagnosticsAction.NAME, @@ -53,6 +56,7 @@ public TransportGetInferenceDiagnosticsAction( ); this.managers = Objects.requireNonNull(managers); + this.inferenceEndpointRegistry = Objects.requireNonNull(inferenceEndpointRegistry); } @Override @@ -79,7 +83,22 @@ protected GetInferenceDiagnosticsAction.NodeResponse nodeOperation(GetInferenceD return new GetInferenceDiagnosticsAction.NodeResponse( transportService.getLocalNode(), managers.externalHttpClientManager().getPoolStats(), - managers.eisMtlsHttpClientManager().getPoolStats() + managers.eisMtlsHttpClientManager().getPoolStats(), + cacheStats() ); } + + private GetInferenceDiagnosticsAction.NodeResponse.Stats cacheStats() { + if (inferenceEndpointRegistry.cacheEnabled()) { + var stats = inferenceEndpointRegistry.stats(); + return new GetInferenceDiagnosticsAction.NodeResponse.Stats( + inferenceEndpointRegistry.cacheCount(), + stats.getHits(), + stats.getMisses(), + stats.getEvictions() + ); + } else { + return null; + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index f14d679ba7d26..f0fb0ec82757a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -15,7 +15,6 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.license.XPackLicenseState; @@ -24,7 +23,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; public class TransportInferenceAction extends BaseTransportInferenceAction { @@ -33,7 +32,7 @@ public TransportInferenceAction( TransportService transportService, ActionFilters actionFilters, XPackLicenseState licenseState, - ModelRegistry modelRegistry, + InferenceEndpointRegistry inferenceEndpointRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, @@ -46,7 +45,7 @@ public TransportInferenceAction( transportService, actionFilters, licenseState, - modelRegistry, + inferenceEndpointRegistry, serviceRegistry, inferenceStats, streamingTaskManager, @@ -58,12 +57,12 @@ public TransportInferenceAction( } @Override - protected boolean isInvalidTaskTypeForInferenceEndpoint(InferenceAction.Request request, UnparsedModel unparsedModel) { + protected boolean isInvalidTaskTypeForInferenceEndpoint(InferenceAction.Request request, Model model) { return false; } @Override - protected ElasticsearchStatusException createInvalidTaskTypeException(InferenceAction.Request request, UnparsedModel unparsedModel) { + protected ElasticsearchStatusException createInvalidTaskTypeException(InferenceAction.Request request, Model model) { return null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index d0eef677ca1d3..4fe5dd3a55a12 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -16,7 +16,6 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.license.XPackLicenseState; @@ -29,7 +28,7 @@ import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import java.util.concurrent.Flow; @@ -40,7 +39,7 @@ public TransportUnifiedCompletionInferenceAction( TransportService transportService, ActionFilters actionFilters, XPackLicenseState licenseState, - ModelRegistry modelRegistry, + InferenceEndpointRegistry inferenceEndpointRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, @@ -53,7 +52,7 @@ public TransportUnifiedCompletionInferenceAction( transportService, actionFilters, licenseState, - modelRegistry, + inferenceEndpointRegistry, serviceRegistry, inferenceStats, streamingTaskManager, @@ -65,15 +64,12 @@ public TransportUnifiedCompletionInferenceAction( } @Override - protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.Request request, UnparsedModel unparsedModel) { - return request.getTaskType().isAnyOrSame(TaskType.CHAT_COMPLETION) == false || unparsedModel.taskType() != TaskType.CHAT_COMPLETION; + protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.Request request, Model model) { + return request.getTaskType().isAnyOrSame(TaskType.CHAT_COMPLETION) == false || model.getTaskType() != TaskType.CHAT_COMPLETION; } @Override - protected ElasticsearchStatusException createInvalidTaskTypeException( - UnifiedCompletionAction.Request request, - UnparsedModel unparsedModel - ) { + protected ElasticsearchStatusException createInvalidTaskTypeException(UnifiedCompletionAction.Request request, Model model) { return new ElasticsearchStatusException( "Incompatible task_type for unified API, the requested type [{}] must be one of [{}]", RestStatus.BAD_REQUEST, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java new file mode 100644 index 0000000000000..2ca6a8312dbae --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java @@ -0,0 +1,242 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.registry; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.cluster.AbstractNamedDiffable; +import org.elasticsearch.cluster.AckedBatchedClusterStateUpdateTask; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateAckListener; +import org.elasticsearch.cluster.SimpleBatchedAckListenerTaskExecutor; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.metadata.ProjectMetadata; +import org.elasticsearch.cluster.project.ProjectResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.cluster.service.MasterServiceTaskQueue; +import org.elasticsearch.common.Priority; +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Iterator; +import java.util.Objects; + +import static org.elasticsearch.TransportVersions.ML_INFERENCE_ENDPOINT_CACHE; + +/** + * Clears the cache in {@link InferenceEndpointRegistry}. + * This uses the cluster state to broadcast the message to all nodes to clear their cache, which has guaranteed delivery. + * There are some edge cases where deletes can come from any node, for example ElasticInferenceServiceAuthorizationHandler and + * SemanticTextIndexOptionsIT will delete endpoints on whatever node is handling the request. So this must use a master node transport + * action so that the cluster updates can invalidate the cache, even though most requests will originate from the master node + * (e.g. when updating and deleting inference endpoints via REST). + */ +public class ClearInferenceEndpointCacheAction extends AcknowledgedTransportMasterNodeAction { + private static final Logger log = LogManager.getLogger(ClearInferenceEndpointCacheAction.class); + private static final String NAME = "cluster:internal/xpack/inference/clear_inference_endpoint_cache"; + public static final ActionType INSTANCE = new ActionType<>(NAME); + private static final String TASK_QUEUE_NAME = "inference-endpoint-cache-management"; + + private final ProjectResolver projectResolver; + private final InferenceEndpointRegistry inferenceEndpointRegistry; + private final MasterServiceTaskQueue taskQueue; + + @Inject + public ClearInferenceEndpointCacheAction( + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + ProjectResolver projectResolver, + InferenceEndpointRegistry inferenceEndpointRegistry + ) { + super( + NAME, + transportService, + clusterService, + threadPool, + actionFilters, + ClearInferenceEndpointCacheAction.Request::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + this.projectResolver = projectResolver; + this.inferenceEndpointRegistry = inferenceEndpointRegistry; + this.taskQueue = clusterService.createTaskQueue(TASK_QUEUE_NAME, Priority.IMMEDIATE, new CacheMetadataUpdateTaskExecutor()); + clusterService.addListener( + event -> event.state() + .metadata() + .projects() + .values() + .stream() + .map(ProjectMetadata::id) + .filter(id -> event.customMetadataChanged(id, InvalidateCacheMetadata.NAME)) + .peek(id -> log.trace("Inference endpoint cache on node [{}]", () -> event.state().nodes().getLocalNodeId())) + .forEach(inferenceEndpointRegistry::invalidateAll) + ); + } + + @Override + protected void masterOperation( + Task task, + ClearInferenceEndpointCacheAction.Request request, + ClusterState state, + ActionListener listener + ) { + if (inferenceEndpointRegistry.cacheEnabled()) { + taskQueue.submitTask("invalidateAll", new RefreshCacheMetadataVersionTask(projectResolver.getProjectId(), listener), null); + } else { + listener.onResponse(AcknowledgedResponse.TRUE); + } + } + + @Override + protected ClusterBlockException checkBlock(ClearInferenceEndpointCacheAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(projectResolver.getProjectId(), ClusterBlockLevel.METADATA_WRITE); + } + + public static class Request extends AcknowledgedRequest { + protected Request() { + super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); + } + + protected Request(StreamInput in) throws IOException { + super(in); + } + + @Override + public int hashCode() { + return Objects.hashCode(ackTimeout()); + } + + @Override + public boolean equals(Object other) { + if (other == this) return true; + return other instanceof ClearInferenceEndpointCacheAction.Request that && Objects.equals(that.ackTimeout(), ackTimeout()); + } + } + + public static class InvalidateCacheMetadata extends AbstractNamedDiffable implements Metadata.ProjectCustom { + public static final String NAME = "inference-endpoint-cache-metadata"; + private static final InvalidateCacheMetadata EMPTY = new InvalidateCacheMetadata(0L); + private static final ParseField VERSION_FIELD = new ParseField("version"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + args -> new InvalidateCacheMetadata((long) args[0]) + ); + + static { + PARSER.declareLong(ConstructingObjectParser.constructorArg(), VERSION_FIELD); + } + + public static InvalidateCacheMetadata fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static InvalidateCacheMetadata fromMetadata(ProjectMetadata projectMetadata) { + InvalidateCacheMetadata metadata = projectMetadata.custom(NAME); + return metadata == null ? EMPTY : metadata; + } + + private final long version; + + private InvalidateCacheMetadata(long version) { + this.version = version; + } + + public InvalidateCacheMetadata(StreamInput in) throws IOException { + this(in.readVLong()); + } + + public InvalidateCacheMetadata bumpVersion() { + return new InvalidateCacheMetadata(version < Long.MAX_VALUE ? version + 1 : 1L); + } + + @Override + public EnumSet context() { + return Metadata.ALL_CONTEXTS; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ML_INFERENCE_ENDPOINT_CACHE; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(version); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params ignored) { + return Iterators.single(((builder, params) -> builder.field(VERSION_FIELD.getPreferredName(), version))); + } + + @Override + public int hashCode() { + return Objects.hashCode(version); + } + + @Override + public boolean equals(Object other) { + if (other == this) return true; + return other instanceof InvalidateCacheMetadata that && that.version == this.version; + } + } + + private static class RefreshCacheMetadataVersionTask extends AckedBatchedClusterStateUpdateTask { + private final ProjectId projectId; + + private RefreshCacheMetadataVersionTask(ProjectId projectId, ActionListener listener) { + super(TimeValue.THIRTY_SECONDS, listener); + this.projectId = projectId; + } + } + + private static class CacheMetadataUpdateTaskExecutor extends SimpleBatchedAckListenerTaskExecutor { + @Override + public Tuple executeTask(RefreshCacheMetadataVersionTask task, ClusterState clusterState) { + var projectMetadata = clusterState.metadata().getProject(task.projectId); + var currentMetadata = InvalidateCacheMetadata.fromMetadata(projectMetadata); + var updatedMetadata = currentMetadata.bumpVersion(); + var newProjectMetadata = ProjectMetadata.builder(projectMetadata).putCustom(InvalidateCacheMetadata.NAME, updatedMetadata); + return new Tuple<>(ClusterState.builder(clusterState).putProjectMetadata(newProjectMetadata).build(), task); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java new file mode 100644 index 0000000000000..46d93e8d404b7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java @@ -0,0 +1,149 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.registry; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.project.ProjectResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.cache.Cache; +import org.elasticsearch.common.cache.CacheBuilder; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; + +import java.util.Collection; +import java.util.List; + +/** + * A registry that assembles and caches Inference Endpoints, {@link Model}, for reuse. + * Models are high read and minimally written, where changes only occur during updates and deletes. + * The cache is invalidated via the {@link ClearInferenceEndpointCacheAction} so that every node gets the invalidation + * message. + */ +public class InferenceEndpointRegistry { + + private static final Setting INFERENCE_ENDPOINT_CACHE_ENABLED = Setting.boolSetting( + "xpack.inference.endpoint.cache.enabled", + true, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + private static final Setting INFERENCE_ENDPOINT_CACHE_WEIGHT = Setting.intSetting( + "xpack.inference.endpoint.cache.weight", + 25, + Setting.Property.NodeScope + ); + + private static final Setting INFERENCE_ENDPOINT_CACHE_EXPIRY = Setting.timeSetting( + "xpack.inference.endpoint.cache.expiry_time", + TimeValue.timeValueMinutes(15), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueHours(1), + Setting.Property.NodeScope + ); + + public static Collection> getSettingsDefinitions() { + return List.of(INFERENCE_ENDPOINT_CACHE_ENABLED, INFERENCE_ENDPOINT_CACHE_WEIGHT, INFERENCE_ENDPOINT_CACHE_EXPIRY); + } + + private static final Logger log = LogManager.getLogger(InferenceEndpointRegistry.class); + private static final Cache.Stats EMPTY = new Cache.Stats(0, 0, 0); + private final ModelRegistry modelRegistry; + private final InferenceServiceRegistry serviceRegistry; + private final ProjectResolver projectResolver; + private final Cache cache; + private volatile boolean cacheEnabled; + + public InferenceEndpointRegistry( + ClusterService clusterService, + Settings settings, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + ProjectResolver projectResolver + ) { + this.modelRegistry = modelRegistry; + this.serviceRegistry = serviceRegistry; + this.projectResolver = projectResolver; + this.cache = CacheBuilder.builder() + .setMaximumWeight(INFERENCE_ENDPOINT_CACHE_WEIGHT.get(settings)) + .setExpireAfterWrite(INFERENCE_ENDPOINT_CACHE_EXPIRY.get(settings)) + .build(); + this.cacheEnabled = INFERENCE_ENDPOINT_CACHE_ENABLED.get(settings); + + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(INFERENCE_ENDPOINT_CACHE_ENABLED, enabled -> this.cacheEnabled = enabled); + } + + public void getEndpoint(String inferenceEntityId, ActionListener listener) { + var key = new InferenceIdAndProject(inferenceEntityId, projectResolver.getProjectId()); + var cachedModel = cacheEnabled ? cache.get(key) : null; + if (cachedModel != null) { + log.trace("Retrieved [{}] from cache.", inferenceEntityId); + listener.onResponse(cachedModel); + } else { + loadFromIndex(key, listener); + } + } + + void invalidateAll(ProjectId projectId) { + if (cacheEnabled) { + var cacheKeys = cache.keys().iterator(); + while (cacheKeys.hasNext()) { + if (cacheKeys.next().projectId.equals(projectId)) { + cacheKeys.remove(); + } + } + } + } + + private void loadFromIndex(InferenceIdAndProject idAndProject, ActionListener listener) { + modelRegistry.getModelWithSecrets(idAndProject.inferenceEntityId(), listener.delegateFailureAndWrap((l, unparsedModel) -> { + var service = serviceRegistry.getService(unparsedModel.service()) + .orElseThrow( + () -> new ResourceNotFoundException( + "Unknown service [{}] for model [{}]", + unparsedModel.service(), + idAndProject.inferenceEntityId() + ) + ); + + var model = service.parsePersistedConfigWithSecrets( + unparsedModel.inferenceEntityId(), + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ); + + if (cacheEnabled) { + cache.put(idAndProject, model); + } + l.onResponse(model); + })); + } + + public Cache.Stats stats() { + return cacheEnabled ? cache.stats() : EMPTY; + } + + public int cacheCount() { + return cacheEnabled ? cache.count() : 0; + } + + public boolean cacheEnabled() { + return cacheEnabled; + } + + private record InferenceIdAndProject(String inferenceEntityId, ProjectId projectId) {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index fe7c4a9395cd1..7cd1cf5999d11 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -614,6 +614,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi } else { // since updating the secrets was successful, we can remove the lock and respond to the final listener preventDeletionLock.remove(inferenceEntityId); + refreshInferenceEndpointCache(); finalListener.onResponse(true); } }).andThen((subListener, configResponse) -> { @@ -844,7 +845,10 @@ private void deleteModels(Set inferenceEntityIds, boolean updateClusterS client.execute( DeleteByQueryAction.INSTANCE, request, - getDeleteModelClusterStateListener(inferenceEntityIds, updateClusterState, listener) + ActionListener.runAfter( + getDeleteModelClusterStateListener(inferenceEntityIds, updateClusterState, listener), + this::refreshInferenceEndpointCache + ) ); } @@ -899,6 +903,17 @@ public void onFailure(Exception exc) { }; } + private void refreshInferenceEndpointCache() { + client.execute( + ClearInferenceEndpointCacheAction.INSTANCE, + new ClearInferenceEndpointCacheAction.Request(), + ActionListener.wrap( + ignored -> logger.debug("Successfully refreshed inference endpoint cache."), + e -> logger.atDebug().withThrowable(e).log("Failed to refresh inference endpoint cache.") + ) + ); + } + private static DeleteByQueryRequest createDeleteRequest(Set inferenceEntityIds) { DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); request.indices(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index 812cd1e3c6d7f..47053b7cbe5eb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -20,7 +20,6 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.license.MockLicenseState; @@ -34,11 +33,10 @@ import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import org.junit.Before; import org.mockito.ArgumentCaptor; -import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.Flow; @@ -59,7 +57,7 @@ public abstract class BaseTransportInferenceActionTestCase extends ESTestCase { private MockLicenseState licenseState; - private ModelRegistry modelRegistry; + private InferenceEndpointRegistry inferenceEndpointRegistry; private StreamingTaskManager streamingTaskManager; private BaseTransportInferenceAction action; private ThreadPool threadPool; @@ -87,7 +85,7 @@ public void setUp() throws Exception { transportService = mock(); inferenceServiceRateLimitCalculator = mock(); licenseState = mock(); - modelRegistry = mock(); + inferenceEndpointRegistry = mock(); serviceRegistry = mock(); inferenceStats = InferenceStatsTests.mockInferenceStats(); streamingTaskManager = mock(); @@ -96,7 +94,7 @@ public void setUp() throws Exception { transportService, actionFilters, licenseState, - modelRegistry, + inferenceEndpointRegistry, serviceRegistry, inferenceStats, streamingTaskManager, @@ -113,7 +111,7 @@ protected abstract BaseTransportInferenceAction createAction( TransportService transportService, ActionFilters actionFilters, MockLicenseState licenseState, - ModelRegistry modelRegistry, + InferenceEndpointRegistry inferenceEndpointRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, @@ -132,7 +130,7 @@ public void testMetricsAfterModelRegistryError() { ActionListener listener = ans.getArgument(1); listener.onFailure(expectedException); return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); + }).when(inferenceEndpointRegistry).getEndpoint(any(), any()); doExecute(taskType); @@ -168,7 +166,7 @@ public void onFailure(Exception e) {} } public void testMetricsAfterMissingService() { - mockModelRegistry(taskType); + mockInferenceEndpointRegistry(taskType); when(serviceRegistry.getService(any())).thenReturn(Optional.empty()); @@ -190,19 +188,18 @@ public void testMetricsAfterMissingService() { })); } - protected void mockModelRegistry(TaskType expectedTaskType) { - var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of()); + protected void mockInferenceEndpointRegistry(TaskType expectedTaskType) { doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onResponse(unparsedModel); + ActionListener listener = ans.getArgument(1); + listener.onResponse(mockModel(expectedTaskType)); return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); + }).when(inferenceEndpointRegistry).getEndpoint(any(), any()); } public void testMetricsAfterUnknownTaskType() { var modelTaskType = TaskType.RERANK; var requestTaskType = TaskType.SPARSE_EMBEDDING; - mockModelRegistry(modelTaskType); + mockInferenceEndpointRegistry(modelTaskType); when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); var listener = doExecute(requestTaskType); @@ -363,7 +360,7 @@ public void testProductUseCaseHeaderPresentInThreadContextIfPresent() { when(threadPool.getThreadContext()).thenReturn(threadContext); - mockModelRegistry(taskType); + mockInferenceEndpointRegistry(taskType); mockService(listener -> listener.onResponse(mock())); Request request = createRequest(); @@ -431,30 +428,25 @@ protected void mockService( listenerAction.accept(ans.getArgument(3)); return null; }).when(service).unifiedCompletionInfer(any(), any(), any(), any()); - mockModelAndServiceRegistry(service); + mockInferenceEndpointRegistry(taskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); } protected Model mockModel() { + return mockModel(taskType); + } + + protected Model mockModel(TaskType expectedTaskType) { Model model = mock(); ModelConfigurations modelConfigurations = mock(); when(modelConfigurations.getService()).thenReturn(serviceId); when(model.getConfigurations()).thenReturn(modelConfigurations); - when(model.getTaskType()).thenReturn(taskType); + when(model.getTaskType()).thenReturn(expectedTaskType); when(model.getServiceSettings()).thenReturn(mock()); + when(model.getInferenceEntityId()).thenReturn(inferenceId); return model; } - protected void mockModelAndServiceRegistry(InferenceService service) { - var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of()); - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - - when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); - } - protected void mockValidLicenseState() { when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(true); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index 547078d93acc4..dd0a1b952233b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -22,7 +22,7 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.common.RateLimitAssignment; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import java.util.List; @@ -49,7 +49,7 @@ protected BaseTransportInferenceAction createAction( TransportService transportService, ActionFilters actionFilters, MockLicenseState licenseState, - ModelRegistry modelRegistry, + InferenceEndpointRegistry inferenceEndpointRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, @@ -61,7 +61,7 @@ protected BaseTransportInferenceAction createAction( transportService, actionFilters, licenseState, - modelRegistry, + inferenceEndpointRegistry, serviceRegistry, inferenceStats, streamingTaskManager, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java index 9e6f4a6260936..0b05509acaf8b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import java.util.Optional; @@ -45,7 +45,7 @@ protected BaseTransportInferenceAction createAc TransportService transportService, ActionFilters actionFilters, MockLicenseState licenseState, - ModelRegistry modelRegistry, + InferenceEndpointRegistry inferenceEndpointRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, @@ -57,7 +57,7 @@ protected BaseTransportInferenceAction createAc transportService, actionFilters, licenseState, - modelRegistry, + inferenceEndpointRegistry, serviceRegistry, inferenceStats, streamingTaskManager, @@ -75,7 +75,7 @@ protected UnifiedCompletionAction.Request createRequest() { public void testThrows_IncompatibleTaskTypeException_WhenUsingATextEmbeddingInferenceEndpoint() { var modelTaskType = TaskType.TEXT_EMBEDDING; var requestTaskType = TaskType.TEXT_EMBEDDING; - mockModelRegistry(modelTaskType); + mockInferenceEndpointRegistry(modelTaskType); when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); var listener = doExecute(requestTaskType); @@ -100,7 +100,7 @@ public void testThrows_IncompatibleTaskTypeException_WhenUsingATextEmbeddingInfe public void testThrows_IncompatibleTaskTypeException_WhenUsingRequestIsAny_ModelIsTextEmbedding() { var modelTaskType = TaskType.ANY; var requestTaskType = TaskType.TEXT_EMBEDDING; - mockModelRegistry(modelTaskType); + mockInferenceEndpointRegistry(modelTaskType); when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); var listener = doExecute(requestTaskType); @@ -123,7 +123,7 @@ public void testThrows_IncompatibleTaskTypeException_WhenUsingRequestIsAny_Model } public void testMetricsAfterUnifiedInferSuccess_WithRequestTaskTypeAny() { - mockModelRegistry(TaskType.COMPLETION); + mockInferenceEndpointRegistry(TaskType.COMPLETION); mockService(listener -> listener.onResponse(mock())); var listener = doExecute(TaskType.ANY); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheActionTests.java new file mode 100644 index 0000000000000..f5e134c7089cb --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheActionTests.java @@ -0,0 +1,128 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.registry; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class ClearInferenceEndpointCacheActionTests extends ESSingleNodeTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private static final String INFERENCE_ENDPOINT_ID = "1"; + + @Override + protected Collection> getPlugins() { + return List.of(LocalStateInferencePlugin.class); + } + + @Override + protected Settings nodeSettings() { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + public void testCacheEviction() throws Exception { + storeGoodEndpoint(); + invokeEndpoint(); + + var stats = cacheStats(); + assertThat(stats.entryCount(), equalTo(1)); + assertThat(stats.hits(), equalTo(0L)); + assertThat(stats.misses(), equalTo(1L)); + assertThat(stats.evictions(), equalTo(0L)); + + var listener = new PlainActionFuture(); + clusterAdmin().execute(ClearInferenceEndpointCacheAction.INSTANCE, new ClearInferenceEndpointCacheAction.Request(), listener); + assertTrue(listener.actionGet(TIMEOUT).isAcknowledged()); + + assertBusy(() -> { + var nextStats = cacheStats(); + assertThat(nextStats.entryCount(), equalTo(0)); + assertThat(nextStats.hits(), equalTo(0L)); + assertThat(nextStats.misses(), equalTo(1L)); + assertThat(nextStats.evictions(), equalTo(1L)); + }, 10, TimeUnit.SECONDS); + + invokeEndpoint(); + stats = cacheStats(); + assertThat(stats.entryCount(), equalTo(1)); + assertThat(stats.hits(), equalTo(0L)); + assertThat(stats.misses(), equalTo(2L)); + assertThat(stats.evictions(), equalTo(1L)); + } + + private void storeGoodEndpoint() throws IOException { + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("service", TestSparseInferenceServiceExtension.TestInferenceService.NAME); + builder.field("service_settings", Map.of("model", "model", "api_key", "1234")); + builder.endObject(); + + content = BytesReference.bytes(builder); + } + + var request = new PutInferenceModelAction.Request( + TaskType.SPARSE_EMBEDDING, + INFERENCE_ENDPOINT_ID, + content, + XContentType.JSON, + TEST_REQUEST_TIMEOUT + ); + client().execute(PutInferenceModelAction.INSTANCE, request).actionGet(TIMEOUT); + } + + private void invokeEndpoint() { + client().execute( + InferenceAction.INSTANCE, + new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, + INFERENCE_ENDPOINT_ID, + null, + null, + null, + List.of("hello"), + null, + InputType.INTERNAL_SEARCH, + TIMEOUT, + false + ) + ).actionGet(TIMEOUT); + } + + private GetInferenceDiagnosticsAction.NodeResponse.Stats cacheStats() { + var diagnostics = client().execute(GetInferenceDiagnosticsAction.INSTANCE, new GetInferenceDiagnosticsAction.Request()) + .actionGet(TIMEOUT); + + assertThat(diagnostics.getNodes(), hasSize(1)); + return diagnostics.getNodes().getFirst().getInferenceEndpointRegistryStats(); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistryTests.java new file mode 100644 index 0000000000000..b172f0e264c79 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistryTests.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.registry; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mock.AbstractTestInferenceService; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.junit.Before; + +import java.util.Collection; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.registry.ModelRegistryTests.assertStoreModel; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class InferenceEndpointRegistryTests extends ESSingleNodeTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + + private InferenceEndpointRegistry inferenceEndpointRegistry; + private ModelRegistry registry; + + @Override + protected Collection> getPlugins() { + return List.of(LocalStateInferencePlugin.class); + } + + @Before + public void createComponents() { + inferenceEndpointRegistry = node().injector().getInstance(InferenceEndpointRegistry.class); + registry = node().injector().getInstance(ModelRegistry.class); + } + + public void testGetThrowsResourceNotFoundWhenNoHitsReturned() { + assertThat( + getEndpointException("this is not found", ResourceNotFoundException.class).getMessage(), + is("Inference endpoint not found [this is not found]") + ); + } + + private Exception getEndpointException(String id, Class expectedExceptionClass) { + var listener = new PlainActionFuture(); + inferenceEndpointRegistry.getEndpoint(id, listener); + return expectThrows(expectedExceptionClass, () -> listener.actionGet(TIMEOUT)); + } + + public void testGetModel() { + var expectedEndpoint = storeWorkingEndpoint("1"); + var actualEndpoint = getEndpoint("1"); + assertThat(actualEndpoint, equalTo(expectedEndpoint)); + assertThat(getEndpoint("1"), sameInstance(actualEndpoint)); + } + + private Model storeWorkingEndpoint(String id) { + var expectedEndpoint = new AbstractTestInferenceService.TestServiceModel( + id, + TaskType.SPARSE_EMBEDDING, + "test_service", + new TestSparseInferenceServiceExtension.TestServiceSettings("model", null, false), + new AbstractTestInferenceService.TestTaskSettings(randomInt(3)), + new AbstractTestInferenceService.TestSecretSettings("secret") + ); + assertStoreModel(registry, expectedEndpoint); + return expectedEndpoint; + } + + private Model getEndpoint(String id) { + var listener = new PlainActionFuture(); + inferenceEndpointRegistry.getEndpoint(id, listener); + return listener.actionGet(TIMEOUT); + } + + public void testGetModelWithUnknownService() { + var id = "ahhhh"; + var expectedEndpoint = new AbstractTestInferenceService.TestServiceModel( + id, + TaskType.SPARSE_EMBEDDING, + "hello", + new TestSparseInferenceServiceExtension.TestServiceSettings("model", null, false), + new AbstractTestInferenceService.TestTaskSettings(randomInt(3)), + new AbstractTestInferenceService.TestSecretSettings("secret") + ); + assertStoreModel(registry, expectedEndpoint); + + assertThat( + getEndpointException(id, ResourceNotFoundException.class).getMessage(), + equalTo("Unknown service [hello] for model [ahhhh]") + ); + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java index 86aadcb0ec1d4..bfce552b094f6 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.cluster.AbstractNamedDiffable; import org.elasticsearch.cluster.ClusterModule; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.NamedDiff; @@ -99,6 +100,7 @@ import org.elasticsearch.xpack.esql.core.plugin.EsqlCorePlugin; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.elasticsearch.xpack.ilm.IndexLifecycle; +import org.elasticsearch.xpack.inference.registry.ClearInferenceEndpointCacheAction; import org.elasticsearch.xpack.inference.registry.ModelRegistryMetadata; import org.elasticsearch.xpack.ml.LocalStateMachineLearning; import org.elasticsearch.xpack.ml.autoscaling.MlScalingReason; @@ -436,6 +438,24 @@ protected void assertClusterRoundTrip() throws IOException { new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new) ); entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom)); + entries.add( + new NamedWriteableRegistry.Entry( + Metadata.ProjectCustom.class, + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME, + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata::new + ) + ); + entries.add( + new NamedWriteableRegistry.Entry( + NamedDiff.class, + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME, + in -> AbstractNamedDiffable.readDiffFrom( + Metadata.ProjectCustom.class, + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME, + in + ) + ) + ); // Retrieve the cluster state from a random node, and serialize and deserialize it. final ClusterStateResponse clusterStateResponse = client().admin() diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index da9a81898de60..f2634a19d1068 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -174,6 +174,7 @@ public class Constants { "cluster:admin/xpack/enrich/get", "cluster:admin/xpack/enrich/put", "cluster:admin/xpack/enrich/reindex", + "cluster:internal/xpack/inference/clear_inference_endpoint_cache", "cluster:admin/xpack/inference/delete", "cluster:admin/xpack/inference/put", "cluster:admin/xpack/inference/update",