diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java index 0c551f67cc531..0ce7ea0e7a31b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.http.sender; +import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ContextPreservingActionListener; @@ -42,6 +43,7 @@ public class HttpRequestSender implements Sender { */ public static class Factory { private final HttpRequestSender httpRequestSender; + private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5); public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) { Objects.requireNonNull(serviceComponents); @@ -68,7 +70,8 @@ public Factory(ServiceComponents serviceComponents, HttpClientManager httpClient httpClientManager, requestSender, service, - startCompleted + startCompleted, + START_COMPLETED_WAIT_TIME ); } @@ -77,45 +80,95 @@ public Sender createSender() { } } - private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5); + private static final Logger logger = LogManager.getLogger(HttpRequestSender.class); private final ThreadPool threadPool; private final HttpClientManager manager; - private final AtomicBoolean started = new AtomicBoolean(false); + private final AtomicBoolean startInitiated = new AtomicBoolean(false); + private final AtomicBoolean startCompleted = new AtomicBoolean(false); private final RequestSender requestSender; private final RequestExecutor service; - private final CountDownLatch startCompleted; + private final CountDownLatch startCompletedLatch; + private final TimeValue startCompletedWaitTime; - private HttpRequestSender( + // Visible for testing + protected HttpRequestSender( ThreadPool threadPool, HttpClientManager httpClientManager, RequestSender requestSender, RequestExecutor service, - CountDownLatch startCompleted + CountDownLatch startCompletedLatch, + TimeValue startCompletedWaitTime ) { this.threadPool = Objects.requireNonNull(threadPool); this.manager = Objects.requireNonNull(httpClientManager); this.requestSender = Objects.requireNonNull(requestSender); this.service = Objects.requireNonNull(service); - this.startCompleted = Objects.requireNonNull(startCompleted); + this.startCompletedLatch = Objects.requireNonNull(startCompletedLatch); + this.startCompletedWaitTime = Objects.requireNonNull(startCompletedWaitTime); } /** - * Start various internal services. This is required before sending requests. + * Start various internal services asynchronously. This is required before sending requests. */ - public void start() { - if (started.compareAndSet(false, true)) { + @Override + public void startAsynchronously(ActionListener listener) { + if (startInitiated.compareAndSet(false, true)) { + var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()); + threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> startInternal(preservedListener)); + } else if (startCompleted.get() == false) { + var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()); + // wait on another thread so we don't potential block a transport thread + threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> waitForStartToCompleteWithListener(preservedListener)); + } else { + listener.onResponse(null); + } + } + + private void startInternal(ActionListener listener) { + try { // The manager must be started before the executor service. That way we guarantee that the http client // is ready prior to the service attempting to use the http client to send a request manager.start(); threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(service::start); waitForStartToComplete(); + startCompleted.set(true); + listener.onResponse(null); + } catch (Exception ex) { + listener.onFailure(ex); + } + } + + private void waitForStartToCompleteWithListener(ActionListener listener) { + try { + waitForStartToComplete(); + listener.onResponse(null); + } catch (Exception e) { + listener.onFailure(e); } } + /** + * Start various internal services. This is required before sending requests. + * + * NOTE: This method blocks until the startup is complete. + */ + @Override + public void startSynchronously() { + if (startInitiated.compareAndSet(false, true)) { + ActionListener listener = ActionListener.wrap( + unused -> {}, + exception -> logger.error("Http sender failed to start", exception) + ); + startInternal(listener); + } + // Handle the case where start*() was already called and this would return immediately because the started flag is already true + waitForStartToComplete(); + } + private void waitForStartToComplete() { try { - if (startCompleted.await(START_COMPLETED_WAIT_TIME.getSeconds(), TimeUnit.SECONDS) == false) { + if (startCompletedLatch.await(startCompletedWaitTime.getMillis(), TimeUnit.MILLISECONDS) == false) { throw new IllegalStateException("Http sender startup did not complete in time"); } } catch (InterruptedException e) { @@ -145,7 +198,7 @@ public void send( @Nullable TimeValue timeout, ActionListener listener ) { - assert started.get() : "call start() before sending a request"; + assert startInitiated.get() : "call start() before sending a request"; waitForStartToComplete(); service.execute(requestCreator, inferenceInputs, timeout, listener); } @@ -167,7 +220,7 @@ public void sendWithoutQueuing( @Nullable TimeValue timeout, ActionListener listener ) { - assert started.get() : "call start() before sending a request"; + assert startInitiated.get() : "call start() before sending a request"; waitForStartToComplete(); var preservedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java index 3975a554586b7..8a17becc5fa61 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java @@ -18,7 +18,9 @@ import java.io.Closeable; public interface Sender extends Closeable { - void start(); + void startSynchronously(); + + void startAsynchronously(ActionListener listener); void send( RequestManager requestCreator, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index f0b25bd427b69..655f70df89f32 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.IOUtils; @@ -73,10 +74,11 @@ public void infer( @Nullable TimeValue timeout, ActionListener listener ) { - timeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService); - init(); - var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); - doInfer(model, inferenceInput, taskSettings, timeout, listener); + SubscribableListener.newForked(this::init).andThen((inferListener) -> { + var resolvedInferenceTimeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService); + var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); + doInfer(model, inferenceInput, taskSettings, resolvedInferenceTimeout, inferListener); + }).addListener(listener); } private static InferenceInputs createInput( @@ -121,8 +123,9 @@ public void unifiedCompletionInfer( TimeValue timeout, ActionListener listener ) { - init(); - doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, listener); + SubscribableListener.newForked(this::init).andThen((completionInferListener) -> { + doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, completionInferListener); + }).addListener(listener); } @Override @@ -135,16 +138,16 @@ public void chunkedInfer( TimeValue timeout, ActionListener> listener ) { - init(); - - ValidationException validationException = new ValidationException(); - validateInputType(inputType, model, validationException); - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } + SubscribableListener.newForked(this::init).>andThen((chunkedInferListener) -> { + ValidationException validationException = new ValidationException(); + validateInputType(inputType, model, validationException); + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } - // a non-null query is not supported and is dropped by all providers - doChunkedInfer(model, input, taskSettings, inputType, timeout, listener); + // a non-null query is not supported and is dropped by all providers + doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener); + }).addListener(listener); } protected abstract void doInfer( @@ -176,8 +179,9 @@ protected abstract void doChunkedInfer( ); public void start(Model model, ActionListener listener) { - init(); - doStart(model, listener); + SubscribableListener.newForked(this::init) + .andThen((doStartListener) -> doStart(model, doStartListener)) + .addListener(listener); } @Override @@ -189,8 +193,8 @@ protected void doStart(Model model, ActionListener listener) { listener.onResponse(true); } - private void init() { - sender.start(); + private void init(ActionListener listener) { + sender.startAsynchronously(listener); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 5e3ba81b7e1bc..7f0d96c24e84b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -75,6 +75,12 @@ import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getProviderDefaultSimilarityMeasure; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.providerAllowsTaskType; +/** + * TODO we should remove AmazonBedrockService's dependency on SenderService. Bedrock leverages its own SDK with handles sending requests + * and already implements rate limiting. + * + * https://github.com/elastic/ml-team/issues/1706 + */ public class AmazonBedrockService extends SenderService { public static final String NAME = "amazonbedrock"; private static final String SERVICE_NAME = "Amazon Bedrock"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSender.java index 3f3b0db571bae..f9fc01ff3e9d5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSender.java @@ -75,7 +75,7 @@ public Factory( public Sender createSender() { // ensure this is started - bedrockRequestSender.start(); + bedrockRequestSender.startSynchronously(); return bedrockRequestSender; } } @@ -98,7 +98,13 @@ protected AmazonBedrockRequestSender( } @Override - public void start() { + public void startAsynchronously(ActionListener listener) { + + throw new UnsupportedOperationException("not implemented"); + } + + @Override + public void startSynchronously() { if (started.compareAndSet(false, true)) { // The manager must be started before the executor service. That way we guarantee that the http client // is ready prior to the service attempting to use the http client to send a request diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index 948a9a2180a36..02800105ef83d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -12,6 +12,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; @@ -82,37 +83,33 @@ public void getAuthorization(ActionListener newListener = ActionListener.wrap(results -> { - if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { - logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); - listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity)); - } else { - var errorMessage = Strings.format( - "%s Received an invalid response type from the Elastic Inference Service: %s", - FAILED_TO_RETRIEVE_MESSAGE, - results.getClass().getSimpleName() - ); - - logger.warn(errorMessage); - listener.onFailure(new ElasticsearchException(errorMessage)); - } - requestCompleteLatch.countDown(); - }, e -> { + var handleFailuresListener = listener.delegateResponse((authModelListener, e) -> { // unwrap because it's likely a retry exception var exception = ExceptionsHelper.unwrapCause(e); logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception); - listener.onFailure(e); - requestCompleteLatch.countDown(); + authModelListener.onFailure(e); }); - var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); - var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata); + SubscribableListener.newForked(sender::startAsynchronously).andThen((authListener) -> { + var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); + var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata); + sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener); + }).andThenApply(authResult -> { + if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { + logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); + return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity); + } + + var errorMessage = Strings.format( + "%s Received an invalid response type from the Elastic Inference Service: %s", + FAILED_TO_RETRIEVE_MESSAGE, + authResult.getClass().getSimpleName() + ); - sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, newListener); + logger.warn(errorMessage); + throw new ElasticsearchException(errorMessage); + }).addListener(ActionListener.runAfter(handleFailuresListener, requestCompleteLatch::countDown)); } catch (Exception e) { logger.warn(Strings.format("Retrieving the authorization information encountered an exception: %s", e)); requestCompleteLatch.countDown(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index f56b54ecc916d..027a19aca6d1f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -24,6 +24,8 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.http.HttpClient; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.RequestExecutor; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -36,9 +38,11 @@ import org.junit.Before; import java.io.IOException; +import java.util.ArrayList; import java.util.EnumSet; import java.util.List; import java.util.Locale; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -61,7 +65,10 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class HttpRequestSenderTests extends ESTestCase { @@ -102,20 +109,145 @@ public void testCreateSender_ReturnsSameRequestExecutorInstance() { } public void testCreateSender_CanCallStartMultipleTimes() throws Exception { - var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); + var mockManager = createMockHttpClientManager(); + + var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), mockManager, mockClusterServiceEmpty()); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + sender.startSynchronously(); + sender.startSynchronously(); + } + + verify(mockManager, times(1)).start(); + } + + private HttpClientManager createMockHttpClientManager() { + var mockManager = mock(HttpClientManager.class); + when(mockManager.getHttpClient()).thenReturn(mock(HttpClient.class)); + + return mockManager; + } + + public void testStart_ThrowsExceptionWaitingForStartToComplete_WhenAnErrorOccurs() throws IOException { + var mockManager = createMockHttpClientManager(); + doThrow(new Error("failed")).when(mockManager).start(); + + var senderFactory = new HttpRequestSender.Factory( + ServiceComponentsTests.createWithEmptySettings(threadPool), + mockManager, + mockClusterServiceEmpty() + ); + + try (var sender = senderFactory.createSender()) { + var exception = expectThrows(Error.class, sender::startSynchronously); + + assertThat(exception.getMessage(), is("failed")); + } + } + + public void testStart_ThrowsExceptionWaitingForStartToComplete() { + var mockManager = createMockHttpClientManager(); + doThrow(new IllegalArgumentException("failed")).when(mockManager).start(); + + // Force the startup to never complete + var latch = new CountDownLatch(1); + var sender = new HttpRequestSender( + threadPool, + mockManager, + mock(RequestSender.class), + mock(RequestExecutor.class), + latch, + // Override the wait time so we don't block the test for too long + TimeValue.timeValueMillis(1) + ); + + var exception = expectThrows(IllegalStateException.class, sender::startSynchronously); + + assertThat(exception.getMessage(), is("Http sender startup did not complete in time")); + } + + public void testStartAsync_WaitsAsyncForStartToComplete_ThrowsWhenItTimesOut_ThenSucceeds() { + var mockManager = createMockHttpClientManager(); + var latch = new CountDownLatch(1); + var sender = new HttpRequestSender( + threadPool, + mockManager, + mock(RequestSender.class), + mock(RequestExecutor.class), + latch, + // Override the wait time so we don't block the test for too long + TimeValue.timeValueMillis(1) + ); + + var listener = new PlainActionFuture(); + sender.startAsynchronously(listener); + + var exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), is("Http sender startup did not complete in time")); + + // simulate the start completing + latch.countDown(); + + var listenerCompleted = new PlainActionFuture(); + sender.startAsynchronously(listenerCompleted); + assertNull(listenerCompleted.actionGet(TIMEOUT)); + + verify(mockManager, times(1)).start(); + } + + public void testCreateSender_CanCallStartAsyncMultipleTimes() throws Exception { + var mockManager = createMockHttpClientManager(); + var asyncCalls = 3; + var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), mockManager, mockClusterServiceEmpty()); try (var sender = createSender(senderFactory)) { - sender.start(); - sender.start(); - sender.start(); + var listenerList = new ArrayList>(); + + for (int i = 0; i < asyncCalls; i++) { + PlainActionFuture listener = new PlainActionFuture<>(); + listenerList.add(listener); + sender.startAsynchronously(listener); + } + + for (int i = 0; i < asyncCalls; i++) { + PlainActionFuture listener = listenerList.get(i); + assertNull(listener.actionGet(TIMEOUT)); + } } + + verify(mockManager, times(1)).start(); + } + + public void testCreateSender_CanCallStartAsyncAndSyncMultipleTimes() throws Exception { + var mockManager = createMockHttpClientManager(); + var asyncCalls = 3; + var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), mockManager, mockClusterServiceEmpty()); + + try (var sender = createSender(senderFactory)) { + var listenerList = new ArrayList>(); + + for (int i = 0; i < asyncCalls; i++) { + PlainActionFuture listener = new PlainActionFuture<>(); + listenerList.add(listener); + sender.startAsynchronously(listener); + sender.startSynchronously(); + } + + for (int i = 0; i < asyncCalls; i++) { + PlainActionFuture listener = listenerList.get(i); + assertNull(listener.actionGet(TIMEOUT)); + } + } + + verify(mockManager, times(1)).start(); } public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception { var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -167,7 +299,7 @@ public void testSendWithoutQueuing_SendsRequestAndReceivesResponse() throws Exce var senderFactory = createSenderFactory(clientManager, threadRef); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -240,7 +372,7 @@ public void testHttpRequestSender_Throws_WhenATimeoutOccurs() throws Exception { try (var sender = senderFactory.createSender()) { assertThat(sender, instanceOf(HttpRequestSender.class)); - sender.start(); + sender.startSynchronously(); PlainActionFuture listener = new PlainActionFuture<>(); sender.send(RequestManagerTests.createMock(), new EmbeddingsInput(List.of(), null), TimeValue.timeValueNanos(1), listener); @@ -253,8 +385,7 @@ public void testHttpRequestSender_Throws_WhenATimeoutOccurs() throws Exception { } public void testHttpRequestSenderWithTimeout_Throws_WhenATimeoutOccurs() throws Exception { - var mockManager = mock(HttpClientManager.class); - when(mockManager.getHttpClient()).thenReturn(mock(HttpClient.class)); + var mockManager = createMockHttpClientManager(); var senderFactory = new HttpRequestSender.Factory( ServiceComponentsTests.createWithEmptySettings(threadPool), @@ -263,7 +394,7 @@ public void testHttpRequestSenderWithTimeout_Throws_WhenATimeoutOccurs() throws ); try (var sender = senderFactory.createSender()) { - sender.start(); + sender.startSynchronously(); PlainActionFuture listener = new PlainActionFuture<>(); sender.send(RequestManagerTests.createMock(), new EmbeddingsInput(List.of(), null), TimeValue.timeValueNanos(1), listener); @@ -276,8 +407,7 @@ public void testHttpRequestSenderWithTimeout_Throws_WhenATimeoutOccurs() throws } public void testSendWithoutQueuingWithTimeout_Throws_WhenATimeoutOccurs() throws Exception { - var mockManager = mock(HttpClientManager.class); - when(mockManager.getHttpClient()).thenReturn(mock(HttpClient.class)); + var mockManager = createMockHttpClientManager(); var senderFactory = new HttpRequestSender.Factory( ServiceComponentsTests.createWithEmptySettings(threadPool), @@ -286,7 +416,7 @@ public void testSendWithoutQueuingWithTimeout_Throws_WhenATimeoutOccurs() throws ); try (var sender = senderFactory.createSender()) { - sender.start(); + sender.startSynchronously(); PlainActionFuture listener = new PlainActionFuture<>(); sender.sendWithoutQueuing( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ChatCompletionActionTests.java index da5cc53cc94f7..b6494bd6d5a2b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ChatCompletionActionTests.java @@ -108,7 +108,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 69e5228a927e7..552482c104f97 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -43,6 +43,8 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -65,7 +67,7 @@ public void shutdown() throws IOException { } public void testStart_InitializesTheSender() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -75,7 +77,7 @@ public void testStart_InitializesTheSender() throws IOException { service.start(mock(Model.class), listener); listener.actionGet(TIMEOUT); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); verify(factory, times(1)).createSender(); } @@ -85,7 +87,7 @@ public void testStart_InitializesTheSender() throws IOException { } public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -95,11 +97,12 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep service.start(mock(Model.class), listener); listener.actionGet(TIMEOUT); - service.start(mock(Model.class), listener); - listener.actionGet(TIMEOUT); + PlainActionFuture listener2 = new PlainActionFuture<>(); + service.start(mock(Model.class), listener2); + listener2.actionGet(TIMEOUT); verify(factory, times(1)).createSender(); - verify(sender, times(2)).start(); + verify(sender, times(2)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -108,7 +111,8 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep } public void test_nullTimeoutUsesClusterSetting() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); + var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -147,7 +151,7 @@ protected void doInfer( } public void test_providedTimeoutPropagateProperly() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -185,6 +189,17 @@ protected void doInfer( } } + public static Sender createMockSender() { + var sender = mock(Sender.class); + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(0); + listener.onResponse(null); + return Void.TYPE; + }).when(sender).startAsynchronously(any()); + + return sender; + } + private static class TestSenderService extends SenderService { TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { super(factory, serviceComponents, clusterService); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/action/Ai21ActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/action/Ai21ActionCreatorTests.java index b985b53dce02a..a17ae01a2741d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/action/Ai21ActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/action/Ai21ActionCreatorTests.java @@ -73,7 +73,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() thro var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -117,7 +117,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() th var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 7636c16cd2c6f..c9f5ea738b33b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -362,21 +362,21 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType_TextEmbedding() t ); try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows( - ValidationException.class, - () -> service.infer( - model, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.CLASSIFICATION, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) + + service.infer( + model, + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.CLASSIFICATION, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + + var thrownException = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), is("Validation Failed: 1: Input type [classification] is not supported for [AlibabaCloud AI Search];") @@ -406,21 +406,21 @@ public void testInfer_ThrowsValidationExceptionForInvalidInputType_SparseEmbeddi ); try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows( - ValidationException.class, - () -> service.infer( - model, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.CLASSIFICATION, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) + + service.infer( + model, + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.CLASSIFICATION, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + + var thrownException = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), is("Validation Failed: 1: Input type [classification] is not supported for [AlibabaCloud AI Search];") @@ -450,21 +450,21 @@ public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOExc ); try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows( - ValidationException.class, - () -> service.infer( - model, - "hi", - Boolean.TRUE, - 10, - List.of("a"), - false, - new HashMap<>(), - null, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) + + service.infer( + model, + "hi", + Boolean.TRUE, + 10, + List.of("a"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + + var thrownException = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), is( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 36a4ce7b18c46..04c8bca2287e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -42,7 +42,6 @@ import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockMockRequestSender; @@ -78,6 +77,7 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettingsTests.getAmazonBedrockSecretSettingsMap; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getProviderDefaultSimilarityMeasure; import static org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettingsTests.createChatCompletionRequestSettingsMap; @@ -87,6 +87,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -952,7 +953,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( } public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -991,7 +992,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); verifyNoMoreInteractions(factory); @@ -999,7 +1000,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc } public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -1048,7 +1049,7 @@ public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException } public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -1101,7 +1102,7 @@ public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException } public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -1152,7 +1153,7 @@ public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { } public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -1193,7 +1194,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel } private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -1238,7 +1239,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si } public void testInfer_UnauthorizedResponse() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -1323,7 +1324,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { } private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java index 2e8d0bbd2ea14..f77553d3abf4c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java @@ -65,10 +65,15 @@ public InputType getInputType() { } @Override - public void start() { + public void startSynchronously() { // do nothing } + @Override + public void startAsynchronously(ActionListener listener) { + throw new UnsupportedOperationException("not supported"); + } + @Override public void send( RequestManager requestCreator, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java index fed601805a748..d608b6ec9fb8e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java @@ -81,9 +81,9 @@ public void testCreateSender_CanCallStartMultipleTimes() throws Exception { var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender); try (var sender = createSender(senderFactory)) { - sender.start(); - sender.start(); - sender.start(); + sender.startSynchronously(); + sender.startSynchronously(); + sender.startSynchronously(); } } @@ -92,7 +92,7 @@ public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT)); var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); var model = AmazonBedrockEmbeddingsModelTests.createModel( "test_id", @@ -123,7 +123,7 @@ public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws requestSender.enqueue(AmazonBedrockExecutorTests.getTestConverseResult("test response text")); var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); var model = AmazonBedrockChatCompletionModelTests.createModel( "test_id", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 7de44857cf58e..b6570ff381381 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -32,7 +32,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; @@ -65,10 +64,12 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -447,7 +448,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSetti } public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -476,7 +477,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/action/AnthropicActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/action/AnthropicActionCreatorTests.java index ab08613a9acfe..d226dbd54820e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/action/AnthropicActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/action/AnthropicActionCreatorTests.java @@ -72,7 +72,7 @@ public void testCreate_ChatCompletionModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -138,7 +138,7 @@ public void testCreate_ChatCompletionModel_FailsFromInvalidResponseFormat() thro var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/action/AnthropicChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/action/AnthropicChatCompletionActionTests.java index ba71d1482eb29..0f770ea9cedc8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/action/AnthropicChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/action/AnthropicChatCompletionActionTests.java @@ -84,7 +84,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -193,7 +193,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 127b7d1c4cfae..06668b91e1965 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -44,7 +44,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; @@ -84,6 +83,7 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.API_KEY_FIELD; import static org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRequestFields.API_KEY_HEADER; @@ -92,6 +92,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -1181,7 +1182,7 @@ private void testUpdateModelWithChatCompletionDetails_Successful(Integer maxNewT } public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -1210,7 +1211,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -1219,7 +1220,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc } public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -1228,28 +1229,28 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows( - ValidationException.class, - () -> service.infer( - mockModel, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.CLASSIFICATION, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) + + service.infer( + mockModel, + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.CLASSIFICATION, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + + var thrownException = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), is("Validation Failed: 1: Input type [classification] is not supported for [Azure AI Studio];") ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java index daceedd2b8207..975661d180795 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java @@ -90,7 +90,7 @@ public void testEmbeddingsRequestAction() throws IOException { final var serviceComponents = getServiceComponents(); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson)); @@ -129,7 +129,7 @@ public void testChatCompletionRequestAction() throws IOException { final var serviceComponents = getServiceComponents(); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson)); final var webserverUrl = getUrl(webServer); @@ -166,7 +166,7 @@ public void testRerankRequestAction() throws IOException { final var serviceComponents = getServiceComponents(); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson)); final var webserverUrl = getUrl(webServer); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index cf3ac6979b8e3..b5d8fb887c62c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -41,7 +41,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; @@ -73,6 +72,7 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettingsTests.getAzureOpenAiSecretSettingsMap; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getPersistentAzureOpenAiServiceSettingsMap; @@ -84,6 +84,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -748,7 +749,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( } public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -777,7 +778,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java index c1d69580af8fd..1724747f9861a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java @@ -87,7 +87,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -138,7 +138,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOExcepti var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -190,7 +190,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -246,7 +246,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); // note - there is no complete documentation on Azure's error messages // but this error and response has been verified manually via CURL @@ -323,7 +323,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); // note - there is no complete documentation on Azure's error messages // but this error and response has been verified manually via CURL @@ -400,7 +400,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -452,7 +452,7 @@ public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOExcept var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -510,7 +510,7 @@ public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOExceptio var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -566,7 +566,7 @@ public void testInfer_AzureOpenAiCompletionModel_FailsFromInvalidResponseFormat( var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); // "choices" missing String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiCompletionActionTests.java index 42d18bac183fa..c3ff37195128e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiCompletionActionTests.java @@ -81,7 +81,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java index 0c6b7f62a96b3..20ac31398de2f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java @@ -88,7 +88,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { ); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 9642e7b85cdc7..d3f00e9c06c5e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -44,7 +44,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; @@ -79,6 +78,7 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; @@ -88,6 +88,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -770,7 +771,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( } public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -799,7 +800,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java index 8ab76bb728802..f3d816d9a118d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java @@ -76,7 +76,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -157,7 +157,7 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java index 8285613f48b80..5b8293725417e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java @@ -79,7 +79,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -232,7 +232,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index ca068d3e1859d..f5191ecac5ce6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -84,7 +84,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -168,7 +168,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index a44a692e9b912..760312fe2d97b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -48,7 +48,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; @@ -90,6 +89,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; @@ -407,7 +407,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -436,7 +436,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -445,31 +445,25 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException } public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOException { - var sender = mock(Sender.class); - - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - try (var service = createServiceWithMockSender()) { var model = ElasticInferenceServiceRerankModelTests.createModel(getUrl(webServer), "my-rerank-model-id"); PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows( - ValidationException.class, - () -> service.infer( - model, - "search query", - Boolean.TRUE, - 10, - List.of("doc1", "doc2", "doc3"), - false, - new HashMap<>(), - InputType.SEARCH, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) + service.infer( + model, + "search query", + Boolean.TRUE, + 10, + List.of("doc1", "doc2", "doc3"), + false, + new HashMap<>(), + InputType.SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + var thrownException = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( thrownException.getMessage(), is("Validation Failed: 1: Invalid return_documents [true]. The return_documents option is not supported by this service;") @@ -478,7 +472,7 @@ public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOExc } public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -512,7 +506,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -1433,7 +1427,7 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index 7ee6d817f899c..1d91713eef931 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -81,7 +81,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -144,7 +144,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); // This will fail because the expected output is {"data": [{...}]} String responseJson = """ @@ -192,7 +192,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForRerankAction() throws IOExc var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -264,7 +264,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -326,7 +326,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -383,7 +383,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForDenseTextEmbeddingsAction var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); // This will fail because the expected output is {"data": [[...]]} String responseJson = """ @@ -428,7 +428,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -463,7 +463,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJsonContentTooLarge = """ { @@ -534,7 +534,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index 0a6a01ae18d13..e3d24ea2ec8f7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -24,7 +24,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.junit.After; import org.junit.Before; @@ -40,6 +39,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -197,14 +197,14 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { } } - @SuppressWarnings("unchecked") public void testGetAuthorization_OnResponseCalledOnce() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var eisGatewayUrl = getUrl(webServer); var logger = mock(Logger.class); var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger); - ActionListener listener = mock(ActionListener.class); + PlainActionFuture listener = new PlainActionFuture<>(); + ActionListener onlyOnceListener = ActionListener.assertOnce(listener); String responseJson = """ { "models": [ @@ -218,10 +218,14 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); try (var sender = senderFactory.createSender()) { - authHandler.getAuthorization(listener, sender); + authHandler.getAuthorization(onlyOnceListener, sender); authHandler.waitForAuthRequestCompletion(TIMEOUT); - verify(listener, times(1)).onResponse(any()); + var authResponse = listener.actionGet(TIMEOUT); + assertThat(authResponse.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); + assertThat(authResponse.getAuthorizedModelIds(), is(Set.of("model-a"))); + assertTrue(authResponse.isAuthorized()); + var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger, times(1)).debug(loggerArgsCaptor.capture()); @@ -231,7 +235,7 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { } public void testGetAuthorization_InvalidResponse() throws IOException { - var senderMock = mock(Sender.class); + var senderMock = createMockSender(); var senderFactory = mock(HttpRequestSender.Factory.class); when(senderFactory.createSender()).thenReturn(senderMock); @@ -257,6 +261,5 @@ public void testGetAuthorization_InvalidResponse() throws IOException { var message = loggerArgsCaptor.getValue(); assertThat(message, containsString("Failed to retrieve the authorization information from the Elastic Inference Service.")); } - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 65ccc52cefc30..be087f73f8d5b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -43,7 +43,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceFields; @@ -76,6 +75,7 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; @@ -86,6 +86,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.hasSize; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -652,7 +653,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSetti } public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -681,7 +682,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -690,7 +691,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx } public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatDoesNotAcceptTaskType() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -699,28 +700,28 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows( - ValidationException.class, - () -> service.infer( - model, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) + + service.infer( + model, + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + + var thrownException = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), is("Validation Failed: 1: Invalid value [ingest] received. [input_type] is not allowed for model [model];") ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioCompletionActionTests.java index 1517e4b6b85af..f92505f7ebb10 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioCompletionActionTests.java @@ -78,7 +78,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -203,7 +203,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java index a36f06841ddcf..38e33c546b291 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java @@ -87,7 +87,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); try (var sender = senderFactory.createSender()) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 7acbe17340757..e5cf901e56116 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -17,7 +17,6 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.junit.After; import org.junit.Before; @@ -30,8 +29,10 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.is; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -54,7 +55,7 @@ public void shutdown() throws IOException { } public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -83,7 +84,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index b28a6e6636c70..67534d83a2e15 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -48,7 +48,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; @@ -85,6 +84,7 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; @@ -93,6 +93,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.isA; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -253,7 +254,7 @@ public void testParseRequestConfig_CreatesHuggingFaceChatCompletionsModel_Withou } public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -282,7 +283,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -1064,21 +1065,20 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows( - ValidationException.class, - () -> service.infer( - model, - null, - null, - null, - List.of("abc"), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) + service.infer( + model, + null, + null, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + + var thrownException = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), is("Validation Failed: 1: Invalid input_type [ingest]. The input_type option is not supported by this service;") diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index 70632a439fdea..1194f3a5fa95c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -84,7 +84,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ [ @@ -144,7 +144,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ [ @@ -204,7 +204,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -260,7 +260,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); // this will fail because the only valid formats are {"embeddings": [[...]]} or [[...]] String responseJson = """ @@ -322,7 +322,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForRerankAction() throws IOE var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -381,7 +381,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJsonContentTooLarge = """ { @@ -450,7 +450,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -500,7 +500,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() thro var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -548,7 +548,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() th var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java index ed61b409e75ad..daa174782c83a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java @@ -86,7 +86,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -186,7 +186,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 7bd68f0ba0510..6e24981e3f3b3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -79,6 +79,7 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; @@ -87,6 +88,7 @@ import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -591,7 +593,7 @@ public void testParsePersistedConfig_CreatesAIbmWatsonxEmbeddingsModelWhenChunki } public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -620,7 +622,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -629,7 +631,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept } public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -639,28 +641,27 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows( - ValidationException.class, - () -> service.infer( - model, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) + service.infer( + model, + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + + var thrownException = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); MatcherAssert.assertThat( thrownException.getMessage(), is("Validation Failed: 1: Invalid input_type [ingest]. The input_type option is not supported by this service;") ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java index c2ec051fd7a42..a0980a8151036 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java @@ -92,7 +92,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); try (var sender = senderFactory.createSender()) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 998559d102ab7..d408e269219cb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -42,7 +42,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; @@ -75,6 +74,7 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; @@ -82,6 +82,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -769,7 +770,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( } public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -798,7 +799,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java index 709a4a0630ba0..5bf65870e1dfc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java @@ -80,7 +80,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -113,7 +113,7 @@ public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction() thr var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ [ @@ -145,7 +145,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForCompletionAction() throws I var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -201,7 +201,7 @@ public void testExecute_FailsFromInvalidResponseFormat_ForCompletionAction() thr var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index c8f6d0ee0e2fe..db94ec5c9c2f5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -47,7 +47,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; @@ -86,6 +85,7 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_KEY_FIELD; import static org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettingsTests.getServiceSettingsMap; @@ -96,6 +96,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.isA; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -243,7 +244,7 @@ public void testParseRequestConfig_ThrowsException_WithoutModelId() throws IOExc } public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -272,7 +273,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -982,7 +983,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si } public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -1011,7 +1012,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -1020,7 +1021,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I } public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -1030,28 +1031,27 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows( - ValidationException.class, - () -> service.infer( - model, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) + service.infer( + model, + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + + var thrownException = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), is("Validation Failed: 1: Invalid input_type [ingest]. The input_type option is not supported by this service;") ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreatorTests.java index 2db2a3298dee5..b43d2f4e99e99 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreatorTests.java @@ -73,7 +73,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() thro var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -121,7 +121,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() th var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java index 4f8e99ecfe8fc..d93e1371f1b1a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java @@ -48,7 +48,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException, URISynta var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index c2cda175831ed..0a03c7a231e8c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -49,7 +49,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; @@ -92,6 +91,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; @@ -105,6 +105,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.isA; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -352,7 +353,7 @@ public void testParseRequestConfig_MovesModel() throws IOException { } public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -381,7 +382,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -390,7 +391,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException } public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -399,28 +400,28 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows( - ValidationException.class, - () -> service.infer( - model, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) + service.infer( + model, + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + + var thrownException = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( thrownException.getMessage(), is("Validation Failed: 1: Invalid input_type [ingest]. The input_type option is not supported by this service;") ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -429,7 +430,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { } public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -461,7 +462,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -470,7 +471,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { } public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -504,7 +505,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java index df4e5cf4e3822..5cdcf402835b4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java @@ -78,7 +78,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -135,7 +135,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -191,7 +191,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOExce var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -254,7 +254,7 @@ public void testCreate_OpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() th var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -316,7 +316,7 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -380,7 +380,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -443,7 +443,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IO var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -513,7 +513,7 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -578,7 +578,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); var contentTooLargeErrorMessage = "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;" @@ -665,7 +665,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); var contentTooLargeErrorMessage = "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;" @@ -752,7 +752,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiChatCompletionActionTests.java index 0f16daa93a189..fb5dc530d69f5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiChatCompletionActionTests.java @@ -89,7 +89,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -242,7 +242,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java index 6758f66f2917e..c34ebbde4ac8f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java @@ -86,7 +86,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { ); try (var sender = senderFactory.createSender()) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index d7f5726af85e0..9bbfcb4d58667 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -41,7 +41,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; @@ -73,6 +72,7 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; @@ -80,6 +80,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -715,7 +716,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( } public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -744,7 +745,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOExceptio ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); @@ -753,7 +754,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOExceptio } public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOException { - var sender = mock(Sender.class); + var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -770,28 +771,27 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows( - ValidationException.class, - () -> service.infer( - model, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.CLUSTERING, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) + service.infer( + model, + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.CLUSTERING, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + + var thrownException = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); MatcherAssert.assertThat( thrownException.getMessage(), is("Validation Failed: 1: Input type [clustering] is not supported for [Voyage AI];") ); verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + verify(sender, times(1)).startAsynchronously(any()); } verify(sender, times(1)).close(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java index e2802c3569d86..afbb5532d5fdd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java @@ -74,7 +74,7 @@ public void testCreate_VoyageAIEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java index ba97ba3b70d00..7050c2d131e17 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java @@ -88,7 +88,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -185,7 +185,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ { @@ -282,7 +282,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { - sender.start(); + sender.startSynchronously(); String responseJson = """ {