Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/133860.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 133860
summary: Cache Inference Endpoints
area: Machine Learning
type: enhancement
issues:
- 133135
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_SAMPLE_OPERATOR_STATUS = def(9_127_0_00);
public static final TransportVersion TIME_SERIES_TELEMETRY = def(9_155_0_00);
public static final TransportVersion INFERENCE_API_EIS_DIAGNOSTICS = def(9_156_0_00);
public static final TransportVersion ML_INFERENCE_ENDPOINT_CACHE = def(9_157_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import org.elasticsearch.action.support.nodes.BaseNodesResponse;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.transport.AbstractTransportRequest;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.ToXContentObject;
Expand All @@ -28,6 +30,8 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.ML_INFERENCE_ENDPOINT_CACHE;

public class GetInferenceDiagnosticsAction extends ActionType<GetInferenceDiagnosticsAction.Response> {

public static final GetInferenceDiagnosticsAction INSTANCE = new GetInferenceDiagnosticsAction();
Expand Down Expand Up @@ -119,14 +123,23 @@ public static class NodeResponse extends BaseNodeResponse implements ToXContentF
private static final String EXTERNAL_FIELD = "external";
private static final String EIS_FIELD = "eis_mtls";
private static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats";
static final String INFERENCE_ENDPOINT_REGISTRY_STATS_FIELD_NAME = "inference_endpoint_registry";

private final ConnectionPoolStats externalConnectionPoolStats;
private final ConnectionPoolStats eisMtlsConnectionPoolStats;

public NodeResponse(DiscoveryNode node, PoolStats poolStats, PoolStats eisPoolStats) {
@Nullable
private final Stats inferenceEndpointRegistryStats;

public NodeResponse(
DiscoveryNode node,
PoolStats poolStats,
PoolStats eisPoolStats,
@Nullable Cache.Stats inferenceEndpointRegistryStats
) {
super(node);
externalConnectionPoolStats = ConnectionPoolStats.of(poolStats);
eisMtlsConnectionPoolStats = ConnectionPoolStats.of(eisPoolStats);
this.inferenceEndpointRegistryStats = inferenceEndpointRegistryStats != null ? Stats.of(inferenceEndpointRegistryStats) : null;
}

public NodeResponse(StreamInput in) throws IOException {
Expand All @@ -138,6 +151,9 @@ public NodeResponse(StreamInput in) throws IOException {
} else {
eisMtlsConnectionPoolStats = ConnectionPoolStats.EMPTY;
}
inferenceEndpointRegistryStats = in.getTransportVersion().onOrAfter(ML_INFERENCE_ENDPOINT_CACHE)
? in.readOptionalWriteable(Stats::new)
: null;
}

@Override
Expand All @@ -148,6 +164,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) {
eisMtlsConnectionPoolStats.writeTo(out);
}
if (out.getTransportVersion().onOrAfter(ML_INFERENCE_ENDPOINT_CACHE)) {
out.writeOptionalWriteable(inferenceEndpointRegistryStats);
}
}

@Override
Expand All @@ -163,6 +182,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(CONNECTION_POOL_STATS_FIELD_NAME, eisMtlsConnectionPoolStats, params);
}
builder.endObject();
if (inferenceEndpointRegistryStats != null) {
builder.field(INFERENCE_ENDPOINT_REGISTRY_STATS_FIELD_NAME, inferenceEndpointRegistryStats, params);
}
return builder;
}

Expand All @@ -172,12 +194,13 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
NodeResponse response = (NodeResponse) o;
return Objects.equals(externalConnectionPoolStats, response.externalConnectionPoolStats)
&& Objects.equals(eisMtlsConnectionPoolStats, response.eisMtlsConnectionPoolStats);
&& Objects.equals(eisMtlsConnectionPoolStats, response.eisMtlsConnectionPoolStats)
&& Objects.equals(inferenceEndpointRegistryStats, response.inferenceEndpointRegistryStats);
}

@Override
public int hashCode() {
return Objects.hash(externalConnectionPoolStats, eisMtlsConnectionPoolStats);
return Objects.hash(externalConnectionPoolStats, eisMtlsConnectionPoolStats, inferenceEndpointRegistryStats);
}

ConnectionPoolStats getExternalConnectionPoolStats() {
Expand All @@ -188,6 +211,10 @@ ConnectionPoolStats getEisMtlsConnectionPoolStats() {
return eisMtlsConnectionPoolStats;
}

public Stats getInferenceEndpointRegistryStats() {
return inferenceEndpointRegistryStats;
}

static class ConnectionPoolStats implements ToXContentObject, Writeable {
private static final String LEASED_CONNECTIONS = "leased_connections";
private static final String PENDING_CONNECTIONS = "pending_connections";
Expand Down Expand Up @@ -270,5 +297,36 @@ int getMaxConnections() {
return maxConnections;
}
}

public record Stats(long hits, long misses, long evictions) implements ToXContentObject, Writeable {

private static final String CACHE_HITS = "cache_hits";
private static final String CACHE_MISSES = "cache_misses";
private static final String CACHE_EVICTIONS = "cache_evictions";

public Stats(StreamInput in) throws IOException {
this(in.readLong(), in.readLong(), in.readLong());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeLong(hits);
out.writeLong(misses);
out.writeLong(evictions);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.startObject()
.field(CACHE_HITS, hits)
.field(CACHE_MISSES, misses)
.field(CACHE_EVICTIONS, evictions)
.endObject();
}

public static Stats of(Cache.Stats cacheStats) {
return new Stats(cacheStats.getHits(), cacheStats.getMisses(), cacheStats.getEvictions());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;

Expand All @@ -26,7 +27,7 @@ public static GetInferenceDiagnosticsAction.NodeResponse createRandom() {
var randomExternalPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt());
var randomEisPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt());

return new GetInferenceDiagnosticsAction.NodeResponse(node, randomExternalPoolStats, randomEisPoolStats);
return new GetInferenceDiagnosticsAction.NodeResponse(node, randomExternalPoolStats, randomEisPoolStats, randomCacheStats());
}

@Override
Expand All @@ -45,11 +46,16 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
if (randomBoolean()) {
PoolStats mutatedConnPoolStats = mutatePoolStats(instance.getExternalConnectionPoolStats());
PoolStats eisPoolStats = copyPoolStats(instance.getEisMtlsConnectionPoolStats());
return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), mutatedConnPoolStats, eisPoolStats);
return new GetInferenceDiagnosticsAction.NodeResponse(
instance.getNode(),
mutatedConnPoolStats,
eisPoolStats,
randomCacheStats()
);
} else {
PoolStats connPoolStats = copyPoolStats(instance.getExternalConnectionPoolStats());
PoolStats mutatedEisPoolStats = mutatePoolStats(instance.getEisMtlsConnectionPoolStats());
return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), connPoolStats, mutatedEisPoolStats);
return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), connPoolStats, mutatedEisPoolStats, null);
}
}

