Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SIMULATE_INGEST_EFFECTIVE_MAPPING = def(9_140_0_00);
public static final TransportVersion RESOLVE_INDEX_MODE_ADDED = def(9_141_0_00);
public static final TransportVersion DATA_STREAM_WRITE_INDEX_ONLY_SETTINGS = def(9_142_0_00);
public static final TransportVersion INFERENCE_DIAGNOSTICS_FOR_EIS = def(9_143_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.inference.action;

import org.apache.http.pool.PoolStats;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.support.nodes.BaseNodeResponse;
Expand Down Expand Up @@ -116,29 +117,44 @@ public int hashCode() {

public static class NodeResponse extends BaseNodeResponse implements ToXContentFragment {
static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats";
static final String EIS_CONNECTION_POOL_STATS_FIELD_NAME = "eis_connection_pool_stats";

private final ConnectionPoolStats connectionPoolStats;
private final ConnectionPoolStats externalConnectionPoolStats;
private final ConnectionPoolStats eisConnectionPoolStats;

public NodeResponse(DiscoveryNode node, PoolStats poolStats) {
public NodeResponse(DiscoveryNode node, PoolStats externalPoolStats, PoolStats eisPoolStats) {
super(node);
connectionPoolStats = ConnectionPoolStats.of(poolStats);
externalConnectionPoolStats = ConnectionPoolStats.of(externalPoolStats);
eisConnectionPoolStats = ConnectionPoolStats.of(eisPoolStats);
}

public NodeResponse(StreamInput in) throws IOException {
super(in);

connectionPoolStats = new ConnectionPoolStats(in);
externalConnectionPoolStats = new ConnectionPoolStats(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_DIAGNOSTICS_FOR_EIS)) {
eisConnectionPoolStats = new ConnectionPoolStats(in);
} else {
eisConnectionPoolStats = null; // EIS stats are not available in older versions
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
connectionPoolStats.writeTo(out);
externalConnectionPoolStats.writeTo(out);

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_DIAGNOSTICS_FOR_EIS)) {
eisConnectionPoolStats.writeTo(out);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(CONNECTION_POOL_STATS_FIELD_NAME, connectionPoolStats, params);
builder.field(CONNECTION_POOL_STATS_FIELD_NAME, externalConnectionPoolStats, params);
if (eisConnectionPoolStats != null) {
builder.field(EIS_CONNECTION_POOL_STATS_FIELD_NAME, eisConnectionPoolStats, params);
}
return builder;
}

Expand All @@ -147,16 +163,21 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NodeResponse response = (NodeResponse) o;
return Objects.equals(connectionPoolStats, response.connectionPoolStats);
return Objects.equals(externalConnectionPoolStats, response.externalConnectionPoolStats)
&& Objects.equals(eisConnectionPoolStats, response.eisConnectionPoolStats);
}

@Override
public int hashCode() {
return Objects.hash(connectionPoolStats);
return Objects.hash(externalConnectionPoolStats, eisConnectionPoolStats);
}

ConnectionPoolStats getExternalConnectionPoolStats() {
return externalConnectionPoolStats;
}

ConnectionPoolStats getConnectionPoolStats() {
return connectionPoolStats;
ConnectionPoolStats getEisConnectionPoolStats() {
return eisConnectionPoolStats;
}

static class ConnectionPoolStats implements ToXContentObject, Writeable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,17 +487,31 @@ public static class Response extends ActionResponse implements ChunkedToXContent
private final InferenceServiceResults results;
private final boolean isStreaming;
private final Flow.Publisher<InferenceServiceResults.Result> publisher;
private final long elapsedTimeMs;
private final long elapsedTimeNanos;

public Response(InferenceServiceResults results) {
this.results = results;
this.isStreaming = false;
this.publisher = null;
this.elapsedTimeMs = -1;
this.elapsedTimeNanos = -1;
}

public Response(InferenceServiceResults results, long elapsedTimeMs, long elapsedTimeNanos) {
this.results = results;
this.isStreaming = false;
this.publisher = null;
this.elapsedTimeMs = elapsedTimeMs;
this.elapsedTimeNanos = elapsedTimeNanos;
}

public Response(InferenceServiceResults results, Flow.Publisher<InferenceServiceResults.Result> publisher) {
this.results = results;
this.isStreaming = true;
this.publisher = publisher;
this.elapsedTimeMs = -1;
this.elapsedTimeNanos = -1;
}

public Response(StreamInput in) throws IOException {
Expand All @@ -511,6 +525,8 @@ public Response(StreamInput in) throws IOException {
// streaming isn't supported via Writeable yet
this.isStreaming = false;
this.publisher = null;
this.elapsedTimeMs = in.readVLong();
this.elapsedTimeNanos = in.readVLong();
}

@SuppressWarnings("deprecation")
Expand Down Expand Up @@ -586,13 +602,21 @@ public Flow.Publisher<InferenceServiceResults.Result> publisher() {
public void writeTo(StreamOutput out) throws IOException {
// streaming isn't supported via Writeable yet
out.writeNamedWriteable(results);
out.writeVLong(elapsedTimeMs);
out.writeVLong(elapsedTimeNanos);
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return Iterators.concat(
ChunkedToXContentHelper.startObject(),
results.toXContentChunked(params),
ChunkedToXContentHelper.field("elapsed_time_ms", params1 -> Iterators.single((b, p) -> b.value(elapsedTimeMs)), params),
ChunkedToXContentHelper.field(
"elapsed_time_nanos",
params1 -> Iterators.single((b, p) -> b.value(elapsedTimeNanos)),
params
),
ChunkedToXContentHelper.endObject()
);
}
Expand All @@ -602,12 +626,14 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Response response = (Response) o;
return Objects.equals(results, response.results);
return Objects.equals(results, response.results)
&& elapsedTimeMs == response.elapsedTimeMs
&& elapsedTimeNanos == response.elapsedTimeNanos;
}

@Override
public int hashCode() {
return Objects.hash(results);
return Objects.hash(results, elapsedTimeMs, elapsedTimeNanos);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ protected GetInferenceDiagnosticsAction.NodeResponse createTestInstance() {
protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInferenceDiagnosticsAction.NodeResponse instance)
throws IOException {
var select = randomIntBetween(0, 3);
var connPoolStats = instance.getConnectionPoolStats();
var connPoolStats = instance.getExternalConnectionPoolStats();

return switch (select) {
case 0 -> new GetInferenceDiagnosticsAction.NodeResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ public class InferencePlugin extends Plugin

public static final String NAME = "inference";
public static final String UTILITY_THREAD_POOL_NAME = "inference_utility";
public static final String UTILITY_RESPONSE_THREAD_POOL_NAME = "inference_utility_response";

private static final Logger log = LogManager.getLogger(InferencePlugin.class);

Expand All @@ -218,6 +219,8 @@ public class InferencePlugin extends Plugin
private final SetOnce<ModelRegistry> modelRegistry = new SetOnce<>();
private List<InferenceServiceExtension> inferenceServiceExtensions;

public record ClientManagers(HttpClientManager externalManager, HttpClientManager eisManager) {}

public InferencePlugin(Settings settings) {
this.settings = settings;
}
Expand Down Expand Up @@ -301,6 +304,8 @@ public Collection<?> createComponents(PluginServices services) {
inferenceServiceSettings.getConnectionTtl()
);

var clientManagers = new ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager);

var elasticInferenceServiceRequestSenderFactory = new HttpRequestSender.Factory(
serviceComponents.get(),
elasticInferenceServiceHttpClientManager,
Expand Down Expand Up @@ -372,7 +377,7 @@ public Collection<?> createComponents(PluginServices services) {

components.add(serviceRegistry);
components.add(modelRegistry.get());
components.add(httpClientManager);
components.add(clientManagers);
components.add(inferenceStatsBinding);

// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
Expand Down Expand Up @@ -495,20 +500,31 @@ protected Settings getSecretsIndexSettings() {

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settingsToUse) {
return List.of(inferenceUtilityExecutor(settings));
return List.of(inferenceUtilityExecutor(settings), inferenceResponseUtilityExecutor(settings));
}

public static ExecutorBuilder<?> inferenceUtilityExecutor(Settings settings) {
return new ScalingExecutorBuilder(
UTILITY_THREAD_POOL_NAME,
0,
10,
20,
TimeValue.timeValueMinutes(10),
false,
"xpack.inference.utility_thread_pool"
);
}

public static ExecutorBuilder<?> inferenceResponseUtilityExecutor(Settings settings) {
return new ScalingExecutorBuilder(
UTILITY_RESPONSE_THREAD_POOL_NAME,
0,
20,
TimeValue.timeValueMinutes(10),
false,
"xpack.inference.utility_response_thread_pool"
);
}

@Override
public List<Setting<?>> getSettings() {
return List.copyOf(getInferenceSettings());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ private void inferOnServiceWithMetrics(
listener.onResponse(new InferenceAction.Response(inferenceResults, streamErrorHandler));
} else {
recordRequestDurationMetrics(model, timer, request, localNodeId, null);
listener.onResponse(new InferenceAction.Response(inferenceResults));
listener.onResponse(new InferenceAction.Response(inferenceResults, timer.elapsedMillis(), timer.elapsedNanos()));
}
}, e -> {
recordRequestDurationMetrics(model, timer, request, localNodeId, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.InferencePlugin;

import java.io.IOException;
import java.util.List;
Expand All @@ -31,15 +31,15 @@ public class TransportGetInferenceDiagnosticsAction extends TransportNodesAction
GetInferenceDiagnosticsAction.NodeResponse,
Void> {

private final HttpClientManager httpClientManager;
private final InferencePlugin.ClientManagers httpClientManagers;

@Inject
public TransportGetInferenceDiagnosticsAction(
ThreadPool threadPool,
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
HttpClientManager httpClientManager
InferencePlugin.ClientManagers managers
) {
super(
GetInferenceDiagnosticsAction.NAME,
Expand All @@ -50,7 +50,7 @@ public TransportGetInferenceDiagnosticsAction(
threadPool.executor(ThreadPool.Names.MANAGEMENT)
);

this.httpClientManager = Objects.requireNonNull(httpClientManager);
this.httpClientManagers = Objects.requireNonNull(managers);
}

@Override
Expand All @@ -74,6 +74,10 @@ protected GetInferenceDiagnosticsAction.NodeResponse newNodeResponse(StreamInput

@Override
protected GetInferenceDiagnosticsAction.NodeResponse nodeOperation(GetInferenceDiagnosticsAction.NodeRequest request, Task task) {
return new GetInferenceDiagnosticsAction.NodeResponse(transportService.getLocalNode(), httpClientManager.getPoolStats());
return new GetInferenceDiagnosticsAction.NodeResponse(
transportService.getLocalNode(),
httpClientManagers.externalManager().getPoolStats(),
httpClientManagers.eisManager().getPoolStats()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
package org.elasticsearch.xpack.inference.external.http;

import org.apache.http.HttpResponse;
import org.apache.http.client.UserTokenHandler;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.impl.DefaultConnectionReuseStrategy;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager;
import org.apache.http.protocol.HttpContext;
import org.apache.http.util.EntityUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand All @@ -30,12 +33,14 @@
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_RESPONSE_THREAD_POOL_NAME;

/**
* Provides a wrapper around a {@link CloseableHttpAsyncClient} to move the responses to a separate thread for processing.
*/
public class HttpClient implements Closeable {
public static final String USER_TOKEN = "token";

private static final Logger logger = LogManager.getLogger(HttpClient.class);

enum Status {
Expand Down Expand Up @@ -71,6 +76,15 @@ private static CloseableHttpAsyncClient createAsyncClient(
// The apache client will be shared across all connections because it can be expensive to create it
// so we don't want to support cookies to avoid accidental authentication for unauthorized users
clientBuilder.disableCookieManagement();
var userTokenHandler = new UserTokenHandler() {
public Object getUserToken(HttpContext context) {
return USER_TOKEN;
}

};
clientBuilder.setUserTokenHandler(userTokenHandler);
clientBuilder.setConnectionReuseStrategy(DefaultConnectionReuseStrategy.INSTANCE);
// clientBuilder.disableConnectionState();

/*
By default, if a keep-alive header is not returned by the server then the connection will be kept alive
Expand Down Expand Up @@ -135,7 +149,7 @@ public void cancelled() {
}

private void respondUsingUtilityThread(HttpResponse response, HttpRequest request, ActionListener<HttpResult> listener) {
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> {
threadPool.executor(UTILITY_RESPONSE_THREAD_POOL_NAME).execute(() -> {
try {
listener.onResponse(HttpResult.create(settings.getMaxResponseSize(), response));
} catch (Exception e) {
Expand All @@ -145,12 +159,14 @@ private void respondUsingUtilityThread(HttpResponse response, HttpRequest reques
e
);
listener.onFailure(e);
} finally {
EntityUtils.consumeQuietly(response.getEntity());
}
});
}

private void failUsingUtilityThread(Exception exception, ActionListener<?> listener) {
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> listener.onFailure(exception));
threadPool.executor(UTILITY_RESPONSE_THREAD_POOL_NAME).execute(() -> listener.onFailure(exception));
}

public void stream(HttpRequest request, HttpContext context, ActionListener<StreamingHttpResult> listener) throws IOException {
Expand All @@ -167,12 +183,12 @@ public void completed(Void response) {

@Override
public void failed(Exception ex) {
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> streamingProcessor.failed(ex));
threadPool.executor(UTILITY_RESPONSE_THREAD_POOL_NAME).execute(() -> streamingProcessor.failed(ex));
}

@Override
public void cancelled() {
threadPool.executor(UTILITY_THREAD_POOL_NAME)
threadPool.executor(UTILITY_RESPONSE_THREAD_POOL_NAME)
.execute(
() -> streamingProcessor.failed(
new CancellationException(
Expand Down
Loading