diff --git a/docs/changelog/133861.yaml b/docs/changelog/133861.yaml new file mode 100644 index 0000000000000..3c87fe8edfe00 --- /dev/null +++ b/docs/changelog/133861.yaml @@ -0,0 +1,5 @@ +pr: 133861 +summary: Implementing latency improvements for EIS integration +area: Machine Learning +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index bbabb17549e46..a3fadd5389bb6 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -345,6 +345,7 @@ static TransportVersion def(int id) { public static final TransportVersion GEMINI_THINKING_BUDGET_ADDED = def(9_153_0_00); public static final TransportVersion VISIT_PERCENTAGE = def(9_154_0_00); public static final TransportVersion TIME_SERIES_TELEMETRY = def(9_155_0_00); + public static final TransportVersion INFERENCE_API_EIS_DIAGNOSTICS = def(9_156_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java index afb59f8d4c843..025efa1689ed4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.inference.action; import org.apache.http.pool.PoolStats; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.support.nodes.BaseNodeResponse; @@ -115,30 +116,53 @@ public int hashCode() { } public static class NodeResponse extends BaseNodeResponse implements ToXContentFragment { - static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats"; + private static final String EXTERNAL_FIELD = "external"; + private static final String EIS_FIELD = "eis_mtls"; + private static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats"; - private final ConnectionPoolStats connectionPoolStats; + private final ConnectionPoolStats externalConnectionPoolStats; + private final ConnectionPoolStats eisMtlsConnectionPoolStats; - public NodeResponse(DiscoveryNode node, PoolStats poolStats) { + public NodeResponse(DiscoveryNode node, PoolStats poolStats, PoolStats eisPoolStats) { super(node); - connectionPoolStats = ConnectionPoolStats.of(poolStats); + externalConnectionPoolStats = ConnectionPoolStats.of(poolStats); + eisMtlsConnectionPoolStats = ConnectionPoolStats.of(eisPoolStats); } public NodeResponse(StreamInput in) throws IOException { super(in); - connectionPoolStats = new ConnectionPoolStats(in); + externalConnectionPoolStats = new ConnectionPoolStats(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) { + eisMtlsConnectionPoolStats = new ConnectionPoolStats(in); + } else { + eisMtlsConnectionPoolStats = ConnectionPoolStats.EMPTY; + } } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - connectionPoolStats.writeTo(out); + externalConnectionPoolStats.writeTo(out); + + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) { + eisMtlsConnectionPoolStats.writeTo(out); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(CONNECTION_POOL_STATS_FIELD_NAME, connectionPoolStats, params); + builder.startObject(EXTERNAL_FIELD); + { + builder.field(CONNECTION_POOL_STATS_FIELD_NAME, externalConnectionPoolStats, params); + } + builder.endObject(); + + builder.startObject(EIS_FIELD); + { + builder.field(CONNECTION_POOL_STATS_FIELD_NAME, eisMtlsConnectionPoolStats, params); + } + builder.endObject(); return builder; } @@ -147,23 +171,29 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; NodeResponse response = (NodeResponse) o; - return Objects.equals(connectionPoolStats, response.connectionPoolStats); + return Objects.equals(externalConnectionPoolStats, response.externalConnectionPoolStats) + && Objects.equals(eisMtlsConnectionPoolStats, response.eisMtlsConnectionPoolStats); } @Override public int hashCode() { - return Objects.hash(connectionPoolStats); + return Objects.hash(externalConnectionPoolStats, eisMtlsConnectionPoolStats); + } + + ConnectionPoolStats getExternalConnectionPoolStats() { + return externalConnectionPoolStats; } - ConnectionPoolStats getConnectionPoolStats() { - return connectionPoolStats; + ConnectionPoolStats getEisMtlsConnectionPoolStats() { + return eisMtlsConnectionPoolStats; } static class ConnectionPoolStats implements ToXContentObject, Writeable { - static final String LEASED_CONNECTIONS = "leased_connections"; - static final String PENDING_CONNECTIONS = "pending_connections"; - static final String AVAILABLE_CONNECTIONS = "available_connections"; - static final String MAX_CONNECTIONS = "max_connections"; + private static final String LEASED_CONNECTIONS = "leased_connections"; + private static final String PENDING_CONNECTIONS = "pending_connections"; + private static final String AVAILABLE_CONNECTIONS = "available_connections"; + private static final String MAX_CONNECTIONS = "max_connections"; + private static final ConnectionPoolStats EMPTY = new ConnectionPoolStats(0, 0, 0, 0); static ConnectionPoolStats of(PoolStats poolStats) { return new ConnectionPoolStats(poolStats.getLeased(), poolStats.getPending(), poolStats.getAvailable(), poolStats.getMax()); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java index a21354eb5a73d..3d1cb795e3f3e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java @@ -8,22 +8,25 @@ package org.elasticsearch.xpack.core.inference.action; import org.apache.http.pool.PoolStats; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; import java.io.UnsupportedEncodingException; -public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractWireSerializingTestCase< +public class GetInferenceDiagnosticsActionNodeResponseTests extends AbstractBWCWireSerializationTestCase< GetInferenceDiagnosticsAction.NodeResponse> { public static GetInferenceDiagnosticsAction.NodeResponse createRandom() { DiscoveryNode node = DiscoveryNodeUtils.create("id"); - var randomPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt()); + var randomExternalPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt()); + var randomEisPoolStats = new PoolStats(randomInt(), randomInt(), randomInt(), randomInt()); - return new GetInferenceDiagnosticsAction.NodeResponse(node, randomPoolStats); + return new GetInferenceDiagnosticsAction.NodeResponse(node, randomExternalPoolStats, randomEisPoolStats); } @Override @@ -39,47 +42,61 @@ protected GetInferenceDiagnosticsAction.NodeResponse createTestInstance() { @Override protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInferenceDiagnosticsAction.NodeResponse instance) throws IOException { - var select = randomIntBetween(0, 3); - var connPoolStats = instance.getConnectionPoolStats(); + if (randomBoolean()) { + PoolStats mutatedConnPoolStats = mutatePoolStats(instance.getExternalConnectionPoolStats()); + PoolStats eisPoolStats = copyPoolStats(instance.getEisMtlsConnectionPoolStats()); + return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), mutatedConnPoolStats, eisPoolStats); + } else { + PoolStats connPoolStats = copyPoolStats(instance.getExternalConnectionPoolStats()); + PoolStats mutatedEisPoolStats = mutatePoolStats(instance.getEisMtlsConnectionPoolStats()); + return new GetInferenceDiagnosticsAction.NodeResponse(instance.getNode(), connPoolStats, mutatedEisPoolStats); + } + } + private PoolStats mutatePoolStats(GetInferenceDiagnosticsAction.NodeResponse.ConnectionPoolStats stats) + throws UnsupportedEncodingException { + var select = randomIntBetween(0, 3); return switch (select) { - case 0 -> new GetInferenceDiagnosticsAction.NodeResponse( - instance.getNode(), - new PoolStats( - randomInt(), - connPoolStats.getPendingConnections(), - connPoolStats.getAvailableConnections(), - connPoolStats.getMaxConnections() - ) - ); - case 1 -> new GetInferenceDiagnosticsAction.NodeResponse( - instance.getNode(), - new PoolStats( - connPoolStats.getLeasedConnections(), - randomInt(), - connPoolStats.getAvailableConnections(), - connPoolStats.getMaxConnections() - ) - ); - case 2 -> new GetInferenceDiagnosticsAction.NodeResponse( - instance.getNode(), - new PoolStats( - connPoolStats.getLeasedConnections(), - connPoolStats.getPendingConnections(), - randomInt(), - connPoolStats.getMaxConnections() - ) + case 0 -> new PoolStats(randomInt(), stats.getPendingConnections(), stats.getAvailableConnections(), stats.getMaxConnections()); + case 1 -> new PoolStats(stats.getLeasedConnections(), randomInt(), stats.getAvailableConnections(), stats.getMaxConnections()); + case 2 -> new PoolStats(stats.getLeasedConnections(), stats.getPendingConnections(), randomInt(), stats.getMaxConnections()); + case 3 -> new PoolStats( + stats.getLeasedConnections(), + stats.getPendingConnections(), + stats.getAvailableConnections(), + randomInt() ); - case 3 -> new GetInferenceDiagnosticsAction.NodeResponse( + default -> throw new UnsupportedEncodingException(Strings.format("Encountered unsupported case %s", select)); + }; + } + + private PoolStats copyPoolStats(GetInferenceDiagnosticsAction.NodeResponse.ConnectionPoolStats stats) { + return new PoolStats( + stats.getLeasedConnections(), + stats.getPendingConnections(), + stats.getAvailableConnections(), + stats.getMaxConnections() + ); + } + + @Override + protected GetInferenceDiagnosticsAction.NodeResponse mutateInstanceForVersion( + GetInferenceDiagnosticsAction.NodeResponse instance, + TransportVersion version + ) { + if (version.before(TransportVersions.INFERENCE_API_EIS_DIAGNOSTICS)) { + return new GetInferenceDiagnosticsAction.NodeResponse( instance.getNode(), new PoolStats( - connPoolStats.getLeasedConnections(), - connPoolStats.getPendingConnections(), - connPoolStats.getAvailableConnections(), - randomInt() - ) + instance.getExternalConnectionPoolStats().getLeasedConnections(), + instance.getExternalConnectionPoolStats().getPendingConnections(), + instance.getExternalConnectionPoolStats().getAvailableConnections(), + instance.getExternalConnectionPoolStats().getMaxConnections() + ), + new PoolStats(0, 0, 0, 0) ); - default -> throw new UnsupportedEncodingException(Strings.format("Encountered unsupported case %s", select)); - }; + } else { + return instance; + } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java index e3eb42efdc791..726015f2156ad 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionResponseTests.java @@ -11,15 +11,17 @@ import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.hamcrest.CoreMatchers; import java.io.IOException; import java.util.List; +import static org.hamcrest.Matchers.is; + public class GetInferenceDiagnosticsActionResponseTests extends AbstractWireSerializingTestCase { public static GetInferenceDiagnosticsAction.Response createRandom() { @@ -33,10 +35,11 @@ public static GetInferenceDiagnosticsAction.Response createRandom() { public void testToXContent() throws IOException { var node = DiscoveryNodeUtils.create("id"); - var poolStats = new PoolStats(1, 2, 3, 4); + var externalPoolStats = new PoolStats(1, 2, 3, 4); + var eisPoolStats = new PoolStats(5, 6, 7, 8); var entity = new GetInferenceDiagnosticsAction.Response( ClusterName.DEFAULT, - List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, poolStats)), + List.of(new GetInferenceDiagnosticsAction.NodeResponse(node, externalPoolStats, eisPoolStats)), List.of() ); @@ -44,9 +47,27 @@ public void testToXContent() throws IOException { entity.toXContent(builder, null); String xContentResult = org.elasticsearch.common.Strings.toString(builder); - assertThat(xContentResult, CoreMatchers.is(""" - {"id":{"connection_pool_stats":{"leased_connections":1,"pending_connections":2,"available_connections":3,""" + """ - "max_connections":4}}}""")); + assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" + { + "id":{ + "external": { + "connection_pool_stats":{ + "leased_connections":1, + "pending_connections":2, + "available_connections":3, + "max_connections":4 + } + }, + "eis_mtls": { + "connection_pool_stats":{ + "leased_connections":5, + "pending_connections":6, + "available_connections":7, + "max_connections":8 + } + } + } + }"""))); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 2f3bb8dbb5136..5e7198d75f4bb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -203,6 +203,7 @@ public class InferencePlugin extends Plugin public static final String NAME = "inference"; public static final String UTILITY_THREAD_POOL_NAME = "inference_utility"; + public static final String INFERENCE_RESPONSE_THREAD_POOL_NAME = "inference_response"; private static final Logger log = LogManager.getLogger(InferencePlugin.class); @@ -374,7 +375,9 @@ public Collection createComponents(PluginServices services) { components.add(serviceRegistry); components.add(modelRegistry.get()); - components.add(httpClientManager); + components.add( + new TransportGetInferenceDiagnosticsAction.ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager) + ); components.add(inferenceStatsBinding); // Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting, @@ -497,10 +500,10 @@ protected Settings getSecretsIndexSettings() { @Override public List> getExecutorBuilders(Settings settingsToUse) { - return List.of(inferenceUtilityExecutor(settings)); + return List.of(inferenceUtilityExecutor(), inferenceResponseExecutor()); } - public static ExecutorBuilder inferenceUtilityExecutor(Settings settings) { + private static ExecutorBuilder inferenceUtilityExecutor() { return new ScalingExecutorBuilder( UTILITY_THREAD_POOL_NAME, 0, @@ -511,6 +514,17 @@ public static ExecutorBuilder inferenceUtilityExecutor(Settings settings) { ); } + private static ExecutorBuilder inferenceResponseExecutor() { + return new ScalingExecutorBuilder( + INFERENCE_RESPONSE_THREAD_POOL_NAME, + 0, + 10, + TimeValue.timeValueMinutes(10), + false, + "xpack.inference.inference_response_thread_pool" + ); + } + @Override public List> getSettings() { return List.copyOf(getInferenceSettings()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java index cdd322cfe74f3..1ddfd784676f5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java @@ -31,7 +31,9 @@ public class TransportGetInferenceDiagnosticsAction extends TransportNodesAction GetInferenceDiagnosticsAction.NodeResponse, Void> { - private final HttpClientManager httpClientManager; + public record ClientManagers(HttpClientManager externalHttpClientManager, HttpClientManager eisMtlsHttpClientManager) {} + + private final ClientManagers managers; @Inject public TransportGetInferenceDiagnosticsAction( @@ -39,7 +41,7 @@ public TransportGetInferenceDiagnosticsAction( ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, - HttpClientManager httpClientManager + ClientManagers managers ) { super( GetInferenceDiagnosticsAction.NAME, @@ -50,7 +52,7 @@ public TransportGetInferenceDiagnosticsAction( threadPool.executor(ThreadPool.Names.MANAGEMENT) ); - this.httpClientManager = Objects.requireNonNull(httpClientManager); + this.managers = Objects.requireNonNull(managers); } @Override @@ -74,6 +76,10 @@ protected GetInferenceDiagnosticsAction.NodeResponse newNodeResponse(StreamInput @Override protected GetInferenceDiagnosticsAction.NodeResponse nodeOperation(GetInferenceDiagnosticsAction.NodeRequest request, Task task) { - return new GetInferenceDiagnosticsAction.NodeResponse(transportService.getLocalNode(), httpClientManager.getPoolStats()); + return new GetInferenceDiagnosticsAction.NodeResponse( + transportService.getLocalNode(), + managers.externalHttpClientManager().getPoolStats(), + managers.eisMtlsHttpClientManager().getPoolStats() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java index 7936e6779c8d5..edee87bf00d7e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java @@ -15,6 +15,7 @@ import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; import org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager; import org.apache.http.protocol.HttpContext; +import org.apache.http.util.EntityUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -30,7 +31,7 @@ import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_RESPONSE_THREAD_POOL_NAME; /** * Provides a wrapper around a {@link CloseableHttpAsyncClient} to move the responses to a separate thread for processing. @@ -72,6 +73,29 @@ private static CloseableHttpAsyncClient createAsyncClient( // so we don't want to support cookies to avoid accidental authentication for unauthorized users clientBuilder.disableCookieManagement(); + /* + TODO When we implement multi-project we should ensure this is ok. A cluster will be authenticated to EIS because it is one mTLS + cert per cluster. So I think we're ok to not need to track the connection state per request. We will need to pass a header + that contains the project id and organization so EIS can determine if the project is authorized or not. + + See https://stackoverflow.com/questions/13034998/httpclient-is-not-re-using-my-connections-keeps-creating-new-ones for a good + explanation of why we disable connection state. + + The relevant part is copied below: + SSL connections established by your applications are likely stateful. That is, the server requested the client to + authenticate with a private certificate, making them security context specific. HttpClient detects that and prevents + those connections from being leased to a caller with a different security context. Effectively HttpClient is playing safe + by forcing a new connection for each request rather than risking leasing persistent SSL connection to the wrong user. + + You can do two things here + + - disable connection state tracking + - make sure all logically related requests share the same context (recommended) + For details see this section of the HttpClient tutorial: + https://hc.apache.org/httpcomponents-client-4.5.x/current/tutorial/html/advanced.html#stateful_conn + */ + clientBuilder.disableConnectionState(); + /* By default, if a keep-alive header is not returned by the server then the connection will be kept alive indefinitely. In this situation the default keep alive strategy will return -1. Since we use a connection eviction thread, @@ -115,18 +139,18 @@ public void send(HttpRequest request, HttpClientContext context, ActionListener< SocketAccess.doPrivileged(() -> client.execute(request.httpRequestBase(), context, new FutureCallback<>() { @Override public void completed(HttpResponse response) { - respondUsingUtilityThread(response, request, listener); + respondUsingResponseThread(response, request, listener); } @Override public void failed(Exception ex) { throttlerManager.warn(logger, format("Request from inference entity id [%s] failed", request.inferenceEntityId()), ex); - failUsingUtilityThread(ex, listener); + failUsingResponseThread(ex, listener); } @Override public void cancelled() { - failUsingUtilityThread( + failUsingResponseThread( new CancellationException(format("Request from inference entity id [%s] was cancelled", request.inferenceEntityId())), listener ); @@ -134,8 +158,8 @@ public void cancelled() { })); } - private void respondUsingUtilityThread(HttpResponse response, HttpRequest request, ActionListener listener) { - threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> { + private void respondUsingResponseThread(HttpResponse response, HttpRequest request, ActionListener listener) { + threadPool.executor(INFERENCE_RESPONSE_THREAD_POOL_NAME).execute(() -> { try { listener.onResponse(HttpResult.create(settings.getMaxResponseSize(), response)); } catch (Exception e) { @@ -145,12 +169,14 @@ private void respondUsingUtilityThread(HttpResponse response, HttpRequest reques e ); listener.onFailure(e); + } finally { + EntityUtils.consumeQuietly(response.getEntity()); } }); } - private void failUsingUtilityThread(Exception exception, ActionListener listener) { - threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> listener.onFailure(exception)); + private void failUsingResponseThread(Exception exception, ActionListener listener) { + threadPool.executor(INFERENCE_RESPONSE_THREAD_POOL_NAME).execute(() -> listener.onFailure(exception)); } public void stream(HttpRequest request, HttpContext context, ActionListener listener) throws IOException { @@ -167,12 +193,12 @@ public void completed(Void response) { @Override public void failed(Exception ex) { - threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> streamingProcessor.failed(ex)); + threadPool.executor(INFERENCE_RESPONSE_THREAD_POOL_NAME).execute(() -> streamingProcessor.failed(ex)); } @Override public void cancelled() { - threadPool.executor(UTILITY_THREAD_POOL_NAME) + threadPool.executor(INFERENCE_RESPONSE_THREAD_POOL_NAME) .execute( () -> streamingProcessor.failed( new CancellationException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java index 35e50e557cc83..2c9e85d7b1ae0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java @@ -18,7 +18,7 @@ public class RetrySettings { static final Setting RETRY_INITIAL_DELAY_SETTING = Setting.timeSetting( "xpack.inference.http.retry.initial_delay", - TimeValue.timeValueSeconds(1), + TimeValue.timeValueMillis(100), Setting.Property.NodeScope, Setting.Property.Dynamic ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index 55b59e3fd1d9f..81d249add0262 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -252,16 +252,20 @@ private void cleanup() { private void handleTasks() { try { - if (shutdown.get()) { - logger.debug("Shutdown requested while handling tasks, cleaning up"); - cleanup(); - return; - } + TimeValue timeToWait; + do { + if (shutdown.get()) { + logger.debug("Shutdown requested while handling tasks, cleaning up"); + cleanup(); + return; + } - var timeToWait = settings.getTaskPollFrequency(); - for (var endpoint : rateLimitGroupings.values()) { - timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait); - } + timeToWait = settings.getTaskPollFrequency(); + for (var endpoint : rateLimitGroupings.values()) { + timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait); + } + // if we execute a task the timeToWait will be 0 so we'll immediately look for more work + } while (timeToWait.compareTo(TimeValue.ZERO) <= 0); scheduleNextHandleTasks(timeToWait); } catch (Exception e) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index 6e88bb458fa9f..cfd9216e17a07 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -39,6 +39,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_RESPONSE_THREAD_POOL_NAME; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; @@ -80,6 +81,14 @@ public static ScalingExecutorBuilder[] inferenceUtilityExecutors() { TimeValue.timeValueMinutes(10), false, "xpack.inference.utility_thread_pool" + ), + new ScalingExecutorBuilder( + INFERENCE_RESPONSE_THREAD_POOL_NAME, + 1, + 4, + TimeValue.timeValueMinutes(10), + false, + "xpack.inference.inference_response_thread_pool" ) }; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 88459133ddc71..3af19bf46c62e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -113,6 +113,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response.RESULTS_FIELD; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; @@ -147,7 +148,7 @@ public void setUp() throws Exception { super.setUp(); randomInferenceEntityId = randomAlphaOfLength(10); inferenceStats = InferenceStatsTests.mockInferenceStats(); - threadPool = createThreadPool(InferencePlugin.inferenceUtilityExecutor(Settings.EMPTY)); + threadPool = createThreadPool(inferenceUtilityExecutors()); } @After