Skip to content

Commit 25075fa

Browse files
committed
[ML] Cache Inference Endpoints
Maintain parsed Inference Endpoints in memory for reuse. Endpoints are cached on first access and expire after write. This removes search pressure during inference, bypassing search requests to system indices for repeated model access. When any endpoint is updated or deleted, the whole cache is invalidated and must be reloaded. Cache can be configured with three settings: - `xpack.inference.cache.enabled` enables or disables the cache (default enabled). - `xpack.inference.cache.weight` controls how many endpoints can live in the cache (default 25). - `xpack.inference.cache.expiry_time` controls how long endpoints live in the cache, measured from when they are first accessed (default 15 minutes, minimum 1 minute, maximum 1 hour). Resolve #133135
1 parent ea95e9e commit 25075fa

18 files changed

+893
-103
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
package org.elasticsearch.xpack.core.inference;
99

10-
import org.elasticsearch.common.io.stream.Writeable;
10+
import org.elasticsearch.common.io.stream.NamedWriteable;
1111
import org.elasticsearch.xcontent.ToXContentObject;
1212

13-
public interface SerializableStats extends ToXContentObject, Writeable {
14-
15-
}
13+
public interface SerializableStats extends ToXContentObject, NamedWriteable {}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.core.inference.action;
99

1010
import org.apache.http.pool.PoolStats;
11+
import org.elasticsearch.TransportVersion;
1112
import org.elasticsearch.action.ActionType;
1213
import org.elasticsearch.action.FailedNodeException;
1314
import org.elasticsearch.action.support.nodes.BaseNodeResponse;
@@ -18,10 +19,12 @@
1819
import org.elasticsearch.common.io.stream.StreamInput;
1920
import org.elasticsearch.common.io.stream.StreamOutput;
2021
import org.elasticsearch.common.io.stream.Writeable;
22+
import org.elasticsearch.core.Nullable;
2123
import org.elasticsearch.transport.AbstractTransportRequest;
2224
import org.elasticsearch.xcontent.ToXContentFragment;
2325
import org.elasticsearch.xcontent.ToXContentObject;
2426
import org.elasticsearch.xcontent.XContentBuilder;
27+
import org.elasticsearch.xpack.core.inference.SerializableStats;
2528

