Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/133860.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 133860
summary: Cache Inference Endpoints
area: Machine Learning
type: enhancement
issues:
- 133135
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ static TransportVersion def(int id) {
public static final TransportVersion PROJECT_RESERVED_STATE_MOVE_TO_REGISTRY = def(9_147_0_00);
public static final TransportVersion STREAMS_ENDPOINT_PARAM_RESTRICTIONS = def(9_148_0_00);
public static final TransportVersion RESOLVE_INDEX_MODE_FILTER = def(9_149_0_00);
public static final TransportVersion ML_INFERENCE_ENDPOINT_CACHE = def(9_150_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

package org.elasticsearch.xpack.core.inference;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xcontent.ToXContentObject;

public interface SerializableStats extends ToXContentObject, Writeable {

}
public interface SerializableStats extends ToXContentObject, NamedWriteable {}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.transport.AbstractTransportRequest;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.SerializableStats;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.ML_INFERENCE_ENDPOINT_CACHE;

public class GetInferenceDiagnosticsAction extends ActionType<GetInferenceDiagnosticsAction.Response> {

public static final GetInferenceDiagnosticsAction INSTANCE = new GetInferenceDiagnosticsAction();
Expand Down Expand Up @@ -116,29 +120,42 @@ public int hashCode() {

public static class NodeResponse extends BaseNodeResponse implements ToXContentFragment {
static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats";
static final String INFERENCE_ENDPOINT_REGISTRY_STATS_FIELD_NAME = "inference_endpoint_registry";

private final ConnectionPoolStats connectionPoolStats;
@Nullable
private final SerializableStats inferenceEndpointRegistryStats;

public NodeResponse(DiscoveryNode node, PoolStats poolStats) {
public NodeResponse(DiscoveryNode node, PoolStats poolStats, SerializableStats inferenceEndpointRegistryStats) {
super(node);
connectionPoolStats = ConnectionPoolStats.of(poolStats);
this.inferenceEndpointRegistryStats = inferenceEndpointRegistryStats;
}

public NodeResponse(StreamInput in) throws IOException {
super(in);

connectionPoolStats = new ConnectionPoolStats(in);
inferenceEndpointRegistryStats = in.getTransportVersion().onOrAfter(ML_INFERENCE_ENDPOINT_CACHE)
? in.readOptionalNamedWriteable(SerializableStats.class)
: null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
connectionPoolStats.writeTo(out);
if (out.getTransportVersion().onOrAfter(ML_INFERENCE_ENDPOINT_CACHE)) {
out.writeOptionalNamedWriteable(inferenceEndpointRegistryStats);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(CONNECTION_POOL_STATS_FIELD_NAME, connectionPoolStats, params);
if (inferenceEndpointRegistryStats != null) {
builder.field(INFERENCE_ENDPOINT_REGISTRY_STATS_FIELD_NAME, inferenceEndpointRegistryStats, params);
}
return builder;
}

Expand All @@ -147,18 +164,23 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NodeResponse response = (NodeResponse) o;
return Objects.equals(connectionPoolStats, response.connectionPoolStats);
return Objects.equals(connectionPoolStats, response.connectionPoolStats)
&& Objects.equals(inferenceEndpointRegistryStats, response.inferenceEndpointRegistryStats);
}

@Override
public int hashCode() {
return Objects.hash(connectionPoolStats);
return Objects.hash(connectionPoolStats, inferenceEndpointRegistryStats);
}

ConnectionPoolStats getConnectionPoolStats() {
return connectionPoolStats;
}

public SerializableStats getInferenceEndpointRegistryStats() {
return inferenceEndpointRegistryStats;
}

static class ConnectionPoolStats implements ToXContentObject, Writeable {
static final String LEASED_CONNECTIONS = "leased_connections";
static final String PENDING_CONNECTIONS = "pending_connections";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,41 @@
package org.elasticsearch.xpack.core.inference.action;

import org.apache.http.pool.PoolStats;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.SerializableStats;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.List;

public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractWireSerializingTestCase<
public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractBWCWireSerializationTestCase<
GetInferenceDiagnosticsAction.NodeResponse> {
public static GetInferenceDiagnosticsAction.NodeResponse createRandom() {
DiscoveryNode node = DiscoveryNodeUtils.create("id");
var randomPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt());

return new GetInferenceDiagnosticsAction.NodeResponse(node, randomPoolStats);
return new GetInferenceDiagnosticsAction.NodeResponse(node, randomPoolStats, new TestStats(randomInt()));
}

@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return registryWithTestStats();
}

public static NamedWriteableRegistry registryWithTestStats() {
return new NamedWriteableRegistry(
List.of(new NamedWriteableRegistry.Entry(SerializableStats.class, TestStats.NAME, TestStats::new))
);
}

@Override
Expand All @@ -50,7 +69,8 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
connPoolStats.getPendingConnections(),
connPoolStats.getAvailableConnections(),
connPoolStats.getMaxConnections()
)
),
randomTestStats()
);
case 1 -> new GetInferenceDiagnosticsAction.NodeResponse(
instance.getNode(),
Expand All @@ -59,7 +79,8 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
randomInt(),
connPoolStats.getAvailableConnections(),
connPoolStats.getMaxConnections()
)
),
randomTestStats()
);
case 2 -> new GetInferenceDiagnosticsAction.NodeResponse(
instance.getNode(),
Expand All @@ -68,7 +89,8 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
connPoolStats.getPendingConnections(),
randomInt(),
connPoolStats.getMaxConnections()
)
),
randomTestStats()
);
case 3 -> new GetInferenceDiagnosticsAction.NodeResponse(
instance.getNode(),
Expand All @@ -77,9 +99,58 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
connPoolStats.getPendingConnections(),
connPoolStats.getAvailableConnections(),
randomInt()
)
),
randomTestStats()
);
default -> throw new UnsupportedEncodingException(Strings.format("Encountered unsupported case %s", select));
};
}

public static SerializableStats randomTestStats() {
return new TestStats(randomInt());
}

public record TestStats(int count) implements SerializableStats {
public static final String NAME = "test_stats";

public TestStats(StreamInput in) throws IOException {
this(in.readInt());
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(count);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.startObject().field("count", count).endObject();
}
}

@Override
protected GetInferenceDiagnosticsAction.NodeResponse mutateInstanceForVersion(
GetInferenceDiagnosticsAction.NodeResponse instance,
TransportVersion version
) {
if (version.before(TransportVersions.ML_INFERENCE_ENDPOINT_CACHE)) {
return new GetInferenceDiagnosticsAction.NodeResponse(
instance.getNode(),
new PoolStats(
instance.getConnectionPoolStats().getLeasedConnections(),
instance.getConnectionPoolStats().getPendingConnections(),
instance.getConnectionPoolStats().getAvailableConnections(),
instance.getConnectionPoolStats().getMaxConnections()
),
null
);
} else {
return instance;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.http.pool.PoolStats;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
Expand All @@ -36,7 +37,13 @@ public void testToXContent() throws IOException {
var poolStats = new PoolStats(1, 2, 3, 4);
var entity = new GetInferenceDiagnosticsAction.Response(
ClusterName.DEFAULT,
List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, poolStats)),
List.of(
new GetInferenceDiagnosticsAction.NodeResponse(
node,
poolStats,
new GetInferenceDiagnosticsActionNodeResponseTests.TestStats(5)
)
),
List.of()
);

Expand All @@ -46,7 +53,7 @@ public void testToXContent() throws IOException {

assertThat(xContentResult, CoreMatchers.is("""
{"id":{"connection_pool_stats":{"leased_connections":1,"pending_connections":2,"available_connections":3,""" + """
"max_connections":4}}}"""));
"max_connections":4},"inference_endpoint_registry":{"count":5}}}"""));
}

@Override
Expand All @@ -67,4 +74,9 @@ protected GetInferenceDiagnosticsAction.Response mutateInstance(GetInferenceDiag
List.of()
);
}

@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return GetInferenceDiagnosticsActionNodeResponseTests.registryWithTestStats();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ public void testAttachWithModelId() throws IOException {
var results = infer(inferenceId, List.of("washing machine"));
assertNotNull(results.get("sparse_embedding"));

deleteModel(inferenceId);

forceStopMlNodeDeployment(deploymentId);
}

Expand Down Expand Up @@ -225,6 +227,7 @@ public void testNumAllocationsIsUpdated() throws IOException {
)
);

deleteModel(inferenceId);
forceStopMlNodeDeployment(deploymentId);
}

Expand Down Expand Up @@ -266,6 +269,7 @@ public void testUpdateWhenInferenceEndpointCreatesDeployment() throws IOExceptio
is(Map.of("num_allocations", 2, "num_threads", 1, "model_id", modelId))
);

deleteModel(inferenceId);
forceStopMlNodeDeployment(deploymentId);
}

Expand Down Expand Up @@ -309,6 +313,8 @@ public void testCannotUpdateAnotherInferenceEndpointsCreatedDeployment() throws
)
);

deleteModel(inferenceId);
deleteModel(secondInferenceId);
forceStopMlNodeDeployment(deploymentId);
}

Expand All @@ -331,6 +337,7 @@ public void testStoppingDeploymentAttachedToInferenceEndpoint() throws IOExcepti
)
);

deleteModel(inferenceId);
// Force stop will stop the deployment
forceStopMlNodeDeployment(deploymentId);
}
Expand Down Expand Up @@ -358,16 +365,6 @@ private String endpointConfig(String modelId, String deploymentId) {
""", modelId, deploymentId);
}

private String updatedEndpointConfig(int numAllocations) {
return Strings.format("""
{
"service_settings": {
"num_allocations": %d
}
}
""", numAllocations);
}

private Response startMlNodeDeploymemnt(String modelId, String deploymentId) throws IOException {
String endPoint = "/_ml/trained_models/"
+ modelId
Expand Down Expand Up @@ -413,16 +410,6 @@ private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations
return client().performRequest(request);
}

private Map<String, Object> updateMlNodeDeploymemnt(String deploymentId, String body) throws IOException {
String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update";

Request request = new Request("POST", endPoint);
request.setJsonEntity(body);
var response = client().performRequest(request);
assertStatusOkOrCreated(response);
return entityAsMap(response);
}

protected void stopMlNodeDeployment(String deploymentId) throws IOException {
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
Request request = new Request("POST", endpoint);
Expand Down
Loading