Expand Down Expand Up @@ -79,24 +85,45 @@ private PoolStats copyPoolStats(GetInferenceDiagnosticsAction.NodeResponse.Conne
);
}

private static Cache.Stats randomCacheStats() {
return new Cache.Stats(randomLong(), randomLong(), randomLong());
}

@Override
protected GetInferenceDiagnosticsAction.NodeResponse mutateInstanceForVersion(
GetInferenceDiagnosticsAction.NodeResponse instance,
TransportVersion version
) {
if (version.before(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) {
return new GetInferenceDiagnosticsAction.NodeResponse(
instance.getNode(),
new PoolStats(
instance.getExternalConnectionPoolStats().getLeasedConnections(),
instance.getExternalConnectionPoolStats().getPendingConnections(),
instance.getExternalConnectionPoolStats().getAvailableConnections(),
instance.getExternalConnectionPoolStats().getMaxConnections()
),
new PoolStats(0, 0, 0, 0)
);
} else {
return mutateNodeResponseForVersion(instance, version);
}

public static GetInferenceDiagnosticsAction.NodeResponse mutateNodeResponseForVersion(
GetInferenceDiagnosticsAction.NodeResponse instance,
TransportVersion version
) {
if (version.onOrAfter(TransportVersions.ML_INFERENCE_ENDPOINT_CACHE)) {
return instance;
}

var eisMltsConnectionPoolStats = version.onOrAfter(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)
? new PoolStats(
instance.getEisMtlsConnectionPoolStats().getLeasedConnections(),
instance.getEisMtlsConnectionPoolStats().getPendingConnections(),
instance.getEisMtlsConnectionPoolStats().getAvailableConnections(),
instance.getEisMtlsConnectionPoolStats().getMaxConnections()
)
: new PoolStats(0, 0, 0, 0);

return new GetInferenceDiagnosticsAction.NodeResponse(
instance.getNode(),
new PoolStats(
instance.getExternalConnectionPoolStats().getLeasedConnections(),
instance.getExternalConnectionPoolStats().getPendingConnections(),
instance.getExternalConnectionPoolStats().getAvailableConnections(),
instance.getExternalConnectionPoolStats().getMaxConnections()
),
eisMltsConnectionPoolStats,
null
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,24 @@
package org.elasticsearch.xpack.core.inference.action;

import org.apache.http.pool.PoolStats;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;

import java.io.IOException;
import java.util.List;

import static org.hamcrest.Matchers.is;

public class GetInferenceDiagnosticsActionResponseTests extends AbstractWireSerializingTestCase<GetInferenceDiagnosticsAction.Response> {
public class GetInferenceDiagnosticsActionResponseTests extends AbstractBWCWireSerializationTestCase<
GetInferenceDiagnosticsAction.Response> {

public static GetInferenceDiagnosticsAction.Response createRandom() {
List<GetInferenceDiagnosticsAction.NodeResponse> responses = randomList(
Expand All @@ -39,7 +42,7 @@ public void testToXContent() throws IOException {
var eisPoolStats = new PoolStats(5, 6, 7, 8);
var entity = new GetInferenceDiagnosticsAction.Response(
ClusterName.DEFAULT,
List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, externalPoolStats, eisPoolStats)),
List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, externalPoolStats, eisPoolStats, new Cache.Stats(5, 6, 7))),
List.of()
);

Expand All @@ -65,6 +68,11 @@ public void testToXContent() throws IOException {
"available_connections":7,
"max_connections":8
}
},
"inference_endpoint_registry":{
"cache_hits": 5,
"cache_misses": 6,
"cache_evictions": 7
}
}
}""")));
Expand All @@ -88,4 +96,19 @@ protected GetInferenceDiagnosticsAction.Response mutateInstance(GetInferenceDiag
List.of()
);
}

@Override
protected GetInferenceDiagnosticsAction.Response mutateInstanceForVersion(
GetInferenceDiagnosticsAction.Response instance,
TransportVersion version
) {
return new GetInferenceDiagnosticsAction.Response(
instance.getClusterName(),
instance.getNodes()
.stream()
.map(nodeResponse -> GetInferenceDiagnosticsActionNodeResponseTests.mutateNodeResponseForVersion(nodeResponse, version))
.toList(),
instance.failures()
);
}
}
Loading