2629
import java.io.IOException;
2730
import java.util.List;
@@ -116,29 +119,42 @@ public int hashCode() {
116119

117120
public static class NodeResponse extends BaseNodeResponse implements ToXContentFragment {
118121
static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats";
122+
static final String INFERENCE_ENDPOINT_REGISTRY_STATS_FIELD_NAME = "inference_endpoint_registry";
119123

120124
private final ConnectionPoolStats connectionPoolStats;
125+
@Nullable
126+
private final SerializableStats inferenceEndpointRegistryStats;
121127

122-
public NodeResponse(DiscoveryNode node, PoolStats poolStats) {
128+
public NodeResponse(DiscoveryNode node, PoolStats poolStats, SerializableStats inferenceEndpointRegistryStats) {
123129
super(node);
124130
connectionPoolStats = ConnectionPoolStats.of(poolStats);
131+
this.inferenceEndpointRegistryStats = inferenceEndpointRegistryStats;
125132
}
126133

127134
public NodeResponse(StreamInput in) throws IOException {
128135
super(in);
129136

130137
connectionPoolStats = new ConnectionPoolStats(in);
138+
inferenceEndpointRegistryStats = in.getTransportVersion().onOrAfter(TransportVersion.current())
139+
? in.readOptionalNamedWriteable(SerializableStats.class)
140+
: null;
131141
}
132142

133143
@Override
134144
public void writeTo(StreamOutput out) throws IOException {
135145
super.writeTo(out);
136146
connectionPoolStats.writeTo(out);
147+
if (out.getTransportVersion().onOrAfter(TransportVersion.current())) {
148+
out.writeOptionalNamedWriteable(inferenceEndpointRegistryStats);
149+
}
137150
}
138151

139152
@Override
140153
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
141154
builder.field(CONNECTION_POOL_STATS_FIELD_NAME, connectionPoolStats, params);
155+
if (inferenceEndpointRegistryStats != null) {
156+
builder.field(INFERENCE_ENDPOINT_REGISTRY_STATS_FIELD_NAME, inferenceEndpointRegistryStats, params);
157+
}
142158
return builder;
143159
}
144160

@@ -147,18 +163,23 @@ public boolean equals(Object o) {
147163
if (this == o) return true;
148164
if (o == null || getClass() != o.getClass()) return false;
149165
NodeResponse response = (NodeResponse) o;
150-
return Objects.equals(connectionPoolStats, response.connectionPoolStats);
166+
return Objects.equals(connectionPoolStats, response.connectionPoolStats)
167+
&& Objects.equals(inferenceEndpointRegistryStats, response.inferenceEndpointRegistryStats);
151168
}
152169

153170
@Override
154171
public int hashCode() {
155-
return Objects.hash(connectionPoolStats);
172+
return Objects.hash(connectionPoolStats, inferenceEndpointRegistryStats);
156173
}
157174

158175
ConnectionPoolStats getConnectionPoolStats() {
159176
return connectionPoolStats;
160177
}
161178

179+
public SerializableStats getInferenceEndpointRegistryStats() {
180+
return inferenceEndpointRegistryStats;
181+
}
182+
162183
static class ConnectionPoolStats implements ToXContentObject, Writeable {
163184
static final String LEASED_CONNECTIONS = "leased_connections";
164185
static final String PENDING_CONNECTIONS = "pending_connections";

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,36 @@
1111
import org.elasticsearch.cluster.node.DiscoveryNode;
1212
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
1313
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
15+
import org.elasticsearch.common.io.stream.StreamInput;
16+
import org.elasticsearch.common.io.stream.StreamOutput;
1417
import org.elasticsearch.common.io.stream.Writeable;
1518
import org.elasticsearch.test.AbstractWireSerializingTestCase;
19+
import org.elasticsearch.xcontent.XContentBuilder;
20+
import org.elasticsearch.xpack.core.inference.SerializableStats;
1621

1722
import java.io.IOException;
1823
import java.io.UnsupportedEncodingException;
24+
import java.util.List;
1925

2026
public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractWireSerializingTestCase<
2127
GetInferenceDiagnosticsAction.NodeResponse> {
2228
public static GetInferenceDiagnosticsAction.NodeResponse createRandom() {
2329
DiscoveryNode node = DiscoveryNodeUtils.create("id");
2430
var randomPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt());
2531

26-
return new GetInferenceDiagnosticsAction.NodeResponse(node, randomPoolStats);
32+
return new GetInferenceDiagnosticsAction.NodeResponse(node, randomPoolStats, new TestStats(randomInt()));
33+
}
34+
35+
@Override
36+
protected NamedWriteableRegistry getNamedWriteableRegistry() {
37+
return registryWithTestStats();
38+
}
39+
40+
public static NamedWriteableRegistry registryWithTestStats() {
41+
return new NamedWriteableRegistry(
42+
List.of(new NamedWriteableRegistry.Entry(SerializableStats.class, TestStats.NAME, TestStats::new))
43+
);
2744
}
2845

2946
@Override
@@ -50,7 +67,8 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
5067
connPoolStats.getPendingConnections(),
5168
connPoolStats.getAvailableConnections(),
5269
connPoolStats.getMaxConnections()
53-
)
70+
),
71+
randomTestStats()
5472
);
5573
case 1 -> new GetInferenceDiagnosticsAction.NodeResponse(
5674
instance.getNode(),
@@ -59,7 +77,8 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
5977
randomInt(),
6078
connPoolStats.getAvailableConnections(),
6179
connPoolStats.getMaxConnections()
62-
)
80+
),
81+
randomTestStats()
6382
);
6483
case 2 -> new GetInferenceDiagnosticsAction.NodeResponse(
6584
instance.getNode(),
@@ -68,7 +87,8 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
6887
connPoolStats.getPendingConnections(),
6988
randomInt(),
7089
connPoolStats.getMaxConnections()
71-
)
90+
),
91+
randomTestStats()
7292
);
7393
case 3 -> new GetInferenceDiagnosticsAction.NodeResponse(
7494
instance.getNode(),
@@ -77,9 +97,37 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
7797
connPoolStats.getPendingConnections(),
7898
connPoolStats.getAvailableConnections(),
7999
randomInt()
80-
)
100+
),
101+
randomTestStats()
81102
);
82103
default -> throw new UnsupportedEncodingException(Strings.format("Encountered unsupported case %s", select));
83104
};
84105
}
106+
107+
public static SerializableStats randomTestStats() {
108+
return new TestStats(randomInt());
109+
}
110+
111+
public record TestStats(int count) implements SerializableStats {
112+
public static final String NAME = "test_stats";
113+
114+
public TestStats(StreamInput in) throws IOException {
115+
this(in.readInt());
116+
}
117+
118+
@Override
119+
public String getWriteableName() {
120+
return NAME;
121+
}
122+
123+
@Override
124+
public void writeTo(StreamOutput out) throws IOException {
125+
out.writeInt(count);
126+
}
127+
128+
@Override
129+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
130+
return builder.startObject().field("count", count).endObject();
131+
}
132+
}
85133
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.apache.http.pool.PoolStats;
1111
import org.elasticsearch.cluster.ClusterName;
1212
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
13+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1314
import org.elasticsearch.common.io.stream.Writeable;
1415
import org.elasticsearch.test.AbstractWireSerializingTestCase;
1516
import org.elasticsearch.xcontent.XContentBuilder;
@@ -36,7 +37,13 @@ public void testToXContent() throws IOException {
3637
var poolStats = new PoolStats(1, 2, 3, 4);
3738
var entity = new GetInferenceDiagnosticsAction.Response(
3839
ClusterName.DEFAULT,
39-
List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, poolStats)),
40+
List.of(
41+
new GetInferenceDiagnosticsAction.NodeResponse(
42+
node,
43+
poolStats,
44+
new GetInferenceDiagnosticsActionNodeResponseTests.TestStats(5)
45+
)
46+
),
4047
List.of()
4148
);
4249

