Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125023.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125023
summary: Fix `AlibabaCloudSearchCompletionAction` not accepting `ChatCompletionInputs`
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,18 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchCompletionRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;

import java.util.Objects;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
Expand All @@ -51,18 +49,8 @@ public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompl

@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
if (inferenceInputs instanceof DocumentsOnlyInput == false) {
listener.onFailure(
new ElasticsearchStatusException(
format("Invalid inference input type, task type [%s] do not support Field [query]", TaskType.COMPLETION),
RestStatus.INTERNAL_SERVER_ERROR
)
);
return;
}

var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs;
if (docsOnlyInput.getInputs().size() % 2 == 0) {
var completionInput = inferenceInputs.castTo(ChatCompletionInput.class);
if (completionInput.getInputs().size() % 2 == 0) {
listener.onFailure(
new ElasticsearchStatusException(
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.TimeUnit;

import static org.apache.lucene.tests.util.LuceneTestCase.expectThrows;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;

public class AlibabaCloudSearchCompletionActionTests extends ESTestCase {

private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private HttpClientManager clientManager;

@Before
public void init() throws IOException {
webServer.start();
threadPool = createThreadPool(inferenceUtilityPool());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}

@After
public void shutdown() throws IOException {
clientManager.close();
terminate(threadPool);
webServer.close();
}

public void testExecute_Success() {
var sender = mock(Sender.class);

var resultString = randomAlphaOfLength(100);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result(resultString))));

return Void.TYPE;
}).when(sender).send(any(), any(), any(), any());
var action = createAction(threadPool, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectationCompletion(List.of(resultString))));
}

public void testExecute_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() {
var sender = mock(Sender.class);
doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any());
var action = createAction(threadPool, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("error"));
}

public void testExecute_ListenerThrowsInternalServerError_WhenSenderThrowsException() {
var sender = mock(Sender.class);
doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any());
var action = createAction(threadPool, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is(constructFailedToSendRequestMessage("AlibabaCloud Search completion")));
}

public void testExecute_ThrowsIllegalArgumentException_WhenInputIsNotChatCompletionInput() {
var action = createAction(threadPool, mock(Sender.class));

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
assertThrows(IllegalArgumentException.class, () -> {
action.execute(new DocumentsOnlyInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
});
}

public void testExecute_ListenerThrowsElasticsearchStatusException_WhenInputSizeIsEven() {
var action = createAction(threadPool, mock(Sender.class));

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(
new ChatCompletionInput(List.of(randomAlphaOfLength(10), randomAlphaOfLength(10))),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
thrownException.getMessage(),
is(
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
+ "all preceding inputs are the completion history as pairs of user input and the assistant's response."
)
);
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
}

private ExecutableAction createAction(ThreadPool threadPool, Sender sender) {
var model = AlibabaCloudSearchCompletionModelTests.createModel(
"completion_test",
TaskType.COMPLETION,
AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"),
AlibabaCloudSearchCompletionTaskSettingsTests.getTaskSettingsMap(null),
getSecretSettingsMap("secret")
);

var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
return new AlibabaCloudSearchCompletionAction(sender, model, serviceComponents);
}
}