Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/133861.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 133861
summary: Implementing latency improvements for EIS integration
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = def(9_150_0_00);
public static final TransportVersion ESQL_LOOKUP_JOIN_PRE_JOIN_FILTER = def(9_151_0_00);
public static final TransportVersion INFERENCE_API_DISABLE_EIS_RATE_LIMITING = def(9_152_0_00);
public static final TransportVersion INFERENCE_API_EIS_DIAGNOSTICS = def(9_153_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.inference.action;

import org.apache.http.pool.PoolStats;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.support.nodes.BaseNodeResponse;
Expand Down Expand Up @@ -115,30 +116,53 @@ public int hashCode() {
}

public static class NodeResponse extends BaseNodeResponse implements ToXContentFragment {
static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats";
private static final String EXTERNAL_FIELD = "external";
private static final String EIS_FIELD = "eis_mtls";
private static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats";

private final ConnectionPoolStats connectionPoolStats;
private final ConnectionPoolStats externalConnectionPoolStats;
private final ConnectionPoolStats eisMtlsConnectionPoolStats;

public NodeResponse(DiscoveryNode node, PoolStats poolStats) {
public NodeResponse(DiscoveryNode node, PoolStats poolStats, PoolStats eisPoolStats) {
super(node);
connectionPoolStats = ConnectionPoolStats.of(poolStats);
externalConnectionPoolStats = ConnectionPoolStats.of(poolStats);
eisMtlsConnectionPoolStats = ConnectionPoolStats.of(eisPoolStats);
}

public NodeResponse(StreamInput in) throws IOException {
super(in);

connectionPoolStats = new ConnectionPoolStats(in);
externalConnectionPoolStats = new ConnectionPoolStats(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) {
eisMtlsConnectionPoolStats = new ConnectionPoolStats(in);
} else {
eisMtlsConnectionPoolStats = ConnectionPoolStats.EMPTY;
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
connectionPoolStats.writeTo(out);
externalConnectionPoolStats.writeTo(out);

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) {
eisMtlsConnectionPoolStats.writeTo(out);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(CONNECTION_POOL_STATS_FIELD_NAME, connectionPoolStats, params);
builder.startObject(EXTERNAL_FIELD);
{
builder.field(CONNECTION_POOL_STATS_FIELD_NAME, externalConnectionPoolStats, params);
}
builder.endObject();

builder.startObject(EIS_FIELD);
{
builder.field(CONNECTION_POOL_STATS_FIELD_NAME, eisMtlsConnectionPoolStats, params);
}
builder.endObject();
return builder;
}

Expand All @@ -147,23 +171,29 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NodeResponse response = (NodeResponse) o;
return Objects.equals(connectionPoolStats, response.connectionPoolStats);
return Objects.equals(externalConnectionPoolStats, response.externalConnectionPoolStats)
&& Objects.equals(eisMtlsConnectionPoolStats, response.eisMtlsConnectionPoolStats);
}

@Override
public int hashCode() {
return Objects.hash(connectionPoolStats);
return Objects.hash(externalConnectionPoolStats, eisMtlsConnectionPoolStats);
}

ConnectionPoolStats getExternalConnectionPoolStats() {
return externalConnectionPoolStats;
}

ConnectionPoolStats getConnectionPoolStats() {
return connectionPoolStats;
ConnectionPoolStats getEisMtlsConnectionPoolStats() {
return eisMtlsConnectionPoolStats;
}

static class ConnectionPoolStats implements ToXContentObject, Writeable {
static final String LEASED_CONNECTIONS = "leased_connections";
static final String PENDING_CONNECTIONS = "pending_connections";
static final String AVAILABLE_CONNECTIONS = "available_connections";
static final String MAX_CONNECTIONS = "max_connections";
private static final String LEASED_CONNECTIONS = "leased_connections";
private static final String PENDING_CONNECTIONS = "pending_connections";
private static final String AVAILABLE_CONNECTIONS = "available_connections";
private static final String MAX_CONNECTIONS = "max_connections";
private static final ConnectionPoolStats EMPTY = new ConnectionPoolStats(0, 0, 0, 0);

static ConnectionPoolStats of(PoolStats poolStats) {
return new ConnectionPoolStats(poolStats.getLeased(), poolStats.getPending(), poolStats.getAvailable(), poolStats.getMax());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,25 @@
package org.elasticsearch.xpack.core.inference.action;

import org.apache.http.pool.PoolStats;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;

import java.io.IOException;
import java.io.UnsupportedEncodingException;

public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractWireSerializingTestCase<
public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractBWCWireSerializationTestCase<
GetInferenceDiagnosticsAction.NodeResponse> {
public static GetInferenceDiagnosticsAction.NodeResponse createRandom() {
DiscoveryNode node = DiscoveryNodeUtils.create("id");
var randomPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt());
var randomExternalPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt());
var randomEisPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt());

return new GetInferenceDiagnosticsAction.NodeResponse(node, randomPoolStats);
return new GetInferenceDiagnosticsAction.NodeResponse(node, randomExternalPoolStats, randomEisPoolStats);
}

@Override
Expand All @@ -40,7 +43,8 @@ protected GetInferenceDiagnosticsAction.NodeResponse createTestInstance() {
protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInferenceDiagnosticsAction.NodeResponse instance)
throws IOException {
var select = randomIntBetween(0, 3);
var connPoolStats = instance.getConnectionPoolStats();
var connPoolStats = instance.getExternalConnectionPoolStats();
var eisPoolStats = instance.getEisMtlsConnectionPoolStats();

return switch (select) {
case 0 -> new GetInferenceDiagnosticsAction.NodeResponse(
Expand All @@ -50,6 +54,12 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
connPoolStats.getPendingConnections(),
connPoolStats.getAvailableConnections(),
connPoolStats.getMaxConnections()
),
new PoolStats(
randomInt(),
eisPoolStats.getPendingConnections(),
eisPoolStats.getAvailableConnections(),
eisPoolStats.getMaxConnections()
)
);
case 1 -> new GetInferenceDiagnosticsAction.NodeResponse(
Expand All @@ -59,6 +69,12 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
randomInt(),
connPoolStats.getAvailableConnections(),
connPoolStats.getMaxConnections()
),
new PoolStats(
eisPoolStats.getLeasedConnections(),
randomInt(),
eisPoolStats.getAvailableConnections(),
eisPoolStats.getMaxConnections()
)
);
case 2 -> new GetInferenceDiagnosticsAction.NodeResponse(
Expand All @@ -68,6 +84,12 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
connPoolStats.getPendingConnections(),
randomInt(),
connPoolStats.getMaxConnections()
),
new PoolStats(
eisPoolStats.getLeasedConnections(),
eisPoolStats.getPendingConnections(),
randomInt(),
eisPoolStats.getMaxConnections()
)
);
case 3 -> new GetInferenceDiagnosticsAction.NodeResponse(
Expand All @@ -77,9 +99,36 @@ protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInference
connPoolStats.getPendingConnections(),
connPoolStats.getAvailableConnections(),
randomInt()
),
new PoolStats(
eisPoolStats.getLeasedConnections(),
eisPoolStats.getPendingConnections(),
eisPoolStats.getAvailableConnections(),
randomInt()
)
);
default -> throw new UnsupportedEncodingException(Strings.format("Encountered unsupported case %s", select));
};
}