@@ -46,7 +53,7 @@ public void testToXContent() throws IOException {
4653

4754
assertThat(xContentResult, CoreMatchers.is("""
4855
{"id":{"connection_pool_stats":{"leased_connections":1,"pending_connections":2,"available_connections":3,""" + """
49-
"max_connections":4}}}"""));
56+
"max_connections":4},"inference_endpoint_registry":{"count":5}}}"""));
5057
}
5158

5259
@Override
@@ -67,4 +74,9 @@ protected GetInferenceDiagnosticsAction.Response mutateInstance(GetInferenceDiag
6774
List.of()
6875
);
6976
}
77+
78+
@Override
79+
protected NamedWriteableRegistry getNamedWriteableRegistry() {
80+
return GetInferenceDiagnosticsActionNodeResponseTests.registryWithTestStats();
81+
}
7082
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
package org.elasticsearch.xpack.inference;
99

10+
import org.elasticsearch.cluster.AbstractNamedDiffable;
11+
import org.elasticsearch.cluster.NamedDiff;
12+
import org.elasticsearch.cluster.metadata.Metadata;
1013
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1114
import org.elasticsearch.inference.ChunkingSettings;
1215
import org.elasticsearch.inference.EmptySecretSettings;
@@ -17,6 +20,7 @@
1720
import org.elasticsearch.inference.ServiceSettings;
1821
import org.elasticsearch.inference.TaskSettings;
1922
import org.elasticsearch.inference.UnifiedCompletionRequest;
23+
import org.elasticsearch.xpack.core.inference.SerializableStats;
2024
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
2125
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
2226
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
@@ -31,6 +35,8 @@
3135
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
3236
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
3337
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
38+
import org.elasticsearch.xpack.inference.registry.ClearInferenceEndpointCacheAction;
39+
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;
3440
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionServiceSettings;
3541
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
3642
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettings;
@@ -600,6 +606,31 @@ private static void addInternalNamedWriteables(List<NamedWriteableRegistry.Entry
600606
ElasticRerankerServiceSettings::new
601607
)
602608
);
609+
namedWriteables.add(
610+
new NamedWriteableRegistry.Entry(
611+
SerializableStats.class,
612+
InferenceEndpointRegistry.Stats.NAME,
613+
InferenceEndpointRegistry.Stats::new
614+
)
615+
);
616+
namedWriteables.add(
617+
new NamedWriteableRegistry.Entry(
618+
Metadata.ProjectCustom.class,
619+
ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME,
620+
ClearInferenceEndpointCacheAction.InvalidateCacheMetadata::new
621+
)
622+
);
623+
namedWriteables.add(
624+
new NamedWriteableRegistry.Entry(
625+
NamedDiff.class,
626+
ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME,
627+
in -> AbstractNamedDiffable.readDiffFrom(
628+
Metadata.ProjectCustom.class,
629+
ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME,
630+
in
631+
)
632+
)
633+
);
603634
}
604635

605636
private static void addChunkingSettingsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@
104104
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
105105
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankDoc;
106106
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
107+
import org.elasticsearch.xpack.inference.registry.ClearInferenceEndpointCacheAction;
108+
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;
107109
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
108110
import org.elasticsearch.xpack.inference.registry.ModelRegistryMetadata;
109111
import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction;
@@ -199,6 +201,27 @@ public class InferencePlugin extends Plugin
199201
License.OperationMode.ENTERPRISE
200202
);
201203

204+
public static final Setting<Boolean> INFERENCE_ENDPOINT_CACHE_ENABLED = Setting.boolSetting(
205+
"xpack.inference.cache.enabled",
206+
true,
207+
Setting.Property.NodeScope,
208+
Setting.Property.Dynamic
209+
);
210+
211+
public static final Setting<Integer> INFERENCE_ENDPOINT_CACHE_WEIGHT = Setting.intSetting(
212+
"xpack.inference.cache.weight",
213+
25,
214+
Setting.Property.NodeScope
215+
);
216+
217+
public static final Setting<TimeValue> INFERENCE_ENDPOINT_CACHE_EXPIRY = Setting.timeSetting(
218+
"xpack.inference.cache.expiry_time",
219+
TimeValue.timeValueMinutes(15),
220+
TimeValue.timeValueMinutes(1),
221+
TimeValue.timeValueHours(1),
222+
Setting.Property.NodeScope
223+
);
224+
202225
public static final String X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER = "X-elastic-product-use-case";
203226

204227
public static final String NAME = "inference";
@@ -237,7 +260,8 @@ public List<ActionHandler> getActions() {
237260
new ActionHandler(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class),
238261
new ActionHandler(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class),
239262
new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class),
240-
new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class)
263+
new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class),
264+
new ActionHandler(ClearInferenceEndpointCacheAction.INSTANCE, ClearInferenceEndpointCacheAction.class)
241265
);
242266
}
243267

@@ -389,6 +413,16 @@ public Collection<?> createComponents(PluginServices services) {
389413
// Add binding for interface -> implementation
390414
components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator));
391415

416+
components.add(
417+
new InferenceEndpointRegistry(
418+
services.clusterService(),
419+
settings,
420+
modelRegistry.get(),
421+
serviceRegistry,
422+
services.projectResolver()
423+
)
424+
);
425+
392426
return components;
393427
}
394428

@@ -443,6 +477,13 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {
443477
ModelRegistryMetadata::fromXContent
444478
)
445479
);
480+
namedXContent.add(
481+
new NamedXContentRegistry.Entry(
482+
Metadata.ProjectCustom.class,
483+
new ParseField(ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME),
484+
ClearInferenceEndpointCacheAction.InvalidateCacheMetadata::fromXContent
485+
)
486+
);
446487
return namedXContent;
447488
}
448489

@@ -527,6 +568,9 @@ public static Set<Setting<?>> getInferenceSettings() {
527568
settings.add(SKIP_VALIDATE_AND_START);
528569
settings.add(INDICES_INFERENCE_BATCH_SIZE);
529570
settings.add(INFERENCE_QUERY_TIMEOUT);
571+
settings.add(INFERENCE_ENDPOINT_CACHE_ENABLED);
572+
settings.add(INFERENCE_ENDPOINT_CACHE_EXPIRY);
573+
settings.add(INFERENCE_ENDPOINT_CACHE_WEIGHT);
530574
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());
531575
return Collections.unmodifiableSet(settings);
532576
}

0 commit comments

Comments
 (0)