Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125023.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125023
summary: Fix `AlibabaCloudSearchCompletionAction` not accepting `ChatCompletionInputs`
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,18 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchCompletionRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;

import java.util.Objects;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
Expand All @@ -51,18 +49,8 @@ public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompl

@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
if (inferenceInputs instanceof EmbeddingsInput == false) {
listener.onFailure(
new ElasticsearchStatusException(
format("Invalid inference input type, task type [%s] do not support Field [query]", TaskType.COMPLETION),
RestStatus.INTERNAL_SERVER_ERROR
)
);
return;
}

var docsOnlyInput = (EmbeddingsInput) inferenceInputs;
if (docsOnlyInput.getInputs().size() % 2 == 0) {
var completionInput = inferenceInputs.castTo(ChatCompletionInput.class);
if (completionInput.getInputs().size() % 2 == 0) {
listener.onFailure(
new ElasticsearchStatusException(
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.TimeUnit;

import static org.apache.lucene.tests.util.LuceneTestCase.expectThrows;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;

public class AlibabaCloudSearchCompletionActionTests extends ESTestCase {

private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private HttpClientManager clientManager;

@Before
public void init() throws IOException {
webServer.start();
threadPool = createThreadPool(inferenceUtilityPool());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}

@After
public void shutdown() throws IOException {
clientManager.close();
terminate(threadPool);
webServer.close();
}

public void testExecute_Success() {
var sender = mock(Sender.class);

var resultString = randomAlphaOfLength(100);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result(resultString))));

return Void.TYPE;
}).when(sender).send(any(), any(), any(), any());
var action = createAction(threadPool, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectationCompletion(List.of(resultString))));
}

public void testExecute_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() {
var sender = mock(Sender.class);
doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any());
var action = createAction(threadPool, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("error"));
}

public void testExecute_ListenerThrowsInternalServerError_WhenSenderThrowsException() {
var sender = mock(Sender.class);
doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any());
var action = createAction(threadPool, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is(constructFailedToSendRequestMessage("AlibabaCloud Search completion")));
}

public void testExecute_ThrowsIllegalArgumentException_WhenInputIsNotChatCompletionInput() {
var action = createAction(threadPool, mock(Sender.class));

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
assertThrows(IllegalArgumentException.class, () -> {
action.execute(
new EmbeddingsInput(List.of(randomAlphaOfLength(10)), InputType.INGEST),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
});
}

public void testExecute_ListenerThrowsElasticsearchStatusException_WhenInputSizeIsEven() {
var action = createAction(threadPool, mock(Sender.class));

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(
new ChatCompletionInput(List.of(randomAlphaOfLength(10), randomAlphaOfLength(10))),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
thrownException.getMessage(),
is(
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
+ "all preceding inputs are the completion history as pairs of user input and the assistant's response."
)
);
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
}

private ExecutableAction createAction(ThreadPool threadPool, Sender sender) {
var model = AlibabaCloudSearchCompletionModelTests.createModel(
"completion_test",
TaskType.COMPLETION,
AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"),
AlibabaCloudSearchCompletionTaskSettingsTests.getTaskSettingsMap(null),
getSecretSettingsMap("secret")
);

var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
return new AlibabaCloudSearchCompletionAction(sender, model, serviceComponents);
}
}
Loading