Skip to content

Commit 422db0d

Browse files
authored
[ML] Cache Inference Endpoints (#133860)
Maintain parsed Inference Endpoints in memory for reuse. Endpoints are cached on first access and expire after write. This removes search pressure during inference, bypassing search requests to system indices for repeated model access. When any endpoint is updated or deleted, the whole cache is invalidated and must be reloaded. Cache can be configured with three settings: - `xpack.inference.endpoint.cache.enabled` enables or disables the cache (default enabled). - `xpack.inference.endpoint.cache.weight` controls how many endpoints can live in the cache (default 25). - `xpack.inference.endpoint.cache.expiry_time` controls how long endpoints live in the cache, measured from when they are first accessed (default 15 minutes, minimum 1 minute, maximum 1 hour). Resolve #133135
1 parent 500b68a commit 422db0d

File tree

23 files changed

+938
-146
lines changed

23 files changed

+938
-146
lines changed

docs/changelog/133860.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 133860
2+
summary: Cache Inference Endpoints
3+
area: Machine Learning
4+
type: enhancement
5+
issues:
6+
- 133135

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ static TransportVersion def(int id) {
335335
public static final TransportVersion ESQL_FIXED_INDEX_LIKE = def(9_119_0_00);
336336
public static final TransportVersion TIME_SERIES_TELEMETRY = def(9_155_0_00);
337337
public static final TransportVersion INFERENCE_API_EIS_DIAGNOSTICS = def(9_156_0_00);
338+
public static final TransportVersion ML_INFERENCE_ENDPOINT_CACHE = def(9_157_0_00);
338339

339340
/*
340341
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/SerializableStats.java

Lines changed: 0 additions & 15 deletions
This file was deleted.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.common.io.stream.StreamInput;
2020
import org.elasticsearch.common.io.stream.StreamOutput;
2121
import org.elasticsearch.common.io.stream.Writeable;
22+
import org.elasticsearch.core.Nullable;
2223
import org.elasticsearch.transport.AbstractTransportRequest;
2324
import org.elasticsearch.xcontent.ToXContentFragment;
2425
import org.elasticsearch.xcontent.ToXContentObject;
@@ -28,6 +29,8 @@
2829
import java.util.List;
2930
import java.util.Objects;
3031

32+
import static org.elasticsearch.TransportVersions.ML_INFERENCE_ENDPOINT_CACHE;
33+
3134
public class GetInferenceDiagnosticsAction extends ActionType<GetInferenceDiagnosticsAction.Response> {
3235

3336
public static final GetInferenceDiagnosticsAction INSTANCE = new GetInferenceDiagnosticsAction();
@@ -119,14 +122,23 @@ public static class NodeResponse extends BaseNodeResponse implements ToXContentF
119122
private static final String EXTERNAL_FIELD = "external";
120123
private static final String EIS_FIELD = "eis_mtls";
121124
private static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats";
125+
static final String INFERENCE_ENDPOINT_REGISTRY_STATS_FIELD_NAME = "inference_endpoint_registry";
122126

123127
private final ConnectionPoolStats externalConnectionPoolStats;
124128
private final ConnectionPoolStats eisMtlsConnectionPoolStats;
125-
126-
public NodeResponse(DiscoveryNode node, PoolStats poolStats, PoolStats eisPoolStats) {
129+
@Nullable
130+
private final Stats inferenceEndpointRegistryStats;
131+
132+
public NodeResponse(
133+
DiscoveryNode node,
134+
PoolStats poolStats,
135+
PoolStats eisPoolStats,
136+
@Nullable Stats inferenceEndpointRegistryStats
137+
) {
127138
super(node);
128139
externalConnectionPoolStats = ConnectionPoolStats.of(poolStats);
129140
eisMtlsConnectionPoolStats = ConnectionPoolStats.of(eisPoolStats);
141+
this.inferenceEndpointRegistryStats = inferenceEndpointRegistryStats;
130142
}
131143

132144
public NodeResponse(StreamInput in) throws IOException {
@@ -138,6 +150,9 @@ public NodeResponse(StreamInput in) throws IOException {
138150
} else {
139151
eisMtlsConnectionPoolStats = ConnectionPoolStats.EMPTY;
140152
}
153+
inferenceEndpointRegistryStats = in.getTransportVersion().onOrAfter(ML_INFERENCE_ENDPOINT_CACHE)
154+
? in.readOptionalWriteable(Stats::new)
155+
: null;
141156
}
142157

143158
@Override
@@ -148,6 +163,9 @@ public void writeTo(StreamOutput out) throws IOException {
148163
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) {
149164
eisMtlsConnectionPoolStats.writeTo(out);
150165
}
166+
if (out.getTransportVersion().onOrAfter(ML_INFERENCE_ENDPOINT_CACHE)) {
167+
out.writeOptionalWriteable(inferenceEndpointRegistryStats);
168+
}
151169
}
152170

153171
@Override
@@ -163,6 +181,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
163181
builder.field(CONNECTION_POOL_STATS_FIELD_NAME, eisMtlsConnectionPoolStats, params);
164182
}
165183
builder.endObject();
184+
if (inferenceEndpointRegistryStats != null) {
185+
builder.field(INFERENCE_ENDPOINT_REGISTRY_STATS_FIELD_NAME, inferenceEndpointRegistryStats, params);
186+
}
166187
return builder;
167188
}
168189

@@ -172,12 +193,13 @@ public boolean equals(Object o) {
172193
if (o == null || getClass() != o.getClass()) return false;
173194
NodeResponse response = (NodeResponse) o;
174195
return Objects.equals(externalConnectionPoolStats, response.externalConnectionPoolStats)
175-
&& Objects.equals(eisMtlsConnectionPoolStats, response.eisMtlsConnectionPoolStats);
196+
&& Objects.equals(eisMtlsConnectionPoolStats, response.eisMtlsConnectionPoolStats)
197+
&& Objects.equals(inferenceEndpointRegistryStats, response.inferenceEndpointRegistryStats);
176198
}
177199

178200
@Override
179201
public int hashCode() {
180-
return Objects.hash(externalConnectionPoolStats, eisMtlsConnectionPoolStats);
202+
return Objects.hash(externalConnectionPoolStats, eisMtlsConnectionPoolStats, inferenceEndpointRegistryStats);
181203
}
182204

183205
ConnectionPoolStats getExternalConnectionPoolStats() {
@@ -188,6 +210,10 @@ ConnectionPoolStats getEisMtlsConnectionPoolStats() {
188210
return eisMtlsConnectionPoolStats;
189211
}
190212

213+
public Stats getInferenceEndpointRegistryStats() {
214+
return inferenceEndpointRegistryStats;
215+
}
216+
191217
static class ConnectionPoolStats implements ToXContentObject, Writeable {
192218
private static final String LEASED_CONNECTIONS = "leased_connections";
193219
private static final String PENDING_CONNECTIONS = "pending_connections";
@@ -270,5 +296,35 @@ int getMaxConnections() {
270296
return maxConnections;
271297
}
272298
}
299+
300+
public record Stats(int entryCount, long hits, long misses, long evictions) implements ToXContentObject, Writeable {
301+
302+
private static final String NUM_OF_CACHE_ENTRIES = "cache_count";
303+
private static final String CACHE_HITS = "cache_hits";
304+
private static final String CACHE_MISSES = "cache_misses";
305+
private static final String CACHE_EVICTIONS = "cache_evictions";
306+
307+
public Stats(StreamInput in) throws IOException {
308+
this(in.readVInt(), in.readVLong(), in.readVLong(), in.readVLong());
309+
}
310+
311+
@Override
312+
public void writeTo(StreamOutput out) throws IOException {
313+
out.writeVInt(entryCount);
314+
out.writeVLong(hits);
315+
out.writeVLong(misses);
316+
out.writeVLong(evictions);
317+
}
318+
319+
@Override
320+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
321+
return builder.startObject()
322+
.field(NUM_OF_CACHE_ENTRIES, entryCount)
323+
.field(CACHE_HITS, hits)
324+
.field(CACHE_MISSES, misses)
325+
.field(CACHE_EVICTIONS, evictions)
326+
.endObject();
327+
}
328+
}
273329
}
274330
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public static GetInferenceDiagnosticsAction.NodeResponse createRandom() {
2626
var randomExternalPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt());
2727
var randomEisPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt());
2828

29-
return new GetInferenceDiagnosticsAction.NodeResponse(node, randomExternalPoolStats, randomEisPoolStats);
29+
return new GetInferenceDiagnosticsAction.NodeResponse(node, randomExternalPoolStats, randomEisPoolStats, randomCacheStats());
3030
}
3131

3232
@Override
@@ -45,11 +45,16 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
4545
if (randomBoolean()) {
4646
PoolStats mutatedConnPoolStats = mutatePoolStats(instance.getExternalConnectionPoolStats());
4747
PoolStats eisPoolStats = copyPoolStats(instance.getEisMtlsConnectionPoolStats());
48-
return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), mutatedConnPoolStats, eisPoolStats);
48+
return new GetInferenceDiagnosticsAction.NodeResponse(
49+
instance.getNode(),
50+
mutatedConnPoolStats,
51+
eisPoolStats,
52+
randomCacheStats()
53+
);
4954
} else {
5055
PoolStats connPoolStats = copyPoolStats(instance.getExternalConnectionPoolStats());
5156
PoolStats mutatedEisPoolStats = mutatePoolStats(instance.getEisMtlsConnectionPoolStats());
52-
return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), connPoolStats, mutatedEisPoolStats);
57+
return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), connPoolStats, mutatedEisPoolStats, null);
5358
}
5459
}
5560

@@ -79,24 +84,50 @@ private PoolStats copyPoolStats(GetInferenceDiagnosticsAction.NodeResponse.Conne
7984
);
8085
}
8186

87+
private static GetInferenceDiagnosticsAction.NodeResponse.Stats randomCacheStats() {
88+
return new GetInferenceDiagnosticsAction.NodeResponse.Stats(
89+
randomInt(),
90+
randomLongBetween(0, Long.MAX_VALUE),
91+
randomLongBetween(0, Long.MAX_VALUE),
92+
randomLongBetween(0, Long.MAX_VALUE)
93+
);
94+
}
95+
8296
@Override
8397
protected GetInferenceDiagnosticsAction.NodeResponse mutateInstanceForVersion(
8498
GetInferenceDiagnosticsAction.NodeResponse instance,
8599
TransportVersion version
86100
) {
87-
if (version.before(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) {
88-
return new GetInferenceDiagnosticsAction.NodeResponse(
89-
instance.getNode(),
90-
new PoolStats(
91-
instance.getExternalConnectionPoolStats().getLeasedConnections(),
92-
instance.getExternalConnectionPoolStats().getPendingConnections(),
93-
instance.getExternalConnectionPoolStats().getAvailableConnections(),
94-
instance.getExternalConnectionPoolStats().getMaxConnections()
95-
),
96-
new PoolStats(0, 0, 0, 0)
97-
);
98-
} else {
101+
return mutateNodeResponseForVersion(instance, version);
102+
}
103+
104+
public static GetInferenceDiagnosticsAction.NodeResponse mutateNodeResponseForVersion(
105+
GetInferenceDiagnosticsAction.NodeResponse instance,
106+
TransportVersion version
107+
) {
108+
if (version.onOrAfter(TransportVersions.ML_INFERENCE_ENDPOINT_CACHE)) {
99109
return instance;
100110
}
111+
112+
var eisMltsConnectionPoolStats = version.onOrAfter(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)
113+
? new PoolStats(
114+
instance.getEisMtlsConnectionPoolStats().getLeasedConnections(),
115+
instance.getEisMtlsConnectionPoolStats().getPendingConnections(),
116+
instance.getEisMtlsConnectionPoolStats().getAvailableConnections(),
117+
instance.getEisMtlsConnectionPoolStats().getMaxConnections()
118+
)
119+
: new PoolStats(0, 0, 0, 0);
120+
121+
return new GetInferenceDiagnosticsAction.NodeResponse(
122+
instance.getNode(),
123+
new PoolStats(
124+
instance.getExternalConnectionPoolStats().getLeasedConnections(),
125+
instance.getExternalConnectionPoolStats().getPendingConnections(),
126+
instance.getExternalConnectionPoolStats().getAvailableConnections(),
127+
instance.getExternalConnectionPoolStats().getMaxConnections()
128+
),
129+
eisMltsConnectionPoolStats,
130+
null
131+
);
101132
}
102133
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,23 @@
88
package org.elasticsearch.xpack.core.inference.action;
99

1010
import org.apache.http.pool.PoolStats;
11+
import org.elasticsearch.TransportVersion;
1112
import org.elasticsearch.cluster.ClusterName;
1213
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
1314
import org.elasticsearch.common.io.stream.Writeable;
1415
import org.elasticsearch.common.xcontent.XContentHelper;
15-
import org.elasticsearch.test.AbstractWireSerializingTestCase;
1616
import org.elasticsearch.xcontent.XContentBuilder;
1717
import org.elasticsearch.xcontent.XContentFactory;
1818
import org.elasticsearch.xcontent.XContentType;
19+
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
1920

2021
import java.io.IOException;
2122
import java.util.List;
2223

2324
import static org.hamcrest.Matchers.is;
2425

25-
public class GetInferenceDiagnosticsActionResponseTests extends AbstractWireSerializingTestCase<GetInferenceDiagnosticsAction.Response> {
26+
public class GetInferenceDiagnosticsActionResponseTests extends AbstractBWCWireSerializationTestCase<
27+
GetInferenceDiagnosticsAction.Response> {
2628

2729
public static GetInferenceDiagnosticsAction.Response createRandom() {
2830
List<GetInferenceDiagnosticsAction.NodeResponse> responses = randomList(
@@ -39,7 +41,14 @@ public void testToXContent() throws IOException {
3941
var eisPoolStats = new PoolStats(5, 6, 7, 8);
4042
var entity = new GetInferenceDiagnosticsAction.Response(
4143
ClusterName.DEFAULT,
42-
List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, externalPoolStats, eisPoolStats)),
44+
List.of(
45+
new GetInferenceDiagnosticsAction.NodeResponse(
46+
node,
47+
externalPoolStats,
48+
eisPoolStats,
49+
new GetInferenceDiagnosticsAction.NodeResponse.Stats(5, 6, 7, 8)
50+
)
51+
),
4352
List.of()
4453
);
4554

@@ -65,6 +74,12 @@ public void testToXContent() throws IOException {
6574
"available_connections":7,
6675
"max_connections":8
6776
}
77+
},
78+
"inference_endpoint_registry":{
79+
"cache_count": 5,
80+
"cache_hits": 6,
81+
"cache_misses": 7,
82+
"cache_evictions": 8
6883
}
6984
}
7085
}""")));
@@ -88,4 +103,19 @@ protected GetInferenceDiagnosticsAction.Response mutateInstance(GetInferenceDiag
88103
List.of()
89104
);
90105
}
106+
107+
@Override
108+
protected GetInferenceDiagnosticsAction.Response mutateInstanceForVersion(
109+
GetInferenceDiagnosticsAction.Response instance,
110+
TransportVersion version
111+
) {
112+
return new GetInferenceDiagnosticsAction.Response(
113+
instance.getClusterName(),
114+
instance.getNodes()
115+
.stream()
116+
.map(nodeResponse -> GetInferenceDiagnosticsActionNodeResponseTests.mutateNodeResponseForVersion(nodeResponse, version))
117+
.toList(),
118+
instance.failures()
119+
);
120+
}
91121
}

0 commit comments

Comments
 (0)