From 25075fac7f7b394834ce7a93dd1387aed45013c1 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Thu, 28 Aug 2025 08:26:50 -0400 Subject: [PATCH 01/10] [ML] Cache Inference Endpoints Maintain parsed Inference Endpoints in memory for reuse. Endpoints are cached on first access and expire after write. This removes search pressure during inference, bypassing search requests to system indices for repeated model access. When any endpoint is updated or deleted, the whole cache is invalidated and must be reloaded. Cache can be configured with three settings: - `xpack.inference.cache.enabled` enables or disables the cache (default enabled). - `xpack.inference.cache.weight` controls how many endpoints can live in the cache (default 25). - `xpack.inference.cache.expiry_time` controls how long endpoints live in the cache, measured from when they are first accessed (default 15 minutes, minimum 1 minute, maximum 1 hour). Resolve #133135 --- .../core/inference/SerializableStats.java | 6 +- .../action/GetInferenceDiagnosticsAction.java | 27 +- ...nceDiagnosticsActionNodeResponseTests.java | 58 ++++- ...ferenceDiagnosticsActionResponseTests.java | 16 +- .../InferenceNamedWriteablesProvider.java | 31 +++ .../xpack/inference/InferencePlugin.java | 46 +++- .../action/BaseTransportInferenceAction.java | 50 ++-- ...ransportGetInferenceDiagnosticsAction.java | 12 +- .../action/TransportInferenceAction.java | 11 +- ...sportUnifiedCompletionInferenceAction.java | 16 +- .../ClearInferenceEndpointCacheAction.java | 237 ++++++++++++++++++ .../registry/InferenceEndpointRegistry.java | 170 +++++++++++++ .../inference/registry/ModelRegistry.java | 17 +- .../BaseTransportInferenceActionTestCase.java | 50 ++-- .../action/TransportInferenceActionTests.java | 6 +- ...TransportUnifiedCompletionActionTests.java | 12 +- ...learInferenceEndpointCacheActionTests.java | 127 ++++++++++ .../InferenceEndpointRegistryTests.java | 104 ++++++++ 18 files changed, 893 insertions(+), 103 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheActionTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistryTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java index 7704304b11365..3aed4962dce11 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java @@ -7,9 +7,7 @@ package org.elasticsearch.xpack.core.inference; -import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xcontent.ToXContentObject; -public interface SerializableStats extends ToXContentObject, Writeable { - -} +public interface SerializableStats extends ToXContentObject, NamedWriteable {} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java index afb59f8d4c843..f4e807f4f8430 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.inference.action; import org.apache.http.pool.PoolStats; +import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.support.nodes.BaseNodeResponse; @@ -18,10 +19,12 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; import org.elasticsearch.transport.AbstractTransportRequest; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.SerializableStats; import java.io.IOException; import java.util.List; @@ -116,29 +119,42 @@ public int hashCode() { public static class NodeResponse extends BaseNodeResponse implements ToXContentFragment { static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats"; + static final String INFERENCE_ENDPOINT_REGISTRY_STATS_FIELD_NAME = "inference_endpoint_registry"; private final ConnectionPoolStats connectionPoolStats; + @Nullable + private final SerializableStats inferenceEndpointRegistryStats; - public NodeResponse(DiscoveryNode node, PoolStats poolStats) { + public NodeResponse(DiscoveryNode node, PoolStats poolStats, SerializableStats inferenceEndpointRegistryStats) { super(node); connectionPoolStats = ConnectionPoolStats.of(poolStats); + this.inferenceEndpointRegistryStats = inferenceEndpointRegistryStats; } public NodeResponse(StreamInput in) throws IOException { super(in); connectionPoolStats = new ConnectionPoolStats(in); + inferenceEndpointRegistryStats = in.getTransportVersion().onOrAfter(TransportVersion.current()) + ? in.readOptionalNamedWriteable(SerializableStats.class) + : null; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); connectionPoolStats.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersion.current())) { + out.writeOptionalNamedWriteable(inferenceEndpointRegistryStats); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field(CONNECTION_POOL_STATS_FIELD_NAME, connectionPoolStats, params); + if (inferenceEndpointRegistryStats != null) { + builder.field(INFERENCE_ENDPOINT_REGISTRY_STATS_FIELD_NAME, inferenceEndpointRegistryStats, params); + } return builder; } @@ -147,18 +163,23 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; NodeResponse response = (NodeResponse) o; - return Objects.equals(connectionPoolStats, response.connectionPoolStats); + return Objects.equals(connectionPoolStats, response.connectionPoolStats) + && Objects.equals(inferenceEndpointRegistryStats, response.inferenceEndpointRegistryStats); } @Override public int hashCode() { - return Objects.hash(connectionPoolStats); + return Objects.hash(connectionPoolStats, inferenceEndpointRegistryStats); } ConnectionPoolStats getConnectionPoolStats() { return connectionPoolStats; } + public SerializableStats getInferenceEndpointRegistryStats() { + return inferenceEndpointRegistryStats; + } + static class ConnectionPoolStats implements ToXContentObject, Writeable { static final String LEASED_CONNECTIONS = "leased_connections"; static final String PENDING_CONNECTIONS = "pending_connections"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java index a21354eb5a73d..6c2f5539634f5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java @@ -11,11 +11,17 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.SerializableStats; import java.io.IOException; import java.io.UnsupportedEncodingException; +import java.util.List; public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractWireSerializingTestCase< GetInferenceDiagnosticsAction.NodeResponse> { @@ -23,7 +29,18 @@ public static GetInferenceDiagnosticsAction.NodeResponse createRandom() { DiscoveryNode node = DiscoveryNodeUtils.create("id"); var randomPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt()); - return new GetInferenceDiagnosticsAction.NodeResponse(node, randomPoolStats); + return new GetInferenceDiagnosticsAction.NodeResponse(node, randomPoolStats, new TestStats(randomInt())); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return registryWithTestStats(); + } + + public static NamedWriteableRegistry registryWithTestStats() { + return new NamedWriteableRegistry( + List.of(new NamedWriteableRegistry.Entry(SerializableStats.class, TestStats.NAME, TestStats::new)) + ); } @Override @@ -50,7 +67,8 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference connPoolStats.getPendingConnections(), connPoolStats.getAvailableConnections(), connPoolStats.getMaxConnections() - ) + ), + randomTestStats() ); case 1 -> new GetInferenceDiagnosticsAction.NodeResponse( instance.getNode(), @@ -59,7 +77,8 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference randomInt(), connPoolStats.getAvailableConnections(), connPoolStats.getMaxConnections() - ) + ), + randomTestStats() ); case 2 -> new GetInferenceDiagnosticsAction.NodeResponse( instance.getNode(), @@ -68,7 +87,8 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference connPoolStats.getPendingConnections(), randomInt(), connPoolStats.getMaxConnections() - ) + ), + randomTestStats() ); case 3 -> new GetInferenceDiagnosticsAction.NodeResponse( instance.getNode(), @@ -77,9 +97,37 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference connPoolStats.getPendingConnections(), connPoolStats.getAvailableConnections(), randomInt() - ) + ), + randomTestStats() ); default -> throw new UnsupportedEncodingException(Strings.format("Encountered unsupported case %s", select)); }; } + + public static SerializableStats randomTestStats() { + return new TestStats(randomInt()); + } + + public record TestStats(int count) implements SerializableStats { + public static final String NAME = "test_stats"; + + public TestStats(StreamInput in) throws IOException { + this(in.readInt()); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(count); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject().field("count", count).endObject(); + } + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java index e3eb42efdc791..a9f6df035af28 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java @@ -10,6 +10,7 @@ import org.apache.http.pool.PoolStats; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; @@ -36,7 +37,13 @@ public void testToXContent() throws IOException { var poolStats = new PoolStats(1, 2, 3, 4); var entity = new GetInferenceDiagnosticsAction.Response( ClusterName.DEFAULT, - List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, poolStats)), + List.of( + new GetInferenceDiagnosticsAction.NodeResponse( + node, + poolStats, + new GetInferenceDiagnosticsActionNodeResponseTests.TestStats(5) + ) + ), List.of() ); @@ -46,7 +53,7 @@ public void testToXContent() throws IOException { assertThat(xContentResult, CoreMatchers.is(""" {"id":{"connection_pool_stats":{"leased_connections":1,"pending_connections":2,"available_connections":3,""" + """ - "max_connections":4}}}""")); + "max_connections":4},"inference_endpoint_registry":{"count":5}}}""")); } @Override @@ -67,4 +74,9 @@ protected GetInferenceDiagnosticsAction.Response mutateInstance(GetInferenceDiag List.of() ); } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return GetInferenceDiagnosticsActionNodeResponseTests.registryWithTestStats(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index f8fb375022abb..d2cdf611dde8e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.inference; +import org.elasticsearch.cluster.AbstractNamedDiffable; +import org.elasticsearch.cluster.NamedDiff; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; @@ -17,6 +20,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.core.inference.SerializableStats; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; @@ -31,6 +35,8 @@ import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; +import org.elasticsearch.xpack.inference.registry.ClearInferenceEndpointCacheAction; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettings; @@ -600,6 +606,31 @@ private static void addInternalNamedWriteables(List AbstractNamedDiffable.readDiffFrom( + Metadata.ProjectCustom.class, + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME, + in + ) + ) + ); } private static void addChunkingSettingsNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 2f3bb8dbb5136..5b4bb001dffa9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -104,6 +104,8 @@ import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankDoc; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; +import org.elasticsearch.xpack.inference.registry.ClearInferenceEndpointCacheAction; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistryMetadata; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction; @@ -199,6 +201,27 @@ public class InferencePlugin extends Plugin License.OperationMode.ENTERPRISE ); + public static final Setting INFERENCE_ENDPOINT_CACHE_ENABLED = Setting.boolSetting( + "xpack.inference.cache.enabled", + true, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting INFERENCE_ENDPOINT_CACHE_WEIGHT = Setting.intSetting( + "xpack.inference.cache.weight", + 25, + Setting.Property.NodeScope + ); + + public static final Setting INFERENCE_ENDPOINT_CACHE_EXPIRY = Setting.timeSetting( + "xpack.inference.cache.expiry_time", + TimeValue.timeValueMinutes(15), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueHours(1), + Setting.Property.NodeScope + ); + public static final String X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER = "X-elastic-product-use-case"; public static final String NAME = "inference"; @@ -237,7 +260,8 @@ public List getActions() { new ActionHandler(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class), new ActionHandler(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class), new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class), - new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class) + new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class), + new ActionHandler(ClearInferenceEndpointCacheAction.INSTANCE, ClearInferenceEndpointCacheAction.class) ); } @@ -389,6 +413,16 @@ public Collection createComponents(PluginServices services) { // Add binding for interface -> implementation components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator)); + components.add( + new InferenceEndpointRegistry( + services.clusterService(), + settings, + modelRegistry.get(), + serviceRegistry, + services.projectResolver() + ) + ); + return components; } @@ -443,6 +477,13 @@ public List getNamedXContent() { ModelRegistryMetadata::fromXContent ) ); + namedXContent.add( + new NamedXContentRegistry.Entry( + Metadata.ProjectCustom.class, + new ParseField(ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME), + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata::fromXContent + ) + ); return namedXContent; } @@ -527,6 +568,9 @@ public static Set> getInferenceSettings() { settings.add(SKIP_VALIDATE_AND_START); settings.add(INDICES_INFERENCE_BATCH_SIZE); settings.add(INFERENCE_QUERY_TIMEOUT); + settings.add(INFERENCE_ENDPOINT_CACHE_ENABLED); + settings.add(INFERENCE_ENDPOINT_CACHE_EXPIRY); + settings.add(INFERENCE_ENDPOINT_CACHE_WEIGHT); settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions()); return Collections.unmodifiableSet(settings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 269e0f27fd461..8e34cafa3e878 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -25,7 +25,6 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; @@ -42,7 +41,7 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; import java.io.IOException; @@ -79,7 +78,7 @@ public abstract class BaseTransportInferenceAction { - var serviceName = unparsedModel.service(); + var getModelListener = ActionListener.wrap((Model model) -> { + var serviceName = model.getConfigurations().getService(); try { - validateRequest(request, unparsedModel); + validateRequest(request, model); } catch (Exception e) { - recordRequestDurationMetrics(unparsedModel, timer, e); + recordRequestDurationMetrics(model, timer, e); listener.onFailure(e); return; } @@ -162,15 +161,9 @@ protected void doExecute(Task task, Request request, ActionListener unknownServiceException(serviceName, request.getInferenceEntityId())); validationHelper( - () -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false, - () -> requestModelTaskTypeMismatchException(requestTaskType, unparsedModel.taskType()) - ); - validationHelper( - () -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel), - () -> createInvalidTaskTypeException(request, unparsedModel) + () -> request.getTaskType().isAnyOrSame(model.getTaskType()) == false, + () -> requestModelTaskTypeMismatchException(requestTaskType, model.getTaskType()) ); + validationHelper(() -> isInvalidTaskTypeForInferenceEndpoint(request, model), () -> createInvalidTaskTypeException(request, model)); } - private NodeRoutingDecision determineRouting(String serviceName, Request request, UnparsedModel unparsedModel, String localNodeId) { - var modelTaskType = unparsedModel.taskType(); - + private NodeRoutingDecision determineRouting(String serviceName, Request request, TaskType modelTaskType, String localNodeId) { // Rerouting not supported or request was already rerouted if (inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceName, modelTaskType) == false || request.hasBeenRerouted()) { @@ -274,7 +262,7 @@ public InferenceAction.Response read(StreamInput in) throws IOException { ); } - private void recordRequestDurationMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { + private void recordRequestDurationMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) { Map metricAttributes = new HashMap<>(); metricAttributes.putAll(modelAttributes(model)); metricAttributes.putAll(responseAttributes(unwrapCause(t))); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java index cdd322cfe74f3..9f08440fb44d8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import java.io.IOException; import java.util.List; @@ -32,6 +33,7 @@ public class TransportGetInferenceDiagnosticsAction extends TransportNodesAction Void> { private final HttpClientManager httpClientManager; + private final InferenceEndpointRegistry inferenceEndpointRegistry; @Inject public TransportGetInferenceDiagnosticsAction( @@ -39,7 +41,8 @@ public TransportGetInferenceDiagnosticsAction( ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, - HttpClientManager httpClientManager + HttpClientManager httpClientManager, + InferenceEndpointRegistry inferenceEndpointRegistry ) { super( GetInferenceDiagnosticsAction.NAME, @@ -51,6 +54,7 @@ public TransportGetInferenceDiagnosticsAction( ); this.httpClientManager = Objects.requireNonNull(httpClientManager); + this.inferenceEndpointRegistry = Objects.requireNonNull(inferenceEndpointRegistry); } @Override @@ -74,6 +78,10 @@ protected GetInferenceDiagnosticsAction.NodeResponse newNodeResponse(StreamInput @Override protected GetInferenceDiagnosticsAction.NodeResponse nodeOperation(GetInferenceDiagnosticsAction.NodeRequest request, Task task) { - return new GetInferenceDiagnosticsAction.NodeResponse(transportService.getLocalNode(), httpClientManager.getPoolStats()); + return new GetInferenceDiagnosticsAction.NodeResponse( + transportService.getLocalNode(), + httpClientManager.getPoolStats(), + inferenceEndpointRegistry.cacheEnabled() ? inferenceEndpointRegistry.stats() : null + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index f14d679ba7d26..f0fb0ec82757a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -15,7 +15,6 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.license.XPackLicenseState; @@ -24,7 +23,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; public class TransportInferenceAction extends BaseTransportInferenceAction { @@ -33,7 +32,7 @@ public TransportInferenceAction( TransportService transportService, ActionFilters actionFilters, XPackLicenseState licenseState, - ModelRegistry modelRegistry, + InferenceEndpointRegistry inferenceEndpointRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, @@ -46,7 +45,7 @@ public TransportInferenceAction( transportService, actionFilters, licenseState, - modelRegistry, + inferenceEndpointRegistry, serviceRegistry, inferenceStats, streamingTaskManager, @@ -58,12 +57,12 @@ public TransportInferenceAction( } @Override - protected boolean isInvalidTaskTypeForInferenceEndpoint(InferenceAction.Request request, UnparsedModel unparsedModel) { + protected boolean isInvalidTaskTypeForInferenceEndpoint(InferenceAction.Request request, Model model) { return false; } @Override - protected ElasticsearchStatusException createInvalidTaskTypeException(InferenceAction.Request request, UnparsedModel unparsedModel) { + protected ElasticsearchStatusException createInvalidTaskTypeException(InferenceAction.Request request, Model model) { return null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index d0eef677ca1d3..4fe5dd3a55a12 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -16,7 +16,6 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.license.XPackLicenseState; @@ -29,7 +28,7 @@ import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import java.util.concurrent.Flow; @@ -40,7 +39,7 @@ public TransportUnifiedCompletionInferenceAction( TransportService transportService, ActionFilters actionFilters, XPackLicenseState licenseState, - ModelRegistry modelRegistry, + InferenceEndpointRegistry inferenceEndpointRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, @@ -53,7 +52,7 @@ public TransportUnifiedCompletionInferenceAction( transportService, actionFilters, licenseState, - modelRegistry, + inferenceEndpointRegistry, serviceRegistry, inferenceStats, streamingTaskManager, @@ -65,15 +64,12 @@ public TransportUnifiedCompletionInferenceAction( } @Override - protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.Request request, UnparsedModel unparsedModel) { - return request.getTaskType().isAnyOrSame(TaskType.CHAT_COMPLETION) == false || unparsedModel.taskType() != TaskType.CHAT_COMPLETION; + protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.Request request, Model model) { + return request.getTaskType().isAnyOrSame(TaskType.CHAT_COMPLETION) == false || model.getTaskType() != TaskType.CHAT_COMPLETION; } @Override - protected ElasticsearchStatusException createInvalidTaskTypeException( - UnifiedCompletionAction.Request request, - UnparsedModel unparsedModel - ) { + protected ElasticsearchStatusException createInvalidTaskTypeException(UnifiedCompletionAction.Request request, Model model) { return new ElasticsearchStatusException( "Incompatible task_type for unified API, the requested type [{}] must be one of [{}]", RestStatus.BAD_REQUEST, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java new file mode 100644 index 0000000000000..ed32b66d3f728 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java @@ -0,0 +1,237 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.registry; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.cluster.AbstractNamedDiffable; +import org.elasticsearch.cluster.AckedBatchedClusterStateUpdateTask; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateAckListener; +import org.elasticsearch.cluster.SimpleBatchedAckListenerTaskExecutor; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.metadata.ProjectMetadata; +import org.elasticsearch.cluster.project.ProjectResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.cluster.service.MasterServiceTaskQueue; +import org.elasticsearch.common.Priority; +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Iterator; +import java.util.Objects; + +/** + * Clears the cache in {@link InferenceEndpointRegistry}. This uses a master node transport action, even though most requests will originate + * from the master node (when updating and deleting inference endpoints via REST), because there are some edge cases where deletes can come + * from other nodes. This uses the cluster state to broadcast the message to all nodes to clear their cache, which has guaranteed delivery. + */ +public class ClearInferenceEndpointCacheAction extends AcknowledgedTransportMasterNodeAction { + private static final Logger log = LogManager.getLogger(ClearInferenceEndpointCacheAction.class); + private static final String NAME = "cluster:admin/xpack/inference/clear_inference_endpoint_cache"; + public static final ActionType INSTANCE = new ActionType<>(NAME); + private static final String TASK_QUEUE_NAME = "inference-endpoint-cache-management"; + + private final ProjectResolver projectResolver; + private final InferenceEndpointRegistry inferenceEndpointRegistry; + private final MasterServiceTaskQueue taskQueue; + + @Inject + public ClearInferenceEndpointCacheAction( + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + ProjectResolver projectResolver, + InferenceEndpointRegistry inferenceEndpointRegistry + ) { + super( + NAME, + transportService, + clusterService, + threadPool, + actionFilters, + ClearInferenceEndpointCacheAction.Request::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + this.projectResolver = projectResolver; + this.inferenceEndpointRegistry = inferenceEndpointRegistry; + this.taskQueue = clusterService.createTaskQueue(TASK_QUEUE_NAME, Priority.IMMEDIATE, new CacheMetadataUpdateTaskExecutor()); + clusterService.addListener( + event -> event.state() + .metadata() + .projects() + .values() + .stream() + .map(ProjectMetadata::id) + .filter(id -> event.customMetadataChanged(id, InvalidateCacheMetadata.NAME)) + .peek(id -> log.trace("Trained model cache invalidated on node [{}]", () -> event.state().nodes().getLocalNodeId())) + .forEach(inferenceEndpointRegistry::invalidateAll) + ); + } + + @Override + protected void masterOperation( + Task task, + ClearInferenceEndpointCacheAction.Request request, + ClusterState state, + ActionListener listener + ) { + if (inferenceEndpointRegistry.cacheEnabled()) { + taskQueue.submitTask("invalidateAll", new RefreshCacheMetadataVersionTask(projectResolver.getProjectId(), listener), null); + } else { + listener.onResponse(AcknowledgedResponse.TRUE); + } + } + + @Override + protected ClusterBlockException checkBlock(ClearInferenceEndpointCacheAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } + + public static class Request extends AcknowledgedRequest { + protected Request() { + super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); + } + + protected Request(StreamInput in) throws IOException { + super(in); + } + + @Override + public int hashCode() { + return Objects.hashCode(ackTimeout()); + } + + @Override + public boolean equals(Object other) { + if (other == this) return true; + return other instanceof ClearInferenceEndpointCacheAction.Request that && Objects.equals(that.ackTimeout(), ackTimeout()); + } + } + + public static class InvalidateCacheMetadata extends AbstractNamedDiffable implements Metadata.ProjectCustom { + public static final String NAME = "inference-endpoint-cache-metadata"; + private static final InvalidateCacheMetadata EMPTY = new InvalidateCacheMetadata(0L); + private static final ParseField VERSION_FIELD = new ParseField("version"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + args -> new InvalidateCacheMetadata((long) args[0]) + ); + + static { + PARSER.declareLong(ConstructingObjectParser.constructorArg(), VERSION_FIELD); + } + + public static InvalidateCacheMetadata fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static InvalidateCacheMetadata fromMetadata(ProjectMetadata projectMetadata) { + InvalidateCacheMetadata metadata = projectMetadata.custom(NAME); + return metadata == null ? EMPTY : metadata; + } + + private final long version; + + private InvalidateCacheMetadata(long version) { + this.version = version; + } + + public InvalidateCacheMetadata(StreamInput in) throws IOException { + this(in.readVLong()); + } + + public InvalidateCacheMetadata bumpVersion() { + return new InvalidateCacheMetadata(version < Long.MAX_VALUE ? version + 1 : 1L); + } + + @Override + public EnumSet context() { + return Metadata.ALL_CONTEXTS; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(version); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params ignored) { + return Iterators.single(((builder, params) -> builder.field(VERSION_FIELD.getPreferredName(), version))); + } + + @Override + public int hashCode() { + return Objects.hashCode(version); + } + + @Override + public boolean equals(Object other) { + if (other == this) return true; + return other instanceof InvalidateCacheMetadata that && that.version == this.version; + } + } + + private static class RefreshCacheMetadataVersionTask extends AckedBatchedClusterStateUpdateTask { + private final ProjectId projectId; + + private RefreshCacheMetadataVersionTask(ProjectId projectId, ActionListener listener) { + super(TimeValue.THIRTY_SECONDS, listener); + this.projectId = projectId; + } + } + + private static class CacheMetadataUpdateTaskExecutor extends SimpleBatchedAckListenerTaskExecutor { + @Override + public Tuple executeTask(RefreshCacheMetadataVersionTask task, ClusterState clusterState) { + var projectMetadata = clusterState.metadata().getProject(task.projectId); + var currentMetadata = InvalidateCacheMetadata.fromMetadata(projectMetadata); + var updatedMetadata = currentMetadata.bumpVersion(); + var newProjectMetadata = ProjectMetadata.builder(projectMetadata).putCustom(InvalidateCacheMetadata.NAME, updatedMetadata); + return new Tuple<>(ClusterState.builder(clusterState).putProjectMetadata(newProjectMetadata).build(), task); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java new file mode 100644 index 0000000000000..8d4ce68b42b78 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java @@ -0,0 +1,170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.registry; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.project.ProjectResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.cache.Cache; +import org.elasticsearch.common.cache.CacheBuilder; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.SerializableStats; + +import java.io.IOException; +import java.util.stream.StreamSupport; + +import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_ENDPOINT_CACHE_ENABLED; +import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_ENDPOINT_CACHE_EXPIRY; +import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_ENDPOINT_CACHE_WEIGHT; + +/** + * A registry that assembles and caches Inference Endpoints, {@link Model}, for reuse. + * Models are high read and minimally written, where changes only occur during updates and deletes. + * The cache is invalidated via the {@link ClearInferenceEndpointCacheAction} so that every node gets the invalidation + * message. + */ +public class InferenceEndpointRegistry { + + private static final Logger log = LogManager.getLogger(InferenceEndpointRegistry.class); + private final ModelRegistry modelRegistry; + private final InferenceServiceRegistry serviceRegistry; + private final ProjectResolver projectResolver; + private final Cache cache; + private volatile boolean cacheEnabled; + + public InferenceEndpointRegistry( + ClusterService clusterService, + Settings settings, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + ProjectResolver projectResolver + ) { + this.modelRegistry = modelRegistry; + this.serviceRegistry = serviceRegistry; + this.projectResolver = projectResolver; + this.cache = CacheBuilder.builder() + .setMaximumWeight(INFERENCE_ENDPOINT_CACHE_WEIGHT.get(settings)) + .setExpireAfterWrite(INFERENCE_ENDPOINT_CACHE_EXPIRY.get(settings)) + .build(); + this.cacheEnabled = INFERENCE_ENDPOINT_CACHE_ENABLED.get(settings); + + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(INFERENCE_ENDPOINT_CACHE_ENABLED, enabled -> this.cacheEnabled = enabled); + } + + public void getEndpoint(String inferenceEntityId, ActionListener listener) { + var key = new InferenceIdAndProject(inferenceEntityId, projectResolver.getProjectId()); + var cachedModel = cacheEnabled ? cache.get(key) : null; + if (cachedModel != null) { + log.debug("Retrieved [{}] from cache.", inferenceEntityId); + listener.onResponse(cachedModel); + } else { + loadFromIndex(key, listener); + } + } + + void invalidateAll(ProjectId projectId) { + // copy to an interim list because cache.keys() does not allow inline mutations + if (cacheEnabled) { + StreamSupport.stream(cache.keys().spliterator(), false) + .filter(key -> key.projectId.equals(projectId)) + .toList() + .forEach(cache::invalidate); + } + } + + private void loadFromIndex(InferenceIdAndProject idAndProject, ActionListener listener) { + modelRegistry.getModelWithSecrets(idAndProject.inferenceEntityId(), listener.delegateFailureAndWrap((l, unparsedModel) -> { + var service = serviceRegistry.getService(unparsedModel.service()) + .orElseThrow( + () -> new ResourceNotFoundException( + "Unknown service [{}] for model [{}]", + unparsedModel.service(), + idAndProject.inferenceEntityId() + ) + ); + + var model = service.parsePersistedConfigWithSecrets( + unparsedModel.inferenceEntityId(), + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ); + + if (cacheEnabled) { + cache.put(idAndProject, model); + } + l.onResponse(model); + })); + } + + public Stats stats() { + return cacheEnabled ? new Stats(cache.stats()) : Stats.EMPTY; + } + + public boolean cacheEnabled() { + return cacheEnabled; + } + + private record InferenceIdAndProject(String inferenceEntityId, ProjectId projectId) {} + + public record Stats(Cache.Stats stats) implements SerializableStats { + public static final String NAME = "inference_endpoint_registry_stats"; + private static final String CACHE_HITS = "cache_hits"; + private static final String CACHE_MISSES = "cache_misses"; + private static final String CACHE_EVICTIONS = "cache_evictions"; + + private static final Stats EMPTY = new Stats(new Cache.Stats(0, 0, 0)); + + public Stats(StreamInput in) throws IOException { + this(new Cache.Stats(in.readLong(), in.readLong(), in.readLong())); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeLong(stats.getHits()); + out.writeLong(stats.getMisses()); + out.writeLong(stats.getEvictions()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject() + .field(CACHE_HITS, stats.getHits()) + .field(CACHE_MISSES, stats.getMisses()) + .field(CACHE_EVICTIONS, stats.getEvictions()) + .endObject(); + } + + public long hits() { + return stats.getHits(); + } + + public long misses() { + return stats.getMisses(); + } + + public long evictions() { + return stats.getEvictions(); + } + + @Override + public String getWriteableName() { + return NAME; + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index fe7c4a9395cd1..7cd1cf5999d11 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -614,6 +614,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi } else { // since updating the secrets was successful, we can remove the lock and respond to the final listener preventDeletionLock.remove(inferenceEntityId); + refreshInferenceEndpointCache(); finalListener.onResponse(true); } }).andThen((subListener, configResponse) -> { @@ -844,7 +845,10 @@ private void deleteModels(Set inferenceEntityIds, boolean updateClusterS client.execute( DeleteByQueryAction.INSTANCE, request, - getDeleteModelClusterStateListener(inferenceEntityIds, updateClusterState, listener) + ActionListener.runAfter( + getDeleteModelClusterStateListener(inferenceEntityIds, updateClusterState, listener), + this::refreshInferenceEndpointCache + ) ); } @@ -899,6 +903,17 @@ public void onFailure(Exception exc) { }; } + private void refreshInferenceEndpointCache() { + client.execute( + ClearInferenceEndpointCacheAction.INSTANCE, + new ClearInferenceEndpointCacheAction.Request(), + ActionListener.wrap( + ignored -> logger.debug("Successfully refreshed inference endpoint cache."), + e -> logger.atDebug().withThrowable(e).log("Failed to refresh inference endpoint cache.") + ) + ); + } + private static DeleteByQueryRequest createDeleteRequest(Set inferenceEntityIds) { DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); request.indices(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index 812cd1e3c6d7f..47053b7cbe5eb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -20,7 +20,6 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.license.MockLicenseState; @@ -34,11 +33,10 @@ import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import org.junit.Before; import org.mockito.ArgumentCaptor; -import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.Flow; @@ -59,7 +57,7 @@ public abstract class BaseTransportInferenceActionTestCase extends ESTestCase { private MockLicenseState licenseState; - private ModelRegistry modelRegistry; + private InferenceEndpointRegistry inferenceEndpointRegistry; private StreamingTaskManager streamingTaskManager; private BaseTransportInferenceAction action; private ThreadPool threadPool; @@ -87,7 +85,7 @@ public void setUp() throws Exception { transportService = mock(); inferenceServiceRateLimitCalculator = mock(); licenseState = mock(); - modelRegistry = mock(); + inferenceEndpointRegistry = mock(); serviceRegistry = mock(); inferenceStats = InferenceStatsTests.mockInferenceStats(); streamingTaskManager = mock(); @@ -96,7 +94,7 @@ public void setUp() throws Exception { transportService, actionFilters, licenseState, - modelRegistry, + inferenceEndpointRegistry, serviceRegistry, inferenceStats, streamingTaskManager, @@ -113,7 +111,7 @@ protected abstract BaseTransportInferenceAction createAction( TransportService transportService, ActionFilters actionFilters, MockLicenseState licenseState, - ModelRegistry modelRegistry, + InferenceEndpointRegistry inferenceEndpointRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, @@ -132,7 +130,7 @@ public void testMetricsAfterModelRegistryError() { ActionListener listener = ans.getArgument(1); listener.onFailure(expectedException); return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); + }).when(inferenceEndpointRegistry).getEndpoint(any(), any()); doExecute(taskType); @@ -168,7 +166,7 @@ public void onFailure(Exception e) {} } public void testMetricsAfterMissingService() { - mockModelRegistry(taskType); + mockInferenceEndpointRegistry(taskType); when(serviceRegistry.getService(any())).thenReturn(Optional.empty()); @@ -190,19 +188,18 @@ public void testMetricsAfterMissingService() { })); } - protected void mockModelRegistry(TaskType expectedTaskType) { - var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of()); + protected void mockInferenceEndpointRegistry(TaskType expectedTaskType) { doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onResponse(unparsedModel); + ActionListener listener = ans.getArgument(1); + listener.onResponse(mockModel(expectedTaskType)); return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); + }).when(inferenceEndpointRegistry).getEndpoint(any(), any()); } public void testMetricsAfterUnknownTaskType() { var modelTaskType = TaskType.RERANK; var requestTaskType = TaskType.SPARSE_EMBEDDING; - mockModelRegistry(modelTaskType); + mockInferenceEndpointRegistry(modelTaskType); when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); var listener = doExecute(requestTaskType); @@ -363,7 +360,7 @@ public void testProductUseCaseHeaderPresentInThreadContextIfPresent() { when(threadPool.getThreadContext()).thenReturn(threadContext); - mockModelRegistry(taskType); + mockInferenceEndpointRegistry(taskType); mockService(listener -> listener.onResponse(mock())); Request request = createRequest(); @@ -431,30 +428,25 @@ protected void mockService( listenerAction.accept(ans.getArgument(3)); return null; }).when(service).unifiedCompletionInfer(any(), any(), any(), any()); - mockModelAndServiceRegistry(service); + mockInferenceEndpointRegistry(taskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); } protected Model mockModel() { + return mockModel(taskType); + } + + protected Model mockModel(TaskType expectedTaskType) { Model model = mock(); ModelConfigurations modelConfigurations = mock(); when(modelConfigurations.getService()).thenReturn(serviceId); when(model.getConfigurations()).thenReturn(modelConfigurations); - when(model.getTaskType()).thenReturn(taskType); + when(model.getTaskType()).thenReturn(expectedTaskType); when(model.getServiceSettings()).thenReturn(mock()); + when(model.getInferenceEntityId()).thenReturn(inferenceId); return model; } - protected void mockModelAndServiceRegistry(InferenceService service) { - var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of()); - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - - when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); - } - protected void mockValidLicenseState() { when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(true); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index 547078d93acc4..dd0a1b952233b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -22,7 +22,7 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.common.RateLimitAssignment; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import java.util.List; @@ -49,7 +49,7 @@ protected BaseTransportInferenceAction createAction( TransportService transportService, ActionFilters actionFilters, MockLicenseState licenseState, - ModelRegistry modelRegistry, + InferenceEndpointRegistry inferenceEndpointRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, @@ -61,7 +61,7 @@ protected BaseTransportInferenceAction createAction( transportService, actionFilters, licenseState, - modelRegistry, + inferenceEndpointRegistry, serviceRegistry, inferenceStats, streamingTaskManager, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java index 9e6f4a6260936..0b05509acaf8b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import java.util.Optional; @@ -45,7 +45,7 @@ protected BaseTransportInferenceAction createAc TransportService transportService, ActionFilters actionFilters, MockLicenseState licenseState, - ModelRegistry modelRegistry, + InferenceEndpointRegistry inferenceEndpointRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, @@ -57,7 +57,7 @@ protected BaseTransportInferenceAction createAc transportService, actionFilters, licenseState, - modelRegistry, + inferenceEndpointRegistry, serviceRegistry, inferenceStats, streamingTaskManager, @@ -75,7 +75,7 @@ protected UnifiedCompletionAction.Request createRequest() { public void testThrows_IncompatibleTaskTypeException_WhenUsingATextEmbeddingInferenceEndpoint() { var modelTaskType = TaskType.TEXT_EMBEDDING; var requestTaskType = TaskType.TEXT_EMBEDDING; - mockModelRegistry(modelTaskType); + mockInferenceEndpointRegistry(modelTaskType); when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); var listener = doExecute(requestTaskType); @@ -100,7 +100,7 @@ public void testThrows_IncompatibleTaskTypeException_WhenUsingATextEmbeddingInfe public void testThrows_IncompatibleTaskTypeException_WhenUsingRequestIsAny_ModelIsTextEmbedding() { var modelTaskType = TaskType.ANY; var requestTaskType = TaskType.TEXT_EMBEDDING; - mockModelRegistry(modelTaskType); + mockInferenceEndpointRegistry(modelTaskType); when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); var listener = doExecute(requestTaskType); @@ -123,7 +123,7 @@ public void testThrows_IncompatibleTaskTypeException_WhenUsingRequestIsAny_Model } public void testMetricsAfterUnifiedInferSuccess_WithRequestTaskTypeAny() { - mockModelRegistry(TaskType.COMPLETION); + mockInferenceEndpointRegistry(TaskType.COMPLETION); mockService(listener -> listener.onResponse(mock())); var listener = doExecute(TaskType.ANY); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheActionTests.java new file mode 100644 index 0000000000000..b3e3e0adad5aa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheActionTests.java @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.registry; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class ClearInferenceEndpointCacheActionTests extends ESSingleNodeTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private static final String INFERENCE_ENDPOINT_ID = "1"; + + @Override + protected Collection> getPlugins() { + return List.of(LocalStateInferencePlugin.class); + } + + @Override + protected Settings nodeSettings() { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + public void testCacheEviction() throws Exception { + storeGoodEndpoint(); + invokeEndpoint(); + + var stats = cacheStats(); + assertThat(stats.hits(), equalTo(0L)); + assertThat(stats.misses(), equalTo(1L)); + assertThat(stats.evictions(), equalTo(0L)); + + var listener = new PlainActionFuture(); + clusterAdmin().execute(ClearInferenceEndpointCacheAction.INSTANCE, new ClearInferenceEndpointCacheAction.Request(), listener); + assertTrue(listener.actionGet(TIMEOUT).isAcknowledged()); + + assertBusy(() -> { + var nextStats = cacheStats(); + assertThat(nextStats.hits(), equalTo(0L)); + assertThat(nextStats.misses(), equalTo(1L)); + assertThat(nextStats.evictions(), equalTo(1L)); + }, 10, TimeUnit.SECONDS); + + invokeEndpoint(); + stats = cacheStats(); + assertThat(stats.hits(), equalTo(0L)); + assertThat(stats.misses(), equalTo(2L)); + assertThat(stats.evictions(), equalTo(1L)); + } + + private void storeGoodEndpoint() throws IOException { + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("service", TestSparseInferenceServiceExtension.TestInferenceService.NAME); + builder.field("service_settings", Map.of("model", "model", "api_key", "1234")); + builder.endObject(); + + content = BytesReference.bytes(builder); + } + + var request = new PutInferenceModelAction.Request( + TaskType.SPARSE_EMBEDDING, + INFERENCE_ENDPOINT_ID, + content, + XContentType.JSON, + TEST_REQUEST_TIMEOUT + ); + client().execute(PutInferenceModelAction.INSTANCE, request).actionGet(TIMEOUT); + } + + private void invokeEndpoint() { + client().execute( + InferenceAction.INSTANCE, + new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, + INFERENCE_ENDPOINT_ID, + null, + null, + null, + List.of("hello"), + null, + InputType.INTERNAL_SEARCH, + TIMEOUT, + false + ) + ).actionGet(TIMEOUT); + } + + private InferenceEndpointRegistry.Stats cacheStats() { + var diagnostics = client().execute(GetInferenceDiagnosticsAction.INSTANCE, new GetInferenceDiagnosticsAction.Request()) + .actionGet(TIMEOUT); + + assertThat(diagnostics.getNodes(), hasSize(1)); + return diagnostics.getNodes().getFirst().getInferenceEndpointRegistryStats() instanceof InferenceEndpointRegistry.Stats stats + ? stats + : null; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistryTests.java new file mode 100644 index 0000000000000..72687768f5474 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistryTests.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.registry; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mock.AbstractTestInferenceService; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.junit.Before; + +import java.util.Collection; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.registry.ModelRegistryTests.assertStoreModel; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class InferenceEndpointRegistryTests extends ESSingleNodeTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + + private InferenceEndpointRegistry inferenceEndpointRegistry; + private ModelRegistry registry; + + @Override + protected Collection> getPlugins() { + return List.of(LocalStateInferencePlugin.class); + } + + @Before + public void createComponents() { + inferenceEndpointRegistry = node().injector().getInstance(InferenceEndpointRegistry.class); + registry = node().injector().getInstance(ModelRegistry.class); + } + + public void testGetThrowsResourceNotFoundWhenNoHitsReturned() { + assertThat( + getEndpointException("this is not found", ResourceNotFoundException.class).getMessage(), + is("Inference endpoint not found [this is not found]") + ); + } + + private Exception getEndpointException(String id, Class expectedExceptionClass) { + var listener = new PlainActionFuture(); + inferenceEndpointRegistry.getEndpoint(id, listener); + return expectThrows(expectedExceptionClass, () -> listener.actionGet(TIMEOUT)); + } + + public void testGetModel() { + var expectedEndpoint = storeGoodEndpoint("1"); + var actualEndpoint = getEndpoint("1"); + assertThat(actualEndpoint, equalTo(expectedEndpoint)); + assertThat(getEndpoint("1"), sameInstance(actualEndpoint)); + } + + private Model storeGoodEndpoint(String id) { + var expectedEndpoint = new AbstractTestInferenceService.TestServiceModel( + id, + TaskType.SPARSE_EMBEDDING, + "test_service", + new TestSparseInferenceServiceExtension.TestServiceSettings("model", null, false), + new AbstractTestInferenceService.TestTaskSettings(randomInt(3)), + new AbstractTestInferenceService.TestSecretSettings("secret") + ); + assertStoreModel(registry, expectedEndpoint); + return expectedEndpoint; + } + + private Model getEndpoint(String id) { + var listener = new PlainActionFuture(); + inferenceEndpointRegistry.getEndpoint(id, listener); + return listener.actionGet(TIMEOUT); + } + + public void testGetModelWithUnknownService() { + var id = "ahhhh"; + var expectedEndpoint = new AbstractTestInferenceService.TestServiceModel( + id, + TaskType.SPARSE_EMBEDDING, + "hello", + new TestSparseInferenceServiceExtension.TestServiceSettings("model", null, false), + new AbstractTestInferenceService.TestTaskSettings(randomInt(3)), + new AbstractTestInferenceService.TestSecretSettings("secret") + ); + assertStoreModel(registry, expectedEndpoint); + + assertThat( + getEndpointException(id, ResourceNotFoundException.class).getMessage(), + equalTo("Unknown service [hello] for model [ahhhh]") + ); + } +} From 6e38b323e2359c09fab2ec2a4dc9c17887d8ac06 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Fri, 29 Aug 2025 16:50:54 -0400 Subject: [PATCH 02/10] Update docs/changelog/133860.yaml --- docs/changelog/133860.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/133860.yaml diff --git a/docs/changelog/133860.yaml b/docs/changelog/133860.yaml new file mode 100644 index 0000000000000..f036902602579 --- /dev/null +++ b/docs/changelog/133860.yaml @@ -0,0 +1,6 @@ +pr: 133860 +summary: Cache Inference Endpoints +area: Machine Learning +type: enhancement +issues: + - 133135 From 85ba05a0a5f0cd49311a93da3e89cb08d15ad611 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Fri, 29 Aug 2025 16:55:37 -0400 Subject: [PATCH 03/10] Update transport version --- .../src/main/java/org/elasticsearch/TransportVersions.java | 1 + .../inference/action/GetInferenceDiagnosticsAction.java | 7 ++++--- .../registry/ClearInferenceEndpointCacheAction.java | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2c2aaaf322844..e44c3cd864d88 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -358,6 +358,7 @@ static TransportVersion def(int id) { public static final TransportVersion ALLOCATION_DECISION_NOT_PREFERRED = def(9_145_0_00); public static final TransportVersion ESQL_QUALIFIERS_IN_ATTRIBUTES = def(9_146_0_00); public static final TransportVersion PROJECT_RESERVED_STATE_MOVE_TO_REGISTRY = def(9_147_0_00); + public static final TransportVersion ML_INFERENCE_ENDPOINT_CACHE = def(9_148_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java index f4e807f4f8430..289cd287dfb31 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.core.inference.action; import org.apache.http.pool.PoolStats; -import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.support.nodes.BaseNodeResponse; @@ -30,6 +29,8 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.TransportVersions.ML_INFERENCE_ENDPOINT_CACHE; + public class GetInferenceDiagnosticsAction extends ActionType { public static final GetInferenceDiagnosticsAction INSTANCE = new GetInferenceDiagnosticsAction(); @@ -135,7 +136,7 @@ public NodeResponse(StreamInput in) throws IOException { super(in); connectionPoolStats = new ConnectionPoolStats(in); - inferenceEndpointRegistryStats = in.getTransportVersion().onOrAfter(TransportVersion.current()) + inferenceEndpointRegistryStats = in.getTransportVersion().onOrAfter(ML_INFERENCE_ENDPOINT_CACHE) ? in.readOptionalNamedWriteable(SerializableStats.class) : null; } @@ -144,7 +145,7 @@ public NodeResponse(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); connectionPoolStats.writeTo(out); - if (out.getTransportVersion().onOrAfter(TransportVersion.current())) { + if (out.getTransportVersion().onOrAfter(ML_INFERENCE_ENDPOINT_CACHE)) { out.writeOptionalNamedWriteable(inferenceEndpointRegistryStats); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java index ed32b66d3f728..8faf8aadceefc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java @@ -50,6 +50,8 @@ import java.util.Iterator; import java.util.Objects; +import static org.elasticsearch.TransportVersions.ML_INFERENCE_ENDPOINT_CACHE; + /** * Clears the cache in {@link InferenceEndpointRegistry}. This uses a master node transport action, even though most requests will originate * from the master node (when updating and deleting inference endpoints via REST), because there are some edge cases where deletes can come @@ -185,7 +187,7 @@ public EnumSet context() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.current(); + return ML_INFERENCE_ENDPOINT_CACHE; } @Override From 99a3b14e9bd81bd3407c0b1d8aa7643f4fc9426a Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 2 Sep 2025 09:00:36 -0400 Subject: [PATCH 04/10] fix tests; add permissions --- .../inference/CreateFromDeploymentIT.java | 27 +++++-------------- .../xpack/security/operator/Constants.java | 1 + 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java index 9701fff2a5789..48e1a74b7b4a3 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java @@ -110,6 +110,8 @@ public void testAttachWithModelId() throws IOException { var results = infer(inferenceId, List.of("washing machine")); assertNotNull(results.get("sparse_embedding")); + deleteModel(inferenceId); + forceStopMlNodeDeployment(deploymentId); } @@ -225,6 +227,7 @@ public void testNumAllocationsIsUpdated() throws IOException { ) ); + deleteModel(inferenceId); forceStopMlNodeDeployment(deploymentId); } @@ -266,6 +269,7 @@ public void testUpdateWhenInferenceEndpointCreatesDeployment() throws IOExceptio is(Map.of("num_allocations", 2, "num_threads", 1, "model_id", modelId)) ); + deleteModel(inferenceId); forceStopMlNodeDeployment(deploymentId); } @@ -309,6 +313,8 @@ public void testCannotUpdateAnotherInferenceEndpointsCreatedDeployment() throws ) ); + deleteModel(inferenceId); + deleteModel(secondInferenceId); forceStopMlNodeDeployment(deploymentId); } @@ -331,6 +337,7 @@ public void testStoppingDeploymentAttachedToInferenceEndpoint() throws IOExcepti ) ); + deleteModel(inferenceId); // Force stop will stop the deployment forceStopMlNodeDeployment(deploymentId); } @@ -358,16 +365,6 @@ private String endpointConfig(String modelId, String deploymentId) { """, modelId, deploymentId); } - private String updatedEndpointConfig(int numAllocations) { - return Strings.format(""" - { - "service_settings": { - "num_allocations": %d - } - } - """, numAllocations); - } - private Response startMlNodeDeploymemnt(String modelId, String deploymentId) throws IOException { String endPoint = "/_ml/trained_models/" + modelId @@ -413,16 +410,6 @@ private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations return client().performRequest(request); } - private Map updateMlNodeDeploymemnt(String deploymentId, String body) throws IOException { - String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update"; - - Request request = new Request("POST", endPoint); - request.setJsonEntity(body); - var response = client().performRequest(request); - assertStatusOkOrCreated(response); - return entityAsMap(response); - } - protected void stopMlNodeDeployment(String deploymentId) throws IOException { String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop"; Request request = new Request("POST", endpoint); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 485b6989aca58..bcfd54dc06908 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -174,6 +174,7 @@ public class Constants { "cluster:admin/xpack/enrich/get", "cluster:admin/xpack/enrich/put", "cluster:admin/xpack/enrich/reindex", + "cluster:admin/xpack/inference/clear_inference_endpoint_cache", "cluster:admin/xpack/inference/delete", "cluster:admin/xpack/inference/put", "cluster:admin/xpack/inference/update", From 6d405eacb5c5c1d86ab7533bc0b72175553dac97 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 2 Sep 2025 10:36:45 -0400 Subject: [PATCH 05/10] use BWC --- ...nceDiagnosticsActionNodeResponseTests.java | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java index 6c2f5539634f5..9333dfaebc3b3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java @@ -8,6 +8,8 @@ package org.elasticsearch.xpack.core.inference.action; import org.apache.http.pool.PoolStats; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.common.Strings; @@ -15,15 +17,15 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.SerializableStats; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.util.List; -public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractWireSerializingTestCase< +public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractBWCWireSerializationTestCase< GetInferenceDiagnosticsAction.NodeResponse> { public static GetInferenceDiagnosticsAction.NodeResponse createRandom() { DiscoveryNode node = DiscoveryNodeUtils.create("id"); @@ -130,4 +132,25 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder.startObject().field("count", count).endObject(); } } + + @Override + protected GetInferenceDiagnosticsAction.NodeResponse mutateInstanceForVersion( + GetInferenceDiagnosticsAction.NodeResponse instance, + TransportVersion version + ) { + if (version.before(TransportVersions.ML_INFERENCE_ENDPOINT_CACHE)) { + return new GetInferenceDiagnosticsAction.NodeResponse( + instance.getNode(), + new PoolStats( + instance.getConnectionPoolStats().getLeasedConnections(), + instance.getConnectionPoolStats().getPendingConnections(), + instance.getConnectionPoolStats().getAvailableConnections(), + instance.getConnectionPoolStats().getMaxConnections() + ), + null + ); + } else { + return instance; + } + } } From a1c3cdccf94a31e1ce4d4dc6450b048a2e4fbddd Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 2 Sep 2025 13:15:30 -0400 Subject: [PATCH 06/10] Add writeable entry to ML tests --- .../ml/integration/MlNativeIntegTestCase.java | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java index 86aadcb0ec1d4..b33fdf13a3cb8 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.cluster.AbstractNamedDiffable; import org.elasticsearch.cluster.ClusterModule; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.NamedDiff; @@ -59,6 +60,7 @@ import org.elasticsearch.test.XContentTestUtils; import org.elasticsearch.transport.netty4.Netty4Plugin; import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xpack.autoscaling.Autoscaling; import org.elasticsearch.xpack.autoscaling.AutoscalingMetadata; import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderResult; @@ -99,6 +101,7 @@ import org.elasticsearch.xpack.esql.core.plugin.EsqlCorePlugin; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.elasticsearch.xpack.ilm.IndexLifecycle; +import org.elasticsearch.xpack.inference.registry.ClearInferenceEndpointCacheAction; import org.elasticsearch.xpack.inference.registry.ModelRegistryMetadata; import org.elasticsearch.xpack.ml.LocalStateMachineLearning; import org.elasticsearch.xpack.ml.autoscaling.MlScalingReason; @@ -436,6 +439,24 @@ protected void assertClusterRoundTrip() throws IOException { new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new) ); entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom)); + entries.add( + new NamedWriteableRegistry.Entry( + Metadata.ProjectCustom.class, + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME, + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata::new + ) + ); + entries.add( + new NamedWriteableRegistry.Entry( + NamedDiff.class, + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME, + in -> AbstractNamedDiffable.readDiffFrom( + Metadata.ProjectCustom.class, + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME, + in + ) + ) + ); // Retrieve the cluster state from a random node, and serialize and deserialize it. final ClusterStateResponse clusterStateResponse = client().admin() From e0d442d4703b38ad02eafaae4ec54c4e09907d3f Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 2 Sep 2025 17:22:45 +0000 Subject: [PATCH 07/10] [CI] Auto commit changes from spotless --- .../xpack/ml/integration/MlNativeIntegTestCase.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java index b33fdf13a3cb8..bfce552b094f6 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java @@ -60,7 +60,6 @@ import org.elasticsearch.test.XContentTestUtils; import org.elasticsearch.transport.netty4.Netty4Plugin; import org.elasticsearch.xcontent.NamedXContentRegistry; -import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xpack.autoscaling.Autoscaling; import org.elasticsearch.xpack.autoscaling.AutoscalingMetadata; import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderResult; From 97da39677edec70c903b825816d81a757b13dd85 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 9 Sep 2025 10:34:48 -0400 Subject: [PATCH 08/10] address comments --- .../core/inference/SerializableStats.java | 13 --- .../action/GetInferenceDiagnosticsAction.java | 45 ++++++-- ...nceDiagnosticsActionNodeResponseTests.java | 62 +++-------- ...ferenceDiagnosticsActionResponseTests.java | 32 +++--- .../InferenceNamedWriteablesProvider.java | 9 -- .../xpack/inference/InferencePlugin.java | 25 +---- .../ClearInferenceEndpointCacheAction.java | 4 +- .../registry/InferenceEndpointRegistry.java | 103 +++++++----------- ...learInferenceEndpointCacheActionTests.java | 6 +- .../InferenceEndpointRegistryTests.java | 4 +- 10 files changed, 118 insertions(+), 185 deletions(-) delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java deleted file mode 100644 index 3aed4962dce11..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference; - -import org.elasticsearch.common.io.stream.NamedWriteable; -import org.elasticsearch.xcontent.ToXContentObject; - -public interface SerializableStats extends ToXContentObject, NamedWriteable {} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java index 289cd287dfb31..8a2f4e657e786 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.nodes.BaseNodesResponse; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -23,7 +24,6 @@ import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.SerializableStats; import java.io.IOException; import java.util.List; @@ -124,12 +124,12 @@ public static class NodeResponse extends BaseNodeResponse implements ToXContentF private final ConnectionPoolStats connectionPoolStats; @Nullable - private final SerializableStats inferenceEndpointRegistryStats; + private final Stats inferenceEndpointRegistryStats; - public NodeResponse(DiscoveryNode node, PoolStats poolStats, SerializableStats inferenceEndpointRegistryStats) { + public NodeResponse(DiscoveryNode node, PoolStats poolStats, @Nullable Cache.Stats inferenceEndpointRegistryStats) { super(node); connectionPoolStats = ConnectionPoolStats.of(poolStats); - this.inferenceEndpointRegistryStats = inferenceEndpointRegistryStats; + this.inferenceEndpointRegistryStats = inferenceEndpointRegistryStats != null ? Stats.of(inferenceEndpointRegistryStats) : null; } public NodeResponse(StreamInput in) throws IOException { @@ -137,7 +137,7 @@ public NodeResponse(StreamInput in) throws IOException { connectionPoolStats = new ConnectionPoolStats(in); inferenceEndpointRegistryStats = in.getTransportVersion().onOrAfter(ML_INFERENCE_ENDPOINT_CACHE) - ? in.readOptionalNamedWriteable(SerializableStats.class) + ? in.readOptionalWriteable(Stats::new) : null; } @@ -146,7 +146,7 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); connectionPoolStats.writeTo(out); if (out.getTransportVersion().onOrAfter(ML_INFERENCE_ENDPOINT_CACHE)) { - out.writeOptionalNamedWriteable(inferenceEndpointRegistryStats); + out.writeOptionalWriteable(inferenceEndpointRegistryStats); } } @@ -177,7 +177,7 @@ ConnectionPoolStats getConnectionPoolStats() { return connectionPoolStats; } - public SerializableStats getInferenceEndpointRegistryStats() { + public Stats getInferenceEndpointRegistryStats() { return inferenceEndpointRegistryStats; } @@ -262,5 +262,36 @@ int getMaxConnections() { return maxConnections; } } + + public record Stats(long hits, long misses, long evictions) implements ToXContentObject, Writeable { + + private static final String CACHE_HITS = "cache_hits"; + private static final String CACHE_MISSES = "cache_misses"; + private static final String CACHE_EVICTIONS = "cache_evictions"; + + public Stats(StreamInput in) throws IOException { + this(in.readLong(), in.readLong(), in.readLong()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeLong(hits); + out.writeLong(misses); + out.writeLong(evictions); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject() + .field(CACHE_HITS, hits) + .field(CACHE_MISSES, misses) + .field(CACHE_EVICTIONS, evictions) + .endObject(); + } + + public static Stats of(Cache.Stats cacheStats) { + return new Stats(cacheStats.getHits(), cacheStats.getMisses(), cacheStats.getEvictions()); + } + } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java index 9333dfaebc3b3..ec9e917a0e503 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java @@ -13,17 +13,12 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.SerializableStats; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; import java.io.UnsupportedEncodingException; -import java.util.List; public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractBWCWireSerializationTestCase< GetInferenceDiagnosticsAction.NodeResponse> { @@ -31,18 +26,7 @@ public static GetInferenceDiagnosticsAction.NodeResponse createRandom() { DiscoveryNode node = DiscoveryNodeUtils.create("id"); var randomPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt()); - return new GetInferenceDiagnosticsAction.NodeResponse(node, randomPoolStats, new TestStats(randomInt())); - } - - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return registryWithTestStats(); - } - - public static NamedWriteableRegistry registryWithTestStats() { - return new NamedWriteableRegistry( - List.of(new NamedWriteableRegistry.Entry(SerializableStats.class, TestStats.NAME, TestStats::new)) - ); + return new GetInferenceDiagnosticsAction.NodeResponse(node, randomPoolStats, randomCacheStats()); } @Override @@ -70,7 +54,7 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference connPoolStats.getAvailableConnections(), connPoolStats.getMaxConnections() ), - randomTestStats() + randomCacheStats() ); case 1 -> new GetInferenceDiagnosticsAction.NodeResponse( instance.getNode(), @@ -80,7 +64,7 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference connPoolStats.getAvailableConnections(), connPoolStats.getMaxConnections() ), - randomTestStats() + randomCacheStats() ); case 2 -> new GetInferenceDiagnosticsAction.NodeResponse( instance.getNode(), @@ -90,7 +74,7 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference randomInt(), connPoolStats.getMaxConnections() ), - randomTestStats() + randomCacheStats() ); case 3 -> new GetInferenceDiagnosticsAction.NodeResponse( instance.getNode(), @@ -100,43 +84,27 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference connPoolStats.getAvailableConnections(), randomInt() ), - randomTestStats() + randomCacheStats() ); default -> throw new UnsupportedEncodingException(Strings.format("Encountered unsupported case %s", select)); }; } - public static SerializableStats randomTestStats() { - return new TestStats(randomInt()); - } - - public record TestStats(int count) implements SerializableStats { - public static final String NAME = "test_stats"; - - public TestStats(StreamInput in) throws IOException { - this(in.readInt()); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeInt(count); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.startObject().field("count", count).endObject(); - } + private static Cache.Stats randomCacheStats() { + return new Cache.Stats(randomLong(), randomLong(), randomLong()); } @Override protected GetInferenceDiagnosticsAction.NodeResponse mutateInstanceForVersion( GetInferenceDiagnosticsAction.NodeResponse instance, TransportVersion version + ) { + return mutateNodeResponseForVersion(instance, version); + } + + public static GetInferenceDiagnosticsAction.NodeResponse mutateNodeResponseForVersion( + GetInferenceDiagnosticsAction.NodeResponse instance, + TransportVersion version ) { if (version.before(TransportVersions.ML_INFERENCE_ENDPOINT_CACHE)) { return new GetInferenceDiagnosticsAction.NodeResponse( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java index a9f6df035af28..f7874e67a6333 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java @@ -8,20 +8,22 @@ package org.elasticsearch.xpack.core.inference.action; import org.apache.http.pool.PoolStats; +import org.elasticsearch.TransportVersion; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.hamcrest.CoreMatchers; import java.io.IOException; import java.util.List; -public class GetInferenceDiagnosticsActionResponseTests extends AbstractWireSerializingTestCase { +public class GetInferenceDiagnosticsActionResponseTests extends AbstractBWCWireSerializationTestCase< + GetInferenceDiagnosticsAction.Response> { public static GetInferenceDiagnosticsAction.Response createRandom() { List responses = randomList( @@ -37,13 +39,7 @@ public void testToXContent() throws IOException { var poolStats = new PoolStats(1, 2, 3, 4); var entity = new GetInferenceDiagnosticsAction.Response( ClusterName.DEFAULT, - List.of( - new GetInferenceDiagnosticsAction.NodeResponse( - node, - poolStats, - new GetInferenceDiagnosticsActionNodeResponseTests.TestStats(5) - ) - ), + List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, poolStats, new Cache.Stats(5, 6, 7))), List.of() ); @@ -53,7 +49,7 @@ public void testToXContent() throws IOException { assertThat(xContentResult, CoreMatchers.is(""" {"id":{"connection_pool_stats":{"leased_connections":1,"pending_connections":2,"available_connections":3,""" + """ - "max_connections":4},"inference_endpoint_registry":{"count":5}}}""")); + "max_connections":4},"inference_endpoint_registry":{"cache_hits":5,"cache_misses":6,"cache_evictions":7}}}""")); } @Override @@ -76,7 +72,17 @@ protected GetInferenceDiagnosticsAction.Response mutateInstance(GetInferenceDiag } @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return GetInferenceDiagnosticsActionNodeResponseTests.registryWithTestStats(); + protected GetInferenceDiagnosticsAction.Response mutateInstanceForVersion( + GetInferenceDiagnosticsAction.Response instance, + TransportVersion version + ) { + return new GetInferenceDiagnosticsAction.Response( + instance.getClusterName(), + instance.getNodes() + .stream() + .map(nodeResponse -> GetInferenceDiagnosticsActionNodeResponseTests.mutateNodeResponseForVersion(nodeResponse, version)) + .toList(), + instance.failures() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index d2cdf611dde8e..1f4b5117a394e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -20,7 +20,6 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.UnifiedCompletionRequest; -import org.elasticsearch.xpack.core.inference.SerializableStats; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; @@ -36,7 +35,6 @@ import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.registry.ClearInferenceEndpointCacheAction; -import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettings; @@ -606,13 +604,6 @@ private static void addInternalNamedWriteables(List INFERENCE_ENDPOINT_CACHE_ENABLED = Setting.boolSetting( - "xpack.inference.cache.enabled", - true, - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); - - public static final Setting INFERENCE_ENDPOINT_CACHE_WEIGHT = Setting.intSetting( - "xpack.inference.cache.weight", - 25, - Setting.Property.NodeScope - ); - - public static final Setting INFERENCE_ENDPOINT_CACHE_EXPIRY = Setting.timeSetting( - "xpack.inference.cache.expiry_time", - TimeValue.timeValueMinutes(15), - TimeValue.timeValueMinutes(1), - TimeValue.timeValueHours(1), - Setting.Property.NodeScope - ); - public static final String X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER = "X-elastic-product-use-case"; public static final String NAME = "inference"; @@ -568,9 +547,7 @@ public static Set> getInferenceSettings() { settings.add(SKIP_VALIDATE_AND_START); settings.add(INDICES_INFERENCE_BATCH_SIZE); settings.add(INFERENCE_QUERY_TIMEOUT); - settings.add(INFERENCE_ENDPOINT_CACHE_ENABLED); - settings.add(INFERENCE_ENDPOINT_CACHE_EXPIRY); - settings.add(INFERENCE_ENDPOINT_CACHE_WEIGHT); + settings.addAll(InferenceEndpointRegistry.getSettingsDefinitions()); settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions()); return Collections.unmodifiableSet(settings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java index 8faf8aadceefc..be444973b2bd8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java @@ -96,7 +96,7 @@ public ClearInferenceEndpointCacheAction( .stream() .map(ProjectMetadata::id) .filter(id -> event.customMetadataChanged(id, InvalidateCacheMetadata.NAME)) - .peek(id -> log.trace("Trained model cache invalidated on node [{}]", () -> event.state().nodes().getLocalNodeId())) + .peek(id -> log.trace("Inference endpoint cache on node [{}]", () -> event.state().nodes().getLocalNodeId())) .forEach(inferenceEndpointRegistry::invalidateAll) ); } @@ -117,7 +117,7 @@ protected void masterOperation( @Override protected ClusterBlockException checkBlock(ClearInferenceEndpointCacheAction.Request request, ClusterState state) { - return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + return state.blocks().globalBlockedException(projectResolver.getProjectId(), ClusterBlockLevel.METADATA_WRITE); } public static class Request extends AcknowledgedRequest { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java index 8d4ce68b42b78..fca179ee9b216 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java @@ -16,20 +16,14 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.cache.CacheBuilder; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.SerializableStats; -import java.io.IOException; -import java.util.stream.StreamSupport; - -import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_ENDPOINT_CACHE_ENABLED; -import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_ENDPOINT_CACHE_EXPIRY; -import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_ENDPOINT_CACHE_WEIGHT; +import java.util.Collection; +import java.util.List; /** * A registry that assembles and caches Inference Endpoints, {@link Model}, for reuse. @@ -39,7 +33,33 @@ */ public class InferenceEndpointRegistry { + private static final Setting INFERENCE_ENDPOINT_CACHE_ENABLED = Setting.boolSetting( + "xpack.inference.endpoint.cache.enabled", + true, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + private static final Setting INFERENCE_ENDPOINT_CACHE_WEIGHT = Setting.intSetting( + "xpack.inference.endpoint.cache.weight", + 25, + Setting.Property.NodeScope + ); + + private static final Setting INFERENCE_ENDPOINT_CACHE_EXPIRY = Setting.timeSetting( + "xpack.inference.endpoint.cache.expiry_time", + TimeValue.timeValueMinutes(15), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueHours(1), + Setting.Property.NodeScope + ); + + public static Collection> getSettingsDefinitions() { + return List.of(INFERENCE_ENDPOINT_CACHE_ENABLED, INFERENCE_ENDPOINT_CACHE_WEIGHT, INFERENCE_ENDPOINT_CACHE_EXPIRY); + } + private static final Logger log = LogManager.getLogger(InferenceEndpointRegistry.class); + private static final Cache.Stats EMPTY = new Cache.Stats(0, 0, 0); private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; private final ProjectResolver projectResolver; @@ -70,7 +90,7 @@ public void getEndpoint(String inferenceEntityId, ActionListener listener var key = new InferenceIdAndProject(inferenceEntityId, projectResolver.getProjectId()); var cachedModel = cacheEnabled ? cache.get(key) : null; if (cachedModel != null) { - log.debug("Retrieved [{}] from cache.", inferenceEntityId); + log.trace("Retrieved [{}] from cache.", inferenceEntityId); listener.onResponse(cachedModel); } else { loadFromIndex(key, listener); @@ -78,12 +98,13 @@ public void getEndpoint(String inferenceEntityId, ActionListener listener } void invalidateAll(ProjectId projectId) { - // copy to an interim list because cache.keys() does not allow inline mutations if (cacheEnabled) { - StreamSupport.stream(cache.keys().spliterator(), false) - .filter(key -> key.projectId.equals(projectId)) - .toList() - .forEach(cache::invalidate); + var cacheKeys = cache.keys().iterator(); + while (cacheKeys.hasNext()) { + if (cacheKeys.next().projectId.equals(projectId)) { + cacheKeys.remove(); + } + } } } @@ -112,8 +133,8 @@ private void loadFromIndex(InferenceIdAndProject idAndProject, ActionListener Exception getEndpointException(String id, Class } public void testGetModel() { - var expectedEndpoint = storeGoodEndpoint("1"); + var expectedEndpoint = storeWorkingEndpoint("1"); var actualEndpoint = getEndpoint("1"); assertThat(actualEndpoint, equalTo(expectedEndpoint)); assertThat(getEndpoint("1"), sameInstance(actualEndpoint)); } - private Model storeGoodEndpoint(String id) { + private Model storeWorkingEndpoint(String id) { var expectedEndpoint = new AbstractTestInferenceService.TestServiceModel( id, TaskType.SPARSE_EMBEDDING, From 2b1ca15a29c3f5fbc935f35b646c9dd41adf9560 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Wed, 10 Sep 2025 10:32:22 -0400 Subject: [PATCH 09/10] Address comments --- .../action/GetInferenceDiagnosticsAction.java | 22 +++++++++---------- ...nceDiagnosticsActionNodeResponseTests.java | 10 ++++++--- ...ferenceDiagnosticsActionResponseTests.java | 17 +++++++++----- ...ransportGetInferenceDiagnosticsAction.java | 16 +++++++++++++- .../ClearInferenceEndpointCacheAction.java | 2 +- .../registry/InferenceEndpointRegistry.java | 4 ++++ ...learInferenceEndpointCacheActionTests.java | 3 +++ .../xpack/security/operator/Constants.java | 2 +- 8 files changed, 53 insertions(+), 23 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java index 765b49c6cf745..c027dab700672 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java @@ -16,7 +16,6 @@ import org.elasticsearch.action.support.nodes.BaseNodesResponse; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -134,12 +133,12 @@ public NodeResponse( DiscoveryNode node, PoolStats poolStats, PoolStats eisPoolStats, - @Nullable Cache.Stats inferenceEndpointRegistryStats + @Nullable Stats inferenceEndpointRegistryStats ) { super(node); externalConnectionPoolStats = ConnectionPoolStats.of(poolStats); eisMtlsConnectionPoolStats = ConnectionPoolStats.of(eisPoolStats); - this.inferenceEndpointRegistryStats = inferenceEndpointRegistryStats != null ? Stats.of(inferenceEndpointRegistryStats) : null; + this.inferenceEndpointRegistryStats = inferenceEndpointRegistryStats; } public NodeResponse(StreamInput in) throws IOException { @@ -298,35 +297,34 @@ int getMaxConnections() { } } - public record Stats(long hits, long misses, long evictions) implements ToXContentObject, Writeable { + public record Stats(int entryCount, long hits, long misses, long evictions) implements ToXContentObject, Writeable { + private static final String NUM_OF_CACHE_ENTRIES = "cache_count"; private static final String CACHE_HITS = "cache_hits"; private static final String CACHE_MISSES = "cache_misses"; private static final String CACHE_EVICTIONS = "cache_evictions"; public Stats(StreamInput in) throws IOException { - this(in.readLong(), in.readLong(), in.readLong()); + this(in.readVInt(), in.readVLong(), in.readVLong(), in.readVLong()); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeLong(hits); - out.writeLong(misses); - out.writeLong(evictions); + out.writeVInt(entryCount); + out.writeVLong(hits); + out.writeVLong(misses); + out.writeVLong(evictions); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { return builder.startObject() + .field(NUM_OF_CACHE_ENTRIES, entryCount) .field(CACHE_HITS, hits) .field(CACHE_MISSES, misses) .field(CACHE_EVICTIONS, evictions) .endObject(); } - - public static Stats of(Cache.Stats cacheStats) { - return new Stats(cacheStats.getHits(), cacheStats.getMisses(), cacheStats.getEvictions()); - } } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java index abcb1979d93c7..020eeff459502 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; @@ -85,8 +84,13 @@ private PoolStats copyPoolStats(GetInferenceDiagnosticsAction.NodeResponse.Conne ); } - private static Cache.Stats randomCacheStats() { - return new Cache.Stats(randomLong(), randomLong(), randomLong()); + private static GetInferenceDiagnosticsAction.NodeResponse.Stats randomCacheStats() { + return new GetInferenceDiagnosticsAction.NodeResponse.Stats( + randomInt(), + randomLongBetween(0, Long.MAX_VALUE), + randomLongBetween(0, Long.MAX_VALUE), + randomLongBetween(0, Long.MAX_VALUE) + ); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java index 9001952965fe1..d0f608d55fc36 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; -import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.XContentBuilder; @@ -42,7 +41,14 @@ public void testToXContent() throws IOException { var eisPoolStats = new PoolStats(5, 6, 7, 8); var entity = new GetInferenceDiagnosticsAction.Response( ClusterName.DEFAULT, - List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, externalPoolStats, eisPoolStats, new Cache.Stats(5, 6, 7))), + List.of( + new GetInferenceDiagnosticsAction.NodeResponse( + node, + externalPoolStats, + eisPoolStats, + new GetInferenceDiagnosticsAction.NodeResponse.Stats(5, 6, 7, 8) + ) + ), List.of() ); @@ -70,9 +76,10 @@ public void testToXContent() throws IOException { } }, "inference_endpoint_registry":{ - "cache_hits": 5, - "cache_misses": 6, - "cache_evictions": 7 + "cache_count": 5, + "cache_hits": 6, + "cache_misses": 7, + "cache_evictions": 8 } } }"""))); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java index 60d541fbd17d8..0b11b4c1c69a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java @@ -84,7 +84,21 @@ protected GetInferenceDiagnosticsAction.NodeResponse nodeOperation(GetInferenceD transportService.getLocalNode(), managers.externalHttpClientManager().getPoolStats(), managers.eisMtlsHttpClientManager().getPoolStats(), - inferenceEndpointRegistry.cacheEnabled() ? inferenceEndpointRegistry.stats() : null + cacheStats() ); } + + private GetInferenceDiagnosticsAction.NodeResponse.Stats cacheStats() { + if (inferenceEndpointRegistry.cacheEnabled()) { + var stats = inferenceEndpointRegistry.stats(); + return new GetInferenceDiagnosticsAction.NodeResponse.Stats( + inferenceEndpointRegistry.cacheCount(), + stats.getHits(), + stats.getMisses(), + stats.getEvictions() + ); + } else { + return null; + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java index be444973b2bd8..69a210df2bbde 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java @@ -59,7 +59,7 @@ */ public class ClearInferenceEndpointCacheAction extends AcknowledgedTransportMasterNodeAction { private static final Logger log = LogManager.getLogger(ClearInferenceEndpointCacheAction.class); - private static final String NAME = "cluster:admin/xpack/inference/clear_inference_endpoint_cache"; + private static final String NAME = "cluster:internal/xpack/inference/clear_inference_endpoint_cache"; public static final ActionType INSTANCE = new ActionType<>(NAME); private static final String TASK_QUEUE_NAME = "inference-endpoint-cache-management"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java index fca179ee9b216..46d93e8d404b7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java @@ -137,6 +137,10 @@ public Cache.Stats stats() { return cacheEnabled ? cache.stats() : EMPTY; } + public int cacheCount() { + return cacheEnabled ? cache.count() : 0; + } + public boolean cacheEnabled() { return cacheEnabled; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheActionTests.java index 82c51d174a8bd..f5e134c7089cb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheActionTests.java @@ -54,6 +54,7 @@ public void testCacheEviction() throws Exception { invokeEndpoint(); var stats = cacheStats(); + assertThat(stats.entryCount(), equalTo(1)); assertThat(stats.hits(), equalTo(0L)); assertThat(stats.misses(), equalTo(1L)); assertThat(stats.evictions(), equalTo(0L)); @@ -64,6 +65,7 @@ public void testCacheEviction() throws Exception { assertBusy(() -> { var nextStats = cacheStats(); + assertThat(nextStats.entryCount(), equalTo(0)); assertThat(nextStats.hits(), equalTo(0L)); assertThat(nextStats.misses(), equalTo(1L)); assertThat(nextStats.evictions(), equalTo(1L)); @@ -71,6 +73,7 @@ public void testCacheEviction() throws Exception { invokeEndpoint(); stats = cacheStats(); + assertThat(stats.entryCount(), equalTo(1)); assertThat(stats.hits(), equalTo(0L)); assertThat(stats.misses(), equalTo(2L)); assertThat(stats.evictions(), equalTo(1L)); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 55cbb37e7e109..f2634a19d1068 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -174,7 +174,7 @@ public class Constants { "cluster:admin/xpack/enrich/get", "cluster:admin/xpack/enrich/put", "cluster:admin/xpack/enrich/reindex", - "cluster:admin/xpack/inference/clear_inference_endpoint_cache", + "cluster:internal/xpack/inference/clear_inference_endpoint_cache", "cluster:admin/xpack/inference/delete", "cluster:admin/xpack/inference/put", "cluster:admin/xpack/inference/update", From 86c366df5fc94b0bba4d7a182660180006e2bc3e Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Wed, 10 Sep 2025 11:28:41 -0400 Subject: [PATCH 10/10] Update javadoc with edge cases --- .../registry/ClearInferenceEndpointCacheAction.java | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java index 69a210df2bbde..2ca6a8312dbae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ClearInferenceEndpointCacheAction.java @@ -53,9 +53,12 @@ import static org.elasticsearch.TransportVersions.ML_INFERENCE_ENDPOINT_CACHE; /** - * Clears the cache in {@link InferenceEndpointRegistry}. This uses a master node transport action, even though most requests will originate - * from the master node (when updating and deleting inference endpoints via REST), because there are some edge cases where deletes can come - * from other nodes. This uses the cluster state to broadcast the message to all nodes to clear their cache, which has guaranteed delivery. + * Clears the cache in {@link InferenceEndpointRegistry}. + * This uses the cluster state to broadcast the message to all nodes to clear their cache, which has guaranteed delivery. + * There are some edge cases where deletes can come from any node, for example ElasticInferenceServiceAuthorizationHandler and + * SemanticTextIndexOptionsIT will delete endpoints on whatever node is handling the request. So this must use a master node transport + * action so that the cluster updates can invalidate the cache, even though most requests will originate from the master node + * (e.g. when updating and deleting inference endpoints via REST). */ public class ClearInferenceEndpointCacheAction extends AcknowledgedTransportMasterNodeAction { private static final Logger log = LogManager.getLogger(ClearInferenceEndpointCacheAction.class);