Skip to content

Commit 3b743c8

Browse files
[8.x] Fix AlibabaCloudSearchCompletionAction not accepting ChatCompletionInputs (#125023) (#125332)
* Fix AlibabaCloudSearchCompletionAction not accepting ChatCompletionInputs (#125023) * Fix AlibabaCloudSearchCompletionAction not accepting ChatCompletionInputs * Update docs/changelog/125023.yaml * Fix unit tests * Fixing unit test import --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent 9293f35 commit 3b743c8

File tree

3 files changed

+173
-15
lines changed

3 files changed

+173
-15
lines changed

docs/changelog/125023.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 125023
2+
summary: Fix `AlibabaCloudSearchCompletionAction` not accepting `ChatCompletionInputs`
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionAction.java

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,18 @@
1414
import org.elasticsearch.action.ActionListener;
1515
import org.elasticsearch.core.TimeValue;
1616
import org.elasticsearch.inference.InferenceServiceResults;
17-
import org.elasticsearch.inference.TaskType;
1817
import org.elasticsearch.rest.RestStatus;
1918
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
2019
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
2120
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchCompletionRequestManager;
22-
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
21+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
2322
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
2423
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2524
import org.elasticsearch.xpack.inference.services.ServiceComponents;
2625
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
2726

2827
import java.util.Objects;
2928

30-
import static org.elasticsearch.core.Strings.format;
3129
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
3230
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
3331
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
@@ -51,18 +49,8 @@ public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompl
5149

5250
@Override
5351
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
54-
if (inferenceInputs instanceof EmbeddingsInput == false) {
55-
listener.onFailure(
56-
new ElasticsearchStatusException(
57-
format("Invalid inference input type, task type [%s] do not support Field [query]", TaskType.COMPLETION),
58-
RestStatus.INTERNAL_SERVER_ERROR
59-
)
60-
);
61-
return;
62-
}
63-
64-
var docsOnlyInput = (EmbeddingsInput) inferenceInputs;
65-
if (docsOnlyInput.getInputs().size() % 2 == 0) {
52+
var completionInput = inferenceInputs.castTo(ChatCompletionInput.class);
53+
if (completionInput.getInputs().size() % 2 == 0) {
6654
listener.onFailure(
6755
new ElasticsearchStatusException(
6856
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;
9+
10+
import org.elasticsearch.ElasticsearchException;
11+
import org.elasticsearch.ElasticsearchStatusException;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.action.support.PlainActionFuture;
14+
import org.elasticsearch.common.settings.Settings;
15+
import org.elasticsearch.core.TimeValue;
16+
import org.elasticsearch.inference.InferenceServiceResults;
17+
import org.elasticsearch.inference.InputType;
18+
import org.elasticsearch.inference.TaskType;
19+
import org.elasticsearch.rest.RestStatus;
20+
import org.elasticsearch.test.ESTestCase;
21+
import org.elasticsearch.test.http.MockWebServer;
22+
import org.elasticsearch.threadpool.ThreadPool;
23+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
24+
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
25+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
26+
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
27+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
28+
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
29+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
30+
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
31+
import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
32+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests;
33+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests;
34+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests;
35+
import org.junit.After;
36+
import org.junit.Before;
37+
38+
import java.io.IOException;
39+
import java.util.List;
40+
import java.util.concurrent.TimeUnit;
41+
42+
import static org.apache.lucene.tests.util.LuceneTestCase.expectThrows;
43+
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
44+
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
45+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
46+
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
47+
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
48+
import static org.hamcrest.MatcherAssert.assertThat;
49+
import static org.hamcrest.Matchers.is;
50+
import static org.mockito.ArgumentMatchers.any;
51+
import static org.mockito.Mockito.doAnswer;
52+
import static org.mockito.Mockito.doThrow;
53+
import static org.mockito.Mockito.mock;
54+
55+
public class AlibabaCloudSearchCompletionActionTests extends ESTestCase {
56+
57+
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
58+
private final MockWebServer webServer = new MockWebServer();
59+
private ThreadPool threadPool;
60+
private HttpClientManager clientManager;
61+
62+
@Before
63+
public void init() throws IOException {
64+
webServer.start();
65+
threadPool = createThreadPool(inferenceUtilityPool());
66+
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
67+
}
68+
69+
@After
70+
public void shutdown() throws IOException {
71+
clientManager.close();
72+
terminate(threadPool);
73+
webServer.close();
74+
}
75+
76+
public void testExecute_Success() {
77+
var sender = mock(Sender.class);
78+
79+
var resultString = randomAlphaOfLength(100);
80+
doAnswer(invocation -> {
81+
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
82+
listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result(resultString))));
83+
84+
return Void.TYPE;
85+
}).when(sender).send(any(), any(), any(), any());
86+
var action = createAction(threadPool, sender);
87+
88+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
89+
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
90+
91+
var result = listener.actionGet(TIMEOUT);
92+
assertThat(result.asMap(), is(buildExpectationCompletion(List.of(resultString))));
93+
}
94+
95+
public void testExecute_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() {
96+
var sender = mock(Sender.class);
97+
doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any());
98+
var action = createAction(threadPool, sender);
99+
100+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
101+
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
102+
103+
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
104+
assertThat(thrownException.getMessage(), is("error"));
105+
}
106+
107+
public void testExecute_ListenerThrowsInternalServerError_WhenSenderThrowsException() {
108+
var sender = mock(Sender.class);
109+
doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any());
110+
var action = createAction(threadPool, sender);
111+
112+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
113+
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
114+
115+
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
116+
assertThat(thrownException.getMessage(), is(constructFailedToSendRequestMessage("AlibabaCloud Search completion")));
117+
}
118+
119+
public void testExecute_ThrowsIllegalArgumentException_WhenInputIsNotChatCompletionInput() {
120+
var action = createAction(threadPool, mock(Sender.class));
121+
122+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
123+
assertThrows(IllegalArgumentException.class, () -> {
124+
action.execute(
125+
new EmbeddingsInput(List.of(randomAlphaOfLength(10)), InputType.INGEST),
126+
InferenceAction.Request.DEFAULT_TIMEOUT,
127+
listener
128+
);
129+
});
130+
}
131+
132+
public void testExecute_ListenerThrowsElasticsearchStatusException_WhenInputSizeIsEven() {
133+
var action = createAction(threadPool, mock(Sender.class));
134+
135+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
136+
action.execute(
137+
new ChatCompletionInput(List.of(randomAlphaOfLength(10), randomAlphaOfLength(10))),
138+
InferenceAction.Request.DEFAULT_TIMEOUT,
139+
listener
140+
);
141+
142+
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
143+
assertThat(
144+
thrownException.getMessage(),
145+
is(
146+
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
147+
+ "all preceding inputs are the completion history as pairs of user input and the assistant's response."
148+
)
149+
);
150+
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
151+
}
152+
153+
private ExecutableAction createAction(ThreadPool threadPool, Sender sender) {
154+
var model = AlibabaCloudSearchCompletionModelTests.createModel(
155+
"completion_test",
156+
TaskType.COMPLETION,
157+
AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"),
158+
AlibabaCloudSearchCompletionTaskSettingsTests.getTaskSettingsMap(null),
159+
getSecretSettingsMap("secret")
160+
);
161+
162+
var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
163+
return new AlibabaCloudSearchCompletionAction(sender, model, serviceComponents);
164+
}
165+
}

0 commit comments

Comments
 (0)