diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index e215cefde903c..990f7ffbb739a 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -366,6 +366,7 @@ static TransportVersion def(int id) { public static final TransportVersion SIMULATE_INGEST_EFFECTIVE_MAPPING = def(9_140_0_00); public static final TransportVersion RESOLVE_INDEX_MODE_ADDED = def(9_141_0_00); public static final TransportVersion DATA_STREAM_WRITE_INDEX_ONLY_SETTINGS = def(9_142_0_00); + public static final TransportVersion INFERENCE_DIAGNOSTICS_FOR_EIS = def(9_143_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java index afb59f8d4c843..0051b733428a5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsAction.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.inference.action; import org.apache.http.pool.PoolStats; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.support.nodes.BaseNodeResponse; @@ -116,29 +117,44 @@ public int hashCode() { public static class NodeResponse extends BaseNodeResponse implements ToXContentFragment { static final String CONNECTION_POOL_STATS_FIELD_NAME = "connection_pool_stats"; + static final String EIS_CONNECTION_POOL_STATS_FIELD_NAME = "eis_connection_pool_stats"; - private final ConnectionPoolStats connectionPoolStats; + private final ConnectionPoolStats externalConnectionPoolStats; + private final ConnectionPoolStats eisConnectionPoolStats; - public NodeResponse(DiscoveryNode node, PoolStats poolStats) { + public NodeResponse(DiscoveryNode node, PoolStats externalPoolStats, PoolStats eisPoolStats) { super(node); - connectionPoolStats = ConnectionPoolStats.of(poolStats); + externalConnectionPoolStats = ConnectionPoolStats.of(externalPoolStats); + eisConnectionPoolStats = ConnectionPoolStats.of(eisPoolStats); } public NodeResponse(StreamInput in) throws IOException { super(in); - connectionPoolStats = new ConnectionPoolStats(in); + externalConnectionPoolStats = new ConnectionPoolStats(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_DIAGNOSTICS_FOR_EIS)) { + eisConnectionPoolStats = new ConnectionPoolStats(in); + } else { + eisConnectionPoolStats = null; // EIS stats are not available in older versions + } } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - connectionPoolStats.writeTo(out); + externalConnectionPoolStats.writeTo(out); + + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_DIAGNOSTICS_FOR_EIS)) { + eisConnectionPoolStats.writeTo(out); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(CONNECTION_POOL_STATS_FIELD_NAME, connectionPoolStats, params); + builder.field(CONNECTION_POOL_STATS_FIELD_NAME, externalConnectionPoolStats, params); + if (eisConnectionPoolStats != null) { + builder.field(EIS_CONNECTION_POOL_STATS_FIELD_NAME, eisConnectionPoolStats, params); + } return builder; } @@ -147,16 +163,21 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; NodeResponse response = (NodeResponse) o; - return Objects.equals(connectionPoolStats, response.connectionPoolStats); + return Objects.equals(externalConnectionPoolStats, response.externalConnectionPoolStats) + && Objects.equals(eisConnectionPoolStats, response.eisConnectionPoolStats); } @Override public int hashCode() { - return Objects.hash(connectionPoolStats); + return Objects.hash(externalConnectionPoolStats, eisConnectionPoolStats); + } + + ConnectionPoolStats getExternalConnectionPoolStats() { + return externalConnectionPoolStats; } - ConnectionPoolStats getConnectionPoolStats() { - return connectionPoolStats; + ConnectionPoolStats getEisConnectionPoolStats() { + return eisConnectionPoolStats; } static class ConnectionPoolStats implements ToXContentObject, Writeable { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index 64957328d48dd..38a80cbc3bcfe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -487,17 +487,31 @@ public static class Response extends ActionResponse implements ChunkedToXContent private final InferenceServiceResults results; private final boolean isStreaming; private final Flow.Publisher publisher; + private final long elapsedTimeMs; + private final long elapsedTimeNanos; public Response(InferenceServiceResults results) { this.results = results; this.isStreaming = false; this.publisher = null; + this.elapsedTimeMs = -1; + this.elapsedTimeNanos = -1; + } + + public Response(InferenceServiceResults results, long elapsedTimeMs, long elapsedTimeNanos) { + this.results = results; + this.isStreaming = false; + this.publisher = null; + this.elapsedTimeMs = elapsedTimeMs; + this.elapsedTimeNanos = elapsedTimeNanos; } public Response(InferenceServiceResults results, Flow.Publisher publisher) { this.results = results; this.isStreaming = true; this.publisher = publisher; + this.elapsedTimeMs = -1; + this.elapsedTimeNanos = -1; } public Response(StreamInput in) throws IOException { @@ -511,6 +525,8 @@ public Response(StreamInput in) throws IOException { // streaming isn't supported via Writeable yet this.isStreaming = false; this.publisher = null; + this.elapsedTimeMs = in.readVLong(); + this.elapsedTimeNanos = in.readVLong(); } @SuppressWarnings("deprecation") @@ -586,6 +602,8 @@ public Flow.Publisher publisher() { public void writeTo(StreamOutput out) throws IOException { // streaming isn't supported via Writeable yet out.writeNamedWriteable(results); + out.writeVLong(elapsedTimeMs); + out.writeVLong(elapsedTimeNanos); } @Override @@ -593,6 +611,12 @@ public Iterator toXContentChunked(ToXContent.Params params return Iterators.concat( ChunkedToXContentHelper.startObject(), results.toXContentChunked(params), + ChunkedToXContentHelper.field("elapsed_time_ms", params1 -> Iterators.single((b, p) -> b.value(elapsedTimeMs)), params), + ChunkedToXContentHelper.field( + "elapsed_time_nanos", + params1 -> Iterators.single((b, p) -> b.value(elapsedTimeNanos)), + params + ), ChunkedToXContentHelper.endObject() ); } @@ -602,12 +626,14 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Response response = (Response) o; - return Objects.equals(results, response.results); + return Objects.equals(results, response.results) + && elapsedTimeMs == response.elapsedTimeMs + && elapsedTimeNanos == response.elapsedTimeNanos; } @Override public int hashCode() { - return Objects.hash(results); + return Objects.hash(results, elapsedTimeMs, elapsedTimeNanos); } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java index a21354eb5a73d..3f797ae06b6d3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInferenceDiagnosticsActionNodeResponseTests.java @@ -40,7 +40,7 @@ protected GetInferenceDiagnosticsAction.NodeResponse createTestInstance() { protected GetInferenceDiagnosticsAction.NodeResponse mutateInstance(GetInferenceDiagnosticsAction.NodeResponse instance) throws IOException { var select = randomIntBetween(0, 3); - var connPoolStats = instance.getConnectionPoolStats(); + var connPoolStats = instance.getExternalConnectionPoolStats(); return switch (select) { case 0 -> new GetInferenceDiagnosticsAction.NodeResponse( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index c3ae4f0d9d6d6..5ce9d9939d40d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -201,6 +201,7 @@ public class InferencePlugin extends Plugin public static final String NAME = "inference"; public static final String UTILITY_THREAD_POOL_NAME = "inference_utility"; + public static final String UTILITY_RESPONSE_THREAD_POOL_NAME = "inference_utility_response"; private static final Logger log = LogManager.getLogger(InferencePlugin.class); @@ -218,6 +219,8 @@ public class InferencePlugin extends Plugin private final SetOnce modelRegistry = new SetOnce<>(); private List inferenceServiceExtensions; + public record ClientManagers(HttpClientManager externalManager, HttpClientManager eisManager) {} + public InferencePlugin(Settings settings) { this.settings = settings; } @@ -301,6 +304,8 @@ public Collection createComponents(PluginServices services) { inferenceServiceSettings.getConnectionTtl() ); + var clientManagers = new ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager); + var elasticInferenceServiceRequestSenderFactory = new HttpRequestSender.Factory( serviceComponents.get(), elasticInferenceServiceHttpClientManager, @@ -372,7 +377,7 @@ public Collection createComponents(PluginServices services) { components.add(serviceRegistry); components.add(modelRegistry.get()); - components.add(httpClientManager); + components.add(clientManagers); components.add(inferenceStatsBinding); // Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting, @@ -495,20 +500,31 @@ protected Settings getSecretsIndexSettings() { @Override public List> getExecutorBuilders(Settings settingsToUse) { - return List.of(inferenceUtilityExecutor(settings)); + return List.of(inferenceUtilityExecutor(settings), inferenceResponseUtilityExecutor(settings)); } public static ExecutorBuilder inferenceUtilityExecutor(Settings settings) { return new ScalingExecutorBuilder( UTILITY_THREAD_POOL_NAME, 0, - 10, + 20, TimeValue.timeValueMinutes(10), false, "xpack.inference.utility_thread_pool" ); } + public static ExecutorBuilder inferenceResponseUtilityExecutor(Settings settings) { + return new ScalingExecutorBuilder( + UTILITY_RESPONSE_THREAD_POOL_NAME, + 0, + 20, + TimeValue.timeValueMinutes(10), + false, + "xpack.inference.utility_response_thread_pool" + ); + } + @Override public List> getSettings() { return List.copyOf(getInferenceSettings()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 269e0f27fd461..ee7a3079f6953 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -306,7 +306,7 @@ private void inferOnServiceWithMetrics( listener.onResponse(new InferenceAction.Response(inferenceResults, streamErrorHandler)); } else { recordRequestDurationMetrics(model, timer, request, localNodeId, null); - listener.onResponse(new InferenceAction.Response(inferenceResults)); + listener.onResponse(new InferenceAction.Response(inferenceResults, timer.elapsedMillis(), timer.elapsedNanos())); } }, e -> { recordRequestDurationMetrics(model, timer, request, localNodeId, e); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java index cdd322cfe74f3..62c6462d5d5a8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java @@ -18,7 +18,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction; -import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.InferencePlugin; import java.io.IOException; import java.util.List; @@ -31,7 +31,7 @@ public class TransportGetInferenceDiagnosticsAction extends TransportNodesAction GetInferenceDiagnosticsAction.NodeResponse, Void> { - private final HttpClientManager httpClientManager; + private final InferencePlugin.ClientManagers httpClientManagers; @Inject public TransportGetInferenceDiagnosticsAction( @@ -39,7 +39,7 @@ public TransportGetInferenceDiagnosticsAction( ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, - HttpClientManager httpClientManager + InferencePlugin.ClientManagers managers ) { super( GetInferenceDiagnosticsAction.NAME, @@ -50,7 +50,7 @@ public TransportGetInferenceDiagnosticsAction( threadPool.executor(ThreadPool.Names.MANAGEMENT) ); - this.httpClientManager = Objects.requireNonNull(httpClientManager); + this.httpClientManagers = Objects.requireNonNull(managers); } @Override @@ -74,6 +74,10 @@ protected GetInferenceDiagnosticsAction.NodeResponse newNodeResponse(StreamInput @Override protected GetInferenceDiagnosticsAction.NodeResponse nodeOperation(GetInferenceDiagnosticsAction.NodeRequest request, Task task) { - return new GetInferenceDiagnosticsAction.NodeResponse(transportService.getLocalNode(), httpClientManager.getPoolStats()); + return new GetInferenceDiagnosticsAction.NodeResponse( + transportService.getLocalNode(), + httpClientManagers.externalManager().getPoolStats(), + httpClientManagers.eisManager().getPoolStats() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java index 7936e6779c8d5..2730c6642c4d5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java @@ -8,13 +8,16 @@ package org.elasticsearch.xpack.inference.external.http; import org.apache.http.HttpResponse; +import org.apache.http.client.UserTokenHandler; import org.apache.http.client.config.RequestConfig; import org.apache.http.client.protocol.HttpClientContext; import org.apache.http.concurrent.FutureCallback; +import org.apache.http.impl.DefaultConnectionReuseStrategy; import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; import org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager; import org.apache.http.protocol.HttpContext; +import org.apache.http.util.EntityUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -30,12 +33,14 @@ import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_RESPONSE_THREAD_POOL_NAME; /** * Provides a wrapper around a {@link CloseableHttpAsyncClient} to move the responses to a separate thread for processing. */ public class HttpClient implements Closeable { + public static final String USER_TOKEN = "token"; + private static final Logger logger = LogManager.getLogger(HttpClient.class); enum Status { @@ -71,6 +76,15 @@ private static CloseableHttpAsyncClient createAsyncClient( // The apache client will be shared across all connections because it can be expensive to create it // so we don't want to support cookies to avoid accidental authentication for unauthorized users clientBuilder.disableCookieManagement(); + var userTokenHandler = new UserTokenHandler() { + public Object getUserToken(HttpContext context) { + return USER_TOKEN; + } + + }; + clientBuilder.setUserTokenHandler(userTokenHandler); + clientBuilder.setConnectionReuseStrategy(DefaultConnectionReuseStrategy.INSTANCE); + // clientBuilder.disableConnectionState(); /* By default, if a keep-alive header is not returned by the server then the connection will be kept alive @@ -135,7 +149,7 @@ public void cancelled() { } private void respondUsingUtilityThread(HttpResponse response, HttpRequest request, ActionListener listener) { - threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> { + threadPool.executor(UTILITY_RESPONSE_THREAD_POOL_NAME).execute(() -> { try { listener.onResponse(HttpResult.create(settings.getMaxResponseSize(), response)); } catch (Exception e) { @@ -145,12 +159,14 @@ private void respondUsingUtilityThread(HttpResponse response, HttpRequest reques e ); listener.onFailure(e); + } finally { + EntityUtils.consumeQuietly(response.getEntity()); } }); } private void failUsingUtilityThread(Exception exception, ActionListener listener) { - threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> listener.onFailure(exception)); + threadPool.executor(UTILITY_RESPONSE_THREAD_POOL_NAME).execute(() -> listener.onFailure(exception)); } public void stream(HttpRequest request, HttpContext context, ActionListener listener) throws IOException { @@ -167,12 +183,12 @@ public void completed(Void response) { @Override public void failed(Exception ex) { - threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> streamingProcessor.failed(ex)); + threadPool.executor(UTILITY_RESPONSE_THREAD_POOL_NAME).execute(() -> streamingProcessor.failed(ex)); } @Override public void cancelled() { - threadPool.executor(UTILITY_THREAD_POOL_NAME) + threadPool.executor(UTILITY_RESPONSE_THREAD_POOL_NAME) .execute( () -> streamingProcessor.failed( new CancellationException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java index ddf19ff0dc96f..d56e8443e3e14 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java @@ -48,7 +48,7 @@ public class HttpClientManager implements Closeable { */ public static final Setting MAX_TOTAL_CONNECTIONS = Setting.intSetting( "xpack.inference.http.max_total_connections", - 50, // default + 1000, // default 1, // min Setting.Property.NodeScope, Setting.Property.Dynamic @@ -60,7 +60,7 @@ public class HttpClientManager implements Closeable { */ public static final Setting MAX_ROUTE_CONNECTIONS = Setting.intSetting( "xpack.inference.http.max_route_connections", - 20, // default + 1000, // default 1, // min Setting.Property.NodeScope, Setting.Property.Dynamic diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java index 35e50e557cc83..a5f78b50a11bd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java @@ -18,7 +18,7 @@ public class RetrySettings { static final Setting RETRY_INITIAL_DELAY_SETTING = Setting.timeSetting( "xpack.inference.http.retry.initial_delay", - TimeValue.timeValueSeconds(1), + TimeValue.timeValueMillis(5), Setting.Property.NodeScope, Setting.Property.Dynamic ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java index b71887ce6018f..561bf13f57f92 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java @@ -32,6 +32,7 @@ import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.external.http.HttpClient.USER_TOKEN; public class RetryingHttpSender implements RequestSender { @@ -191,16 +192,17 @@ private Exception wrapWithElasticsearchException(Exception e, String inferenceEn @Override public boolean shouldRetry(Exception e) { - if (retryCount.get() >= MAX_RETIES) { - return false; - } - - if (e instanceof Retryable retry) { - request = retry.rebuildRequest(request); - return retry.shouldRetry(); - } - return false; + // if (retryCount.get() >= MAX_RETIES) { + // return false; + // } + // + // if (e instanceof Retryable retry) { + // request = retry.rebuildRequest(request); + // return retry.shouldRetry(); + // } + // + // return false; } } @@ -212,14 +214,9 @@ public void send( ResponseHandler responseHandler, ActionListener listener ) { - var retrier = new InternalRetrier( - logger, - request, - HttpClientContext.create(), - hasRequestTimedOutFunction, - responseHandler, - listener - ); + var context = HttpClientContext.create(); + context.setUserToken(USER_TOKEN); + var retrier = new InternalRetrier(logger, request, context, hasRequestTimedOutFunction, responseHandler, listener); retrier.run(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index e3fff14bf95d7..cb986abc7b7b0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -252,16 +252,19 @@ private void cleanup() { private void handleTasks() { try { - if (shutdown.get()) { - logger.debug("Shutdown requested while handling tasks, cleaning up"); - cleanup(); - return; - } + TimeValue timeToWait; + do { + if (shutdown.get()) { + logger.debug("Shutdown requested while handling tasks, cleaning up"); + cleanup(); + return; + } - var timeToWait = settings.getTaskPollFrequency(); - for (var endpoint : rateLimitGroupings.values()) { - timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait); - } + timeToWait = settings.getTaskPollFrequency(); + for (var endpoint : rateLimitGroupings.values()) { + timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait); + } + } while (timeToWait.compareTo(TimeValue.ZERO) <= 0); scheduleNextHandleTasks(timeToWait); } catch (Exception e) { @@ -449,10 +452,10 @@ public synchronized TimeValue executeEnqueuedTask() { } private TimeValue executeEnqueuedTaskInternal() { - var timeBeforeAvailableToken = rateLimiter.timeToReserve(1); - if (shouldExecuteImmediately(timeBeforeAvailableToken) == false) { - return timeBeforeAvailableToken; - } + // var timeBeforeAvailableToken = rateLimiter.timeToReserve(1); + // if (shouldExecuteImmediately(timeBeforeAvailableToken) == false) { + // return timeBeforeAvailableToken; + // } var task = queue.poll(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index fe7c4a9395cd1..e2da72f18e56e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -249,15 +249,17 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId * @param listener Model listener */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + var maybeDefault = defaultConfigIds.get(inferenceEntityId); + if (maybeDefault != null) { + getDefaultConfig(false, maybeDefault, listener); + logger.debug("Returning default inference endpoint [{}] with secrets", inferenceEntityId); + return; + } + ActionListener searchListener = ActionListener.wrap((searchResponse) -> { // There should be a hit for the configurations if (searchResponse.getHits().getHits().length == 0) { - var maybeDefault = defaultConfigIds.get(inferenceEntityId); - if (maybeDefault != null) { - getDefaultConfig(true, maybeDefault, listener); - } else { - listener.onFailure(inferenceNotFoundException(inferenceEntityId)); - } + listener.onFailure(inferenceNotFoundException(inferenceEntityId)); return; } @@ -277,6 +279,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + var maybeDefault = defaultConfigIds.get(inferenceEntityId); + if (maybeDefault != null) { + getDefaultConfig(false, maybeDefault, listener); + return; + } + ActionListener searchListener = ActionListener.wrap((searchResponse) -> { // There should be a hit for the configurations if (searchResponse.getHits().getHits().length == 0) { - var maybeDefault = defaultConfigIds.get(inferenceEntityId); - if (maybeDefault != null) { - getDefaultConfig(true, maybeDefault, listener); - } else { - listener.onFailure(inferenceNotFoundException(inferenceEntityId)); - } + listener.onFailure(inferenceNotFoundException(inferenceEntityId)); return; } @@ -428,7 +432,11 @@ private void getDefaultConfig( if (persistDefaultEndpoints) { storeDefaultEndpoint(m, () -> listener.onResponse(modelToUnparsedModel(m))); } else { - listener.onResponse(modelToUnparsedModel(m)); + if (m.getSecrets() != null) { + listener.onResponse(modelToUnparsedModelWithSecrets(m)); + } else { + listener.onResponse(modelToUnparsedModel(m)); + } } break; } @@ -922,6 +930,26 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T } } + private static UnparsedModel modelToUnparsedModelWithSecrets(Model model) { + try (XContentBuilder builder = XContentFactory.jsonBuilder(); var secretsBuilder = XContentFactory.jsonBuilder()) { + model.getConfigurations() + .toXContent(builder, new ToXContent.MapParams(Map.of(ModelConfigurations.USE_ID_FOR_INDEX, Boolean.TRUE.toString()))); + + model.getSecrets() + .toXContent( + secretsBuilder, + new ToXContent.MapParams(Map.of(ModelConfigurations.USE_ID_FOR_INDEX, Boolean.TRUE.toString())) + ); + + var modelConfigMap = XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2(); + var modelSecretsMap = XContentHelper.convertToMap(BytesReference.bytes(secretsBuilder), false, builder.contentType()).v2(); + return unparsedModelFromMap(new ModelConfigMap(modelConfigMap, modelSecretsMap)); + + } catch (IOException ex) { + throw new ElasticsearchException("[{}] Error serializing inference endpoint configuration", model.getInferenceEntityId(), ex); + } + } + private static UnparsedModel modelToUnparsedModel(Model model) { try (XContentBuilder builder = XContentFactory.jsonBuilder()) { model.getConfigurations() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 62b01c779db33..9cf5b5019d609 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -206,7 +206,11 @@ private static Map initDefaultEndpoints( DEFAULT_ELSER_ENDPOINT_ID_V2, TaskType.SPARSE_EMBEDDING, NAME, - new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null, null), + new ElasticInferenceServiceSparseEmbeddingsServiceSettings( + DEFAULT_ELSER_2_MODEL_ID, + null, + new RateLimitSettings(500000) + ), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, elasticInferenceServiceComponents, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index 0d8bef246b35d..bee1d7843893f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -77,7 +77,7 @@ public class ElasticInferenceServiceSettings { */ public static final Setting CONNECTION_TTL_SETTING = Setting.timeSetting( "xpack.inference.elastic.http.connection_ttl", - TimeValue.timeValueSeconds(60), + TimeValue.timeValueMinutes(5), Setting.Property.NodeScope ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequest.java index ae52955c1d98f..0b1492b54c6d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequest.java @@ -69,6 +69,7 @@ public HttpRequestBase createHttpRequestBase() { traceContextHandler.propagateTraceContext(httpPost); httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + httpPost.setHeader(HttpHeaders.CONNECTION, "keep-alive"); return httpPost; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index b9e9e34c44736..882f304ace123 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -13,15 +13,18 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -48,6 +51,7 @@ import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.openai.request.OpenAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -58,7 +62,9 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.DEFAULT_SETTINGS; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; @@ -77,6 +83,8 @@ public class OpenAiService extends SenderService { public static final String NAME = "openai"; + private static final String DEFAULT_EMBEDDING_ID = ".openai_text_embedding"; + private static final String SERVICE_NAME = "OpenAI"; // The task types exposed via the _inference/_services API private static final EnumSet SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of( @@ -395,6 +403,43 @@ public Set supportedStreamingTasks() { return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); } + @Override + public List defaultConfigIds() { + return List.of( + new DefaultConfigId( + DEFAULT_EMBEDDING_ID, + MinimalServiceSettings.textEmbedding(name(), 1536, SimilarityMeasure.DOT_PRODUCT, DenseVectorFieldMapper.ElementType.FLOAT), + this + ) + ); + } + + @Override + public void defaultConfigs(ActionListener> listener) { + listener.onResponse( + List.of( + new OpenAiEmbeddingsModel( + ".openai_text_embedding", + TaskType.TEXT_EMBEDDING, + NAME, + new OpenAiEmbeddingsServiceSettings( + "text-embedding-3-small", + null, + null, + null, + 1536, + null, + false, + new RateLimitSettings(200000, TimeUnit.MINUTES) + ), + new OpenAiEmbeddingsTaskSettings((String) null), + DEFAULT_SETTINGS, + new DefaultSecretSettings(new SecureString(("todo").toCharArray())) + ) + ) + ); + } + /** * Model was originally defined in task settings, but it should * have been part of the service settings. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java index 6d47334da43ae..705f30902840b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java @@ -60,7 +60,7 @@ public OpenAiEmbeddingsModel( } // Should only be used directly for testing - OpenAiEmbeddingsModel( + public OpenAiEmbeddingsModel( String inferenceEntityId, TaskType taskType, String service, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceTimer.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceTimer.java index d43f4954edb52..13e73cb70ef1a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceTimer.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceTimer.java @@ -30,4 +30,8 @@ public static InferenceTimer start(Clock clock) { public long elapsedMillis() { return Duration.between(startTime(), clock().instant()).toMillis(); } + + public long elapsedNanos() { + return Duration.between(startTime(), clock().instant()).toNanos(); + } }