@Override
protected GetInferenceDiagnosticsAction.NodeResponse mutateInstanceForVersion(
GetInferenceDiagnosticsAction.NodeResponse instance,
TransportVersion version
) {
if (version.before(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) {
return new GetInferenceDiagnosticsAction.NodeResponse(
instance.getNode(),
new PoolStats(
instance.getExternalConnectionPoolStats().getLeasedConnections(),
instance.getExternalConnectionPoolStats().getPendingConnections(),
instance.getExternalConnectionPoolStats().getAvailableConnections(),
instance.getExternalConnectionPoolStats().getMaxConnections()
),
new PoolStats(0, 0, 0, 0)
);
} else {
return instance;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.hamcrest.CoreMatchers;

import java.io.IOException;
import java.util.List;

import static org.hamcrest.Matchers.is;

public class GetInferenceDiagnosticsActionResponseTests extends AbstractWireSerializingTestCase<GetInferenceDiagnosticsAction.Response> {

public static GetInferenceDiagnosticsAction.Response createRandom() {
Expand All @@ -33,20 +35,39 @@ public static GetInferenceDiagnosticsAction.Response createRandom() {

public void testToXContent() throws IOException {
var node = DiscoveryNodeUtils.create("id");
var poolStats = new PoolStats(1, 2, 3, 4);
var externalPoolStats = new PoolStats(1, 2, 3, 4);
var eisPoolStats = new PoolStats(5, 6, 7, 8);
var entity = new GetInferenceDiagnosticsAction.Response(
ClusterName.DEFAULT,
List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, poolStats)),
List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, externalPoolStats, eisPoolStats)),
List.of()
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = org.elasticsearch.common.Strings.toString(builder);

assertThat(xContentResult, CoreMatchers.is("""
{"id":{"connection_pool_stats":{"leased_connections":1,"pending_connections":2,"available_connections":3,""" + """
"max_connections":4}}}"""));
assertThat(xContentResult, is(XContentHelper.stripWhitespace("""
{
"id":{
"external": {
"connection_pool_stats":{
"leased_connections":1,
"pending_connections":2,
"available_connections":3,
"max_connections":4
}
},
"eis_mtls": {
"connection_pool_stats":{
"leased_connections":5,
"pending_connections":6,
"available_connections":7,
"max_connections":8
}
}
}
}""")));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ public class InferencePlugin extends Plugin

public static final String NAME = "inference";
public static final String UTILITY_THREAD_POOL_NAME = "inference_utility";
public static final String INFERENCE_RESPONSE_THREAD_POOL_NAME = "inference_response";

private static final Logger log = LogManager.getLogger(InferencePlugin.class);

Expand Down Expand Up @@ -374,7 +375,9 @@ public Collection<?> createComponents(PluginServices services) {

components.add(serviceRegistry);
components.add(modelRegistry.get());
components.add(httpClientManager);
components.add(
new TransportGetInferenceDiagnosticsAction.ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager)
);
components.add(inferenceStatsBinding);

// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
Expand Down Expand Up @@ -497,10 +500,10 @@ protected Settings getSecretsIndexSettings() {

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settingsToUse) {
return List.of(inferenceUtilityExecutor(settings));
return List.of(inferenceUtilityExecutor(), inferenceResponseExecutor());
}

public static ExecutorBuilder<?> inferenceUtilityExecutor(Settings settings) {
private static ExecutorBuilder<?> inferenceUtilityExecutor() {
return new ScalingExecutorBuilder(
UTILITY_THREAD_POOL_NAME,
0,
Expand All @@ -511,6 +514,17 @@ public static ExecutorBuilder<?> inferenceUtilityExecutor(Settings settings) {
);
}

private static ExecutorBuilder<?> inferenceResponseExecutor() {
return new ScalingExecutorBuilder(
INFERENCE_RESPONSE_THREAD_POOL_NAME,
0,
10,
TimeValue.timeValueMinutes(10),
false,
"xpack.inference.inference_response_thread_pool"
);
}

@Override
public List<Setting<?>> getSettings() {
return List.copyOf(getInferenceSettings());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ public class TransportGetInferenceDiagnosticsAction extends TransportNodesAction
GetInferenceDiagnosticsAction.NodeResponse,
Void> {

private final HttpClientManager httpClientManager;
public record ClientManagers(HttpClientManager externalHttpClientManager, HttpClientManager eisMtlsHttpClientManager) {}

private final ClientManagers managers;

@Inject
public TransportGetInferenceDiagnosticsAction(
ThreadPool threadPool,
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
HttpClientManager httpClientManager
ClientManagers managers
) {
super(
GetInferenceDiagnosticsAction.NAME,
Expand All @@ -50,7 +52,7 @@ public TransportGetInferenceDiagnosticsAction(
threadPool.executor(ThreadPool.Names.MANAGEMENT)
);

this.httpClientManager = Objects.requireNonNull(httpClientManager);
this.managers = Objects.requireNonNull(managers);
}

@Override
Expand All @@ -74,6 +76,10 @@ protected GetInferenceDiagnosticsAction.NodeResponse newNodeResponse(StreamInput

@Override
protected GetInferenceDiagnosticsAction.NodeResponse nodeOperation(GetInferenceDiagnosticsAction.NodeRequest request, Task task) {
return new GetInferenceDiagnosticsAction.NodeResponse(transportService.getLocalNode(), httpClientManager.getPoolStats());
return new GetInferenceDiagnosticsAction.NodeResponse(
transportService.getLocalNode(),
managers.externalHttpClientManager().getPoolStats(),
managers.eisMtlsHttpClientManager().getPoolStats()
);
}
}
Loading