| 
 | 1 | +/*  | 
 | 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one  | 
 | 3 | + * or more contributor license agreements. Licensed under the Elastic License  | 
 | 4 | + * 2.0; you may not use this file except in compliance with the Elastic License  | 
 | 5 | + * 2.0.  | 
 | 6 | + */  | 
 | 7 | + | 
 | 8 | +package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;  | 
 | 9 | + | 
 | 10 | +import org.elasticsearch.ElasticsearchException;  | 
 | 11 | +import org.elasticsearch.ElasticsearchStatusException;  | 
 | 12 | +import org.elasticsearch.action.ActionListener;  | 
 | 13 | +import org.elasticsearch.action.support.PlainActionFuture;  | 
 | 14 | +import org.elasticsearch.common.settings.Settings;  | 
 | 15 | +import org.elasticsearch.core.TimeValue;  | 
 | 16 | +import org.elasticsearch.inference.InferenceServiceResults;  | 
 | 17 | +import org.elasticsearch.inference.InputType;  | 
 | 18 | +import org.elasticsearch.inference.TaskType;  | 
 | 19 | +import org.elasticsearch.rest.RestStatus;  | 
 | 20 | +import org.elasticsearch.test.ESTestCase;  | 
 | 21 | +import org.elasticsearch.test.http.MockWebServer;  | 
 | 22 | +import org.elasticsearch.threadpool.ThreadPool;  | 
 | 23 | +import org.elasticsearch.xpack.core.inference.action.InferenceAction;  | 
 | 24 | +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;  | 
 | 25 | +import org.elasticsearch.xpack.inference.external.action.ExecutableAction;  | 
 | 26 | +import org.elasticsearch.xpack.inference.external.http.HttpClientManager;  | 
 | 27 | +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;  | 
 | 28 | +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;  | 
 | 29 | +import org.elasticsearch.xpack.inference.external.http.sender.Sender;  | 
 | 30 | +import org.elasticsearch.xpack.inference.logging.ThrottlerManager;  | 
 | 31 | +import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;  | 
 | 32 | +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests;  | 
 | 33 | +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests;  | 
 | 34 | +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests;  | 
 | 35 | +import org.junit.After;  | 
 | 36 | +import org.junit.Before;  | 
 | 37 | + | 
 | 38 | +import java.io.IOException;  | 
 | 39 | +import java.util.List;  | 
 | 40 | +import java.util.concurrent.TimeUnit;  | 
 | 41 | + | 
 | 42 | +import static org.apache.lucene.tests.util.LuceneTestCase.expectThrows;  | 
 | 43 | +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;  | 
 | 44 | +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;  | 
 | 45 | +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;  | 
 | 46 | +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;  | 
 | 47 | +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;  | 
 | 48 | +import static org.hamcrest.MatcherAssert.assertThat;  | 
 | 49 | +import static org.hamcrest.Matchers.is;  | 
 | 50 | +import static org.mockito.ArgumentMatchers.any;  | 
 | 51 | +import static org.mockito.Mockito.doAnswer;  | 
 | 52 | +import static org.mockito.Mockito.doThrow;  | 
 | 53 | +import static org.mockito.Mockito.mock;  | 
 | 54 | + | 
 | 55 | +public class AlibabaCloudSearchCompletionActionTests extends ESTestCase {  | 
 | 56 | + | 
 | 57 | +    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);  | 
 | 58 | +    private final MockWebServer webServer = new MockWebServer();  | 
 | 59 | +    private ThreadPool threadPool;  | 
 | 60 | +    private HttpClientManager clientManager;  | 
 | 61 | + | 
 | 62 | +    @Before  | 
 | 63 | +    public void init() throws IOException {  | 
 | 64 | +        webServer.start();  | 
 | 65 | +        threadPool = createThreadPool(inferenceUtilityPool());  | 
 | 66 | +        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));  | 
 | 67 | +    }  | 
 | 68 | + | 
 | 69 | +    @After  | 
 | 70 | +    public void shutdown() throws IOException {  | 
 | 71 | +        clientManager.close();  | 
 | 72 | +        terminate(threadPool);  | 
 | 73 | +        webServer.close();  | 
 | 74 | +    }  | 
 | 75 | + | 
 | 76 | +    public void testExecute_Success() {  | 
 | 77 | +        var sender = mock(Sender.class);  | 
 | 78 | + | 
 | 79 | +        var resultString = randomAlphaOfLength(100);  | 
 | 80 | +        doAnswer(invocation -> {  | 
 | 81 | +            ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);  | 
 | 82 | +            listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result(resultString))));  | 
 | 83 | + | 
 | 84 | +            return Void.TYPE;  | 
 | 85 | +        }).when(sender).send(any(), any(), any(), any());  | 
 | 86 | +        var action = createAction(threadPool, sender);  | 
 | 87 | + | 
 | 88 | +        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();  | 
 | 89 | +        action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);  | 
 | 90 | + | 
 | 91 | +        var result = listener.actionGet(TIMEOUT);  | 
 | 92 | +        assertThat(result.asMap(), is(buildExpectationCompletion(List.of(resultString))));  | 
 | 93 | +    }  | 
 | 94 | + | 
 | 95 | +    public void testExecute_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() {  | 
 | 96 | +        var sender = mock(Sender.class);  | 
 | 97 | +        doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any());  | 
 | 98 | +        var action = createAction(threadPool, sender);  | 
 | 99 | + | 
 | 100 | +        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();  | 
 | 101 | +        action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);  | 
 | 102 | + | 
 | 103 | +        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));  | 
 | 104 | +        assertThat(thrownException.getMessage(), is("error"));  | 
 | 105 | +    }  | 
 | 106 | + | 
 | 107 | +    public void testExecute_ListenerThrowsInternalServerError_WhenSenderThrowsException() {  | 
 | 108 | +        var sender = mock(Sender.class);  | 
 | 109 | +        doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any());  | 
 | 110 | +        var action = createAction(threadPool, sender);  | 
 | 111 | + | 
 | 112 | +        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();  | 
 | 113 | +        action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);  | 
 | 114 | + | 
 | 115 | +        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));  | 
 | 116 | +        assertThat(thrownException.getMessage(), is(constructFailedToSendRequestMessage("AlibabaCloud Search completion")));  | 
 | 117 | +    }  | 
 | 118 | + | 
 | 119 | +    public void testExecute_ThrowsIllegalArgumentException_WhenInputIsNotChatCompletionInput() {  | 
 | 120 | +        var action = createAction(threadPool, mock(Sender.class));  | 
 | 121 | + | 
 | 122 | +        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();  | 
 | 123 | +        assertThrows(IllegalArgumentException.class, () -> {  | 
 | 124 | +            action.execute(  | 
 | 125 | +                new EmbeddingsInput(List.of(randomAlphaOfLength(10)), InputType.INGEST),  | 
 | 126 | +                InferenceAction.Request.DEFAULT_TIMEOUT,  | 
 | 127 | +                listener  | 
 | 128 | +            );  | 
 | 129 | +        });  | 
 | 130 | +    }  | 
 | 131 | + | 
 | 132 | +    public void testExecute_ListenerThrowsElasticsearchStatusException_WhenInputSizeIsEven() {  | 
 | 133 | +        var action = createAction(threadPool, mock(Sender.class));  | 
 | 134 | + | 
 | 135 | +        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();  | 
 | 136 | +        action.execute(  | 
 | 137 | +            new ChatCompletionInput(List.of(randomAlphaOfLength(10), randomAlphaOfLength(10))),  | 
 | 138 | +            InferenceAction.Request.DEFAULT_TIMEOUT,  | 
 | 139 | +            listener  | 
 | 140 | +        );  | 
 | 141 | + | 
 | 142 | +        var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));  | 
 | 143 | +        assertThat(  | 
 | 144 | +            thrownException.getMessage(),  | 
 | 145 | +            is(  | 
 | 146 | +                "Alibaba Completion's inputs must be an odd number. The last input is the current query, "  | 
 | 147 | +                    + "all preceding inputs are the completion history as pairs of user input and the assistant's response."  | 
 | 148 | +            )  | 
 | 149 | +        );  | 
 | 150 | +        assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));  | 
 | 151 | +    }  | 
 | 152 | + | 
 | 153 | +    private ExecutableAction createAction(ThreadPool threadPool, Sender sender) {  | 
 | 154 | +        var model = AlibabaCloudSearchCompletionModelTests.createModel(  | 
 | 155 | +            "completion_test",  | 
 | 156 | +            TaskType.COMPLETION,  | 
 | 157 | +            AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"),  | 
 | 158 | +            AlibabaCloudSearchCompletionTaskSettingsTests.getTaskSettingsMap(null),  | 
 | 159 | +            getSecretSettingsMap("secret")  | 
 | 160 | +        );  | 
 | 161 | + | 
 | 162 | +        var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);  | 
 | 163 | +        return new AlibabaCloudSearchCompletionAction(sender, model, serviceComponents);  | 
 | 164 | +    }  | 
 | 165 | +}  | 
0 commit comments