diff --git a/docs/changelog/137219.yaml b/docs/changelog/137219.yaml new file mode 100644 index 0000000000000..ec6587af5cf81 --- /dev/null +++ b/docs/changelog/137219.yaml @@ -0,0 +1,5 @@ +pr: 137219 +summary: Perform query field validation for rerank task type +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 0147c62823f0d..f191102bff869 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -90,6 +90,11 @@ private static InferenceInputs createInput( case RERANK -> { ValidationException validationException = new ValidationException(); service.validateRerankParameters(returnDocuments, topN, validationException); + + if (query == null) { + validationException.addValidationError("Rerank task type requires a non-null query field"); + } + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 659b935aff8d8..0c50ae9360b91 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -21,8 +21,10 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.junit.After; @@ -34,9 +36,12 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -101,7 +106,113 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep verifyNoMoreInteractions(sender); } - private static final class TestSenderService extends SenderService { + public void testReturnsValidationException_WhenQueryIsNullForRerankTaskType() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + try (var testService = new TestSenderService(factory, createWithEmptySettings(threadPool))) { + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(TaskType.RERANK); + + var exception = expectThrows( + ValidationException.class, + () -> testService.infer( + model, + null, + null, + null, + List.of("test input"), + false, + Map.of(), + InputType.SEARCH, + null, + new PlainActionFuture<>() + ) + ); + + assertThat(exception.getMessage(), containsString("Rerank task type requires a non-null query field")); + } + } + + public void testInferSucceeds_WhenQueryIsDefinedForRerankTaskType() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var queryString = "a query"; + var testInput = "test input"; + var doInferCalled = new AtomicReference<>(false); + + var testService = new TestSenderService(factory, createWithEmptySettings(threadPool)) { + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + var queryDocs = inputs.castTo(QueryAndDocsInputs.class); + assertThat(queryDocs.getQuery(), is(queryString)); + assertThat(queryDocs.getChunks(), is(List.of(testInput))); + doInferCalled.set(true); + listener.onResponse(mock(InferenceServiceResults.class)); + } + }; + + try (testService) { + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(TaskType.RERANK); + + PlainActionFuture listener = new PlainActionFuture<>(); + + testService.infer(model, queryString, null, null, List.of(testInput), false, Map.of(), null, null, listener); + assertNotNull(listener.actionGet(TIMEOUT)); + assertTrue(doInferCalled.get()); + } + } + + public void testInferSucceeds_WhenQueryIsNotDefinedForCompletionTaskType() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var testInput = "test input"; + var doInferCalled = new AtomicReference<>(false); + + var testService = new TestSenderService(factory, createWithEmptySettings(threadPool)) { + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + var castedInput = inputs.castTo(ChatCompletionInput.class); + assertThat(castedInput.getInputs(), is(List.of(testInput))); + doInferCalled.set(true); + listener.onResponse(mock(InferenceServiceResults.class)); + } + }; + + try (testService) { + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(TaskType.COMPLETION); + + PlainActionFuture listener = new PlainActionFuture<>(); + + testService.infer(model, null, null, null, List.of(testInput), false, Map.of(), null, null, listener); + assertNotNull(listener.actionGet(TIMEOUT)); + assertTrue(doInferCalled.get()); + } + } + + private static class TestSenderService extends SenderService { TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index f91c731268377..f4c0e97bd56ec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -14,8 +14,11 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; @@ -53,6 +56,10 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettingsTests; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; @@ -68,6 +75,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank; import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; @@ -115,17 +123,17 @@ public void shutdown() throws IOException { public void testParseRequestConfig_CreatesACohereEmbeddingsModel() throws IOException { try (var service = createCohereService()) { ActionListener modelListener = ActionListener.wrap(model -> { - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT)); - MatcherAssert.assertThat( + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT)); + assertThat( embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START)) ); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); }, e -> fail("Model parsing should have succeeded " + e.getMessage())); service.parseRequestConfig( @@ -145,19 +153,19 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModel() throws IOExce public void testParseRequestConfig_CreatesACohereEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createCohereService()) { ActionListener modelListener = ActionListener.wrap(model -> { - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT)); - MatcherAssert.assertThat( + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT)); + assertThat( embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START)) ); - MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); }, e -> fail("Model parsing should have succeeded " + e.getMessage())); service.parseRequestConfig( @@ -178,19 +186,19 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModelWhenChunkingSett public void testParseRequestConfig_CreatesACohereEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createCohereService()) { ActionListener modelListener = ActionListener.wrap(model -> { - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT)); - MatcherAssert.assertThat( + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT)); + assertThat( embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START)) ); - MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); }, e -> fail("Model parsing should have succeeded " + e.getMessage())); service.parseRequestConfig( @@ -211,14 +219,14 @@ public void testParseRequestConfig_OptionalTaskSettings() throws IOException { try (var service = createCohereService()) { ActionListener modelListener = ActionListener.wrap(model -> { - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT)); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), equalTo(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT)); + assertThat(embeddingsModel.getTaskSettings(), equalTo(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); }, e -> fail("Model parsing should have succeeded " + e.getMessage())); service.parseRequestConfig( @@ -256,8 +264,8 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti private static ActionListener getModelListenerForException(Class exceptionClass, String expectedMessage) { return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { - MatcherAssert.assertThat(e, instanceOf(exceptionClass)); - MatcherAssert.assertThat(e.getMessage(), is(expectedMessage)); + assertThat(e, instanceOf(exceptionClass)); + assertThat(e.getMessage(), is(expectedMessage)); }); } @@ -332,12 +340,12 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap public void testParseRequestConfig_CreatesACohereEmbeddingsModelWithoutUrl() throws IOException { try (var service = createCohereService()) { var modelListener = ActionListener.wrap((model) -> { - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); }, (e) -> fail("Model parsing should have succeeded " + e.getMessage())); var serviceSettings = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null); @@ -367,13 +375,13 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModel() persistedConfig.secrets() ); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -393,14 +401,14 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWhe persistedConfig.secrets() ); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); - MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -419,14 +427,14 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWhe persistedConfig.secrets() ); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); - MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -470,12 +478,12 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWit persistedConfig.secrets() ); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null))); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null))); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -495,17 +503,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists persistedConfig.secrets() ); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.BYTE)); - MatcherAssert.assertThat( - embeddingsModel.getTaskSettings(), - is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE)) - ); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.BYTE)); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE))); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -527,12 +532,12 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists persistedConfig.secrets() ); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -552,13 +557,13 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe persistedConfig.secrets() ); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -576,12 +581,12 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe persistedConfig.secrets() ); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -603,13 +608,13 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa persistedConfig.secrets() ); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null))); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null))); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -622,12 +627,12 @@ public void testParsePersistedConfig_CreatesACohereEmbeddingsModel() throws IOEx var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE))); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE))); assertNull(embeddingsModel.getSecretSettings()); } } @@ -642,13 +647,13 @@ public void testParsePersistedConfig_CreatesACohereEmbeddingsModelWhenChunkingSe var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE))); - MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE))); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertNull(embeddingsModel.getSecretSettings()); } } @@ -662,13 +667,13 @@ public void testParsePersistedConfig_CreatesACohereEmbeddingsModelWhenChunkingSe var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE))); - MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE))); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertNull(embeddingsModel.getSecretSettings()); } } @@ -701,13 +706,13 @@ public void testParsePersistedConfig_CreatesACohereEmbeddingsModelWithoutUrl() t var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT)); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT)); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null))); assertNull(embeddingsModel.getSecretSettings()); } } @@ -722,11 +727,11 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS)); assertNull(embeddingsModel.getSecretSettings()); } } @@ -740,11 +745,11 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null))); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null))); assertNull(embeddingsModel.getSecretSettings()); } } @@ -761,12 +766,12 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class)); + assertThat(model, instanceOf(CohereEmbeddingsModel.class)); var embeddingsModel = (CohereEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null))); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null))); assertNull(embeddingsModel.getSecretSettings()); } } @@ -795,7 +800,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - MatcherAssert.assertThat( + assertThat( thrownException.getMessage(), is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); @@ -866,23 +871,115 @@ public void testInfer_SendsRequest() throws IOException { var result = listener.actionGet(TIMEOUT); - MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); + assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); - MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaType()) - ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( + assertThat( requestMap, is(Map.of("texts", List.of("abc"), "model", "model", "input_type", "search_document", "embedding_types", List.of("float"))) ); } } + public void testInfer_ReturnsValidationException_WhenSendingRerankRequest_WithoutQueryField() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + var secret = "secret"; + var modelName = "model"; + + var model = new CohereRerankModel( + "id", + new CohereRerankServiceSettings("abc", modelName, null, CohereServiceSettings.CohereApiVersion.V2), + new CohereRerankTaskSettings(null, null, null), + new DefaultSecretSettings(new SecureString(secret.toCharArray())) + ); + + var exception = expectThrows( + ValidationException.class, + () -> service.infer( + model, + // null query string will trigger validation error + null, + null, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + new PlainActionFuture<>() + ) + ); + + assertThat(exception.getMessage(), containsString("Rerank task type requires a non-null query field")); + } + } + + public void testInfer_SendsRerankRequest() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + String responseJson = """ + { + "index": "d0760819-5a73-4d58-b163-3956d3648b62", + "results": [ + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "search_units": 1 + } + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var secret = "secret"; + var modelName = "model"; + + var model = new CohereRerankModel( + "id", + new CohereRerankServiceSettings(getUrl(webServer), modelName, null, CohereServiceSettings.CohereApiVersion.V2), + new CohereRerankTaskSettings(null, null, null), + new DefaultSecretSettings(new SecureString(secret.toCharArray())) + ); + + var queryString = "a query"; + + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + queryString, + null, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), Matchers.is(buildExpectationRerank(List.of()))); + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is(Strings.format("Bearer %s", secret))); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap, is(Map.of("query", queryString, "documents", List.of("abc"), "model", modelName))); + } + } + public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -966,9 +1063,9 @@ public void testInfer_UnauthorisedResponse() throws IOException { ); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - MatcherAssert.assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); - MatcherAssert.assertThat(error.getMessage(), containsString("Error message: [invalid api token]")); - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + assertThat(error.getMessage(), containsString("Error message: [invalid api token]")); + assertThat(webServer.requests(), hasSize(1)); } } @@ -1031,16 +1128,13 @@ public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsA assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); - MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaType()) - ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( + assertThat( requestMap, is(Map.of("texts", List.of("abc"), "model", "model", "input_type", "search_document", "embedding_types", List.of("float"))) ); @@ -1105,17 +1199,14 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs var result = listener.actionGet(TIMEOUT); - MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); + assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); - MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaType()) - ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( + assertThat( requestMap, is(Map.of("texts", List.of("abc"), "model", "model", "input_type", "search_document", "embedding_types", List.of("float"))) ); @@ -1180,20 +1271,14 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec var result = listener.actionGet(TIMEOUT); - MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); + assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); - MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaType()) - ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( - requestMap, - is(Map.of("texts", List.of("abc"), "model", "model", "embedding_types", List.of("float"))) - ); + assertThat(requestMap, is(Map.of("texts", List.of("abc"), "model", "model", "embedding_types", List.of("float")))); } } @@ -1255,9 +1340,9 @@ public void testInfer_DefaultsInputType_WhenNotPresentInTaskSettings_AndUnspecif listener.actionGet(TIMEOUT); - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests(), hasSize(1)); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( + assertThat( requestMap, is(Map.of("texts", List.of("abc"), "model", "model", "embedding_types", List.of("float"), "input_type", "search_query")) ); @@ -1368,16 +1453,13 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { ); } - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); - MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaType()) - ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( + assertThat( requestMap, is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("float"), "input_type", "search_query")) ); @@ -1467,16 +1549,13 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { ); } - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); - MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaType()) - ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( + assertThat( requestMap, is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("int8"), "input_type", "search_query")) );