Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/133861.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 133861
summary: Implementing latency improvements for EIS integration
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ static TransportVersion def(int id) {
public static final TransportVersion GEMINI_THINKING_BUDGET_ADDED = def(9_153_0_00);
public static final TransportVersion VISIT_PERCENTAGE = def(9_154_0_00);
public static final TransportVersion TIME_SERIES_TELEMETRY = def(9_155_0_00);
public static final TransportVersion INFERENCE_API_EIS_DIAGNOSTICS = def(9_156_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.inference.action;

import org.apache.http.pool.PoolStats;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.support.nodes.BaseNodeResponse;
Expand Down Expand Up @@ -115,30 +116,53 @@ public int hashCode() {
}

public static class NodeResponse extends BaseNodeResponse implements ToXContentFragment {
static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats";
private static final String EXTERNAL_FIELD = "external";
private static final String EIS_FIELD = "eis_mtls";
private static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats";

private final ConnectionPoolStats connectionPoolStats;
private final ConnectionPoolStats externalConnectionPoolStats;
private final ConnectionPoolStats eisMtlsConnectionPoolStats;

public NodeResponse(DiscoveryNode node, PoolStats poolStats) {
public NodeResponse(DiscoveryNode node, PoolStats poolStats, PoolStats eisPoolStats) {
super(node);
connectionPoolStats = ConnectionPoolStats.of(poolStats);
externalConnectionPoolStats = ConnectionPoolStats.of(poolStats);
eisMtlsConnectionPoolStats = ConnectionPoolStats.of(eisPoolStats);
}

public NodeResponse(StreamInput in) throws IOException {
super(in);

connectionPoolStats = new ConnectionPoolStats(in);
externalConnectionPoolStats = new ConnectionPoolStats(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) {
eisMtlsConnectionPoolStats = new ConnectionPoolStats(in);
} else {
eisMtlsConnectionPoolStats = ConnectionPoolStats.EMPTY;
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
connectionPoolStats.writeTo(out);
externalConnectionPoolStats.writeTo(out);

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) {
eisMtlsConnectionPoolStats.writeTo(out);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(CONNECTION_POOL_STATS_FIELD_NAME, connectionPoolStats, params);
builder.startObject(EXTERNAL_FIELD);
{
builder.field(CONNECTION_POOL_STATS_FIELD_NAME, externalConnectionPoolStats, params);
}
builder.endObject();

builder.startObject(EIS_FIELD);
{
builder.field(CONNECTION_POOL_STATS_FIELD_NAME, eisMtlsConnectionPoolStats, params);
}
builder.endObject();
return builder;
}

Expand All @@ -147,23 +171,29 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NodeResponse response = (NodeResponse) o;
return Objects.equals(connectionPoolStats, response.connectionPoolStats);
return Objects.equals(externalConnectionPoolStats, response.externalConnectionPoolStats)
&& Objects.equals(eisMtlsConnectionPoolStats, response.eisMtlsConnectionPoolStats);
}

@Override
public int hashCode() {
return Objects.hash(connectionPoolStats);
return Objects.hash(externalConnectionPoolStats, eisMtlsConnectionPoolStats);
}

ConnectionPoolStats getExternalConnectionPoolStats() {
return externalConnectionPoolStats;
}

ConnectionPoolStats getConnectionPoolStats() {
return connectionPoolStats;
ConnectionPoolStats getEisMtlsConnectionPoolStats() {
return eisMtlsConnectionPoolStats;
}

static class ConnectionPoolStats implements ToXContentObject, Writeable {
static final String LEASED_CONNECTIONS = "leased_connections";
static final String PENDING_CONNECTIONS = "pending_connections";
static final String AVAILABLE_CONNECTIONS = "available_connections";
static final String MAX_CONNECTIONS = "max_connections";
private static final String LEASED_CONNECTIONS = "leased_connections";
private static final String PENDING_CONNECTIONS = "pending_connections";
private static final String AVAILABLE_CONNECTIONS = "available_connections";
private static final String MAX_CONNECTIONS = "max_connections";
private static final ConnectionPoolStats EMPTY = new ConnectionPoolStats(0, 0, 0, 0);

static ConnectionPoolStats of(PoolStats poolStats) {
return new ConnectionPoolStats(poolStats.getLeased(), poolStats.getPending(), poolStats.getAvailable(), poolStats.getMax());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,25 @@
package org.elasticsearch.xpack.core.inference.action;

import org.apache.http.pool.PoolStats;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;

import java.io.IOException;
import java.io.UnsupportedEncodingException;

public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractWireSerializingTestCase<
public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractBWCWireSerializationTestCase<
GetInferenceDiagnosticsAction.NodeResponse> {
public static GetInferenceDiagnosticsAction.NodeResponse createRandom() {
DiscoveryNode node = DiscoveryNodeUtils.create("id");
var randomPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt());
var randomExternalPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt());
var randomEisPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt());

return new GetInferenceDiagnosticsAction.NodeResponse(node, randomPoolStats);
return new GetInferenceDiagnosticsAction.NodeResponse(node, randomExternalPoolStats, randomEisPoolStats);
}

@Override
Expand All @@ -39,47 +42,61 @@ protected GetInferenceDiagnosticsAction.NodeResponse createTestInstance() {
@Override
protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInferenceDiagnosticsAction.NodeResponse instance)
throws IOException {
var select = randomIntBetween(0, 3);
var connPoolStats = instance.getConnectionPoolStats();
if (randomBoolean()) {
PoolStats mutatedConnPoolStats = mutatePoolStats(instance.getExternalConnectionPoolStats());
PoolStats eisPoolStats = copyPoolStats(instance.getEisMtlsConnectionPoolStats());
return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), mutatedConnPoolStats, eisPoolStats);
} else {
PoolStats connPoolStats = copyPoolStats(instance.getExternalConnectionPoolStats());
PoolStats mutatedEisPoolStats = mutatePoolStats(instance.getEisMtlsConnectionPoolStats());
return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), connPoolStats, mutatedEisPoolStats);
}
}

private PoolStats mutatePoolStats(GetInferenceDiagnosticsAction.NodeResponse.ConnectionPoolStats stats)
throws UnsupportedEncodingException {
var select = randomIntBetween(0, 3);
return switch (select) {
case 0 -> new GetInferenceDiagnosticsAction.NodeResponse(
instance.getNode(),
new PoolStats(
randomInt(),
connPoolStats.getPendingConnections(),
connPoolStats.getAvailableConnections(),
connPoolStats.getMaxConnections()
)
);
case 1 -> new GetInferenceDiagnosticsAction.NodeResponse(
instance.getNode(),
new PoolStats(
connPoolStats.getLeasedConnections(),
randomInt(),
connPoolStats.getAvailableConnections(),
connPoolStats.getMaxConnections()
)
);
case 2 -> new GetInferenceDiagnosticsAction.NodeResponse(
instance.getNode(),
new PoolStats(
connPoolStats.getLeasedConnections(),
connPoolStats.getPendingConnections(),
randomInt(),
connPoolStats.getMaxConnections()
)
case 0 -> new PoolStats(randomInt(), stats.getPendingConnections(), stats.getAvailableConnections(), stats.getMaxConnections());
case 1 -> new PoolStats(stats.getLeasedConnections(), randomInt(), stats.getAvailableConnections(), stats.getMaxConnections());
case 2 -> new PoolStats(stats.getLeasedConnections(), stats.getPendingConnections(), randomInt(), stats.getMaxConnections());
case 3 -> new PoolStats(
stats.getLeasedConnections(),
stats.getPendingConnections(),
stats.getAvailableConnections(),
randomInt()
);
case 3 -> new GetInferenceDiagnosticsAction.NodeResponse(
default -> throw new UnsupportedEncodingException(Strings.format("Encountered unsupported case %s", select));
};
}

private PoolStats copyPoolStats(GetInferenceDiagnosticsAction.NodeResponse.ConnectionPoolStats stats) {
return new PoolStats(
stats.getLeasedConnections(),
stats.getPendingConnections(),
stats.getAvailableConnections(),
stats.getMaxConnections()
);
}

@Override
protected GetInferenceDiagnosticsAction.NodeResponse mutateInstanceForVersion(
GetInferenceDiagnosticsAction.NodeResponse instance,
TransportVersion version
) {
if (version.before(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) {
return new GetInferenceDiagnosticsAction.NodeResponse(
instance.getNode(),
new PoolStats(
connPoolStats.getLeasedConnections(),
connPoolStats.getPendingConnections(),
connPoolStats.getAvailableConnections(),
randomInt()
)
instance.getExternalConnectionPoolStats().getLeasedConnections(),
instance.getExternalConnectionPoolStats().getPendingConnections(),
instance.getExternalConnectionPoolStats().getAvailableConnections(),
instance.getExternalConnectionPoolStats().getMaxConnections()
),
new PoolStats(0, 0, 0, 0)
);
default -> throw new UnsupportedEncodingException(Strings.format("Encountered unsupported case %s", select));
};
} else {
return instance;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.hamcrest.CoreMatchers;

import java.io.IOException;
import java.util.List;

import static org.hamcrest.Matchers.is;

public class GetInferenceDiagnosticsActionResponseTests extends AbstractWireSerializingTestCase<GetInferenceDiagnosticsAction.Response> {

public static GetInferenceDiagnosticsAction.Response createRandom() {
Expand All @@ -33,20 +35,39 @@ public static GetInferenceDiagnosticsAction.Response createRandom() {

public void testToXContent() throws IOException {
var node = DiscoveryNodeUtils.create("id");
var poolStats = new PoolStats(1, 2, 3, 4);
var externalPoolStats = new PoolStats(1, 2, 3, 4);
var eisPoolStats = new PoolStats(5, 6, 7, 8);
var entity = new GetInferenceDiagnosticsAction.Response(
ClusterName.DEFAULT,
List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, poolStats)),
List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, externalPoolStats, eisPoolStats)),
List.of()
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = org.elasticsearch.common.Strings.toString(builder);

assertThat(xContentResult, CoreMatchers.is("""
{"id":{"connection_pool_stats":{"leased_connections":1,"pending_connections":2,"available_connections":3,""" + """
"max_connections":4}}}"""));
assertThat(xContentResult, is(XContentHelper.stripWhitespace("""
{
"id":{
"external": {
"connection_pool_stats":{
"leased_connections":1,
"pending_connections":2,
"available_connections":3,
"max_connections":4
}
},
"eis_mtls": {
"connection_pool_stats":{
"leased_connections":5,
"pending_connections":6,
"available_connections":7,
"max_connections":8
}
}
}
}""")));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ public class InferencePlugin extends Plugin

public static final String NAME = "inference";
public static final String UTILITY_THREAD_POOL_NAME = "inference_utility";
public static final String INFERENCE_RESPONSE_THREAD_POOL_NAME = "inference_response";

private static final Logger log = LogManager.getLogger(InferencePlugin.class);

Expand Down Expand Up @@ -374,7 +375,9 @@ public Collection<?> createComponents(PluginServices services) {

components.add(serviceRegistry);
components.add(modelRegistry.get());
components.add(httpClientManager);
components.add(
new TransportGetInferenceDiagnosticsAction.ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager)
);
components.add(inferenceStatsBinding);

// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
Expand Down Expand Up @@ -497,10 +500,10 @@ protected Settings getSecretsIndexSettings() {

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settingsToUse) {
return List.of(inferenceUtilityExecutor(settings));
return List.of(inferenceUtilityExecutor(), inferenceResponseExecutor());
}

public static ExecutorBuilder<?> inferenceUtilityExecutor(Settings settings) {
private static ExecutorBuilder<?> inferenceUtilityExecutor() {
return new ScalingExecutorBuilder(
UTILITY_THREAD_POOL_NAME,
0,
Expand All @@ -511,6 +514,17 @@ public static ExecutorBuilder<?> inferenceUtilityExecutor(Settings settings) {
);
}

private static ExecutorBuilder<?> inferenceResponseExecutor() {
return new ScalingExecutorBuilder(
INFERENCE_RESPONSE_THREAD_POOL_NAME,
0,
10,
TimeValue.timeValueMinutes(10),
false,
"xpack.inference.inference_response_thread_pool"
);
}

@Override
public List<Setting<?>> getSettings() {
return List.copyOf(getInferenceSettings());
Expand Down
Loading