| 
 | 1 | +/*  | 
 | 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one  | 
 | 3 | + * or more contributor license agreements. Licensed under the Elastic License  | 
 | 4 | + * 2.0; you may not use this file except in compliance with the Elastic License  | 
 | 5 | + * 2.0.  | 
 | 6 | + */  | 
 | 7 | + | 
 | 8 | +package org.elasticsearch.xpack.inference.services.llama.action;  | 
 | 9 | + | 
 | 10 | +import org.apache.http.HttpHeaders;  | 
 | 11 | +import org.elasticsearch.ElasticsearchException;  | 
 | 12 | +import org.elasticsearch.action.support.PlainActionFuture;  | 
 | 13 | +import org.elasticsearch.common.settings.Settings;  | 
 | 14 | +import org.elasticsearch.core.TimeValue;  | 
 | 15 | +import org.elasticsearch.inference.InferenceServiceResults;  | 
 | 16 | +import org.elasticsearch.test.ESTestCase;  | 
 | 17 | +import org.elasticsearch.test.http.MockResponse;  | 
 | 18 | +import org.elasticsearch.test.http.MockWebServer;  | 
 | 19 | +import org.elasticsearch.threadpool.ThreadPool;  | 
 | 20 | +import org.elasticsearch.xcontent.XContentType;  | 
 | 21 | +import org.elasticsearch.xpack.core.inference.action.InferenceAction;  | 
 | 22 | +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;  | 
 | 23 | +import org.elasticsearch.xpack.inference.InputTypeTests;  | 
 | 24 | +import org.elasticsearch.xpack.inference.common.TruncatorTests;  | 
 | 25 | +import org.elasticsearch.xpack.inference.external.http.HttpClientManager;  | 
 | 26 | +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;  | 
 | 27 | +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;  | 
 | 28 | +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;  | 
 | 29 | +import org.elasticsearch.xpack.inference.external.http.sender.Sender;  | 
 | 30 | +import org.elasticsearch.xpack.inference.logging.ThrottlerManager;  | 
 | 31 | +import org.elasticsearch.xpack.inference.services.ServiceComponents;  | 
 | 32 | +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests;  | 
 | 33 | +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingModelTests;  | 
 | 34 | +import org.junit.After;  | 
 | 35 | +import org.junit.Before;  | 
 | 36 | + | 
 | 37 | +import java.io.IOException;  | 
 | 38 | +import java.util.List;  | 
 | 39 | +import java.util.Map;  | 
 | 40 | +import java.util.concurrent.TimeUnit;  | 
 | 41 | + | 
 | 42 | +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;  | 
 | 43 | +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;  | 
 | 44 | +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;  | 
 | 45 | +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;  | 
 | 46 | +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;  | 
 | 47 | +import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;  | 
 | 48 | +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;  | 
 | 49 | +import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;  | 
 | 50 | +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;  | 
 | 51 | +import static org.hamcrest.Matchers.contains;  | 
 | 52 | +import static org.hamcrest.Matchers.equalTo;  | 
 | 53 | +import static org.hamcrest.Matchers.hasSize;  | 
 | 54 | +import static org.hamcrest.Matchers.instanceOf;  | 
 | 55 | +import static org.hamcrest.Matchers.is;  | 
 | 56 | +import static org.mockito.Mockito.mock;  | 
 | 57 | + | 
 | 58 | +public class LlamaActionCreatorTests extends ESTestCase {  | 
 | 59 | + | 
 | 60 | +    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);  | 
 | 61 | +    private final MockWebServer webServer = new MockWebServer();  | 
 | 62 | +    private ThreadPool threadPool;  | 
 | 63 | +    private HttpClientManager clientManager;  | 
 | 64 | + | 
 | 65 | +    @Before  | 
 | 66 | +    public void init() throws Exception {  | 
 | 67 | +        webServer.start();  | 
 | 68 | +        threadPool = createThreadPool(inferenceUtilityPool());  | 
 | 69 | +        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));  | 
 | 70 | +    }  | 
 | 71 | + | 
 | 72 | +    @After  | 
 | 73 | +    public void shutdown() throws IOException {  | 
 | 74 | +        clientManager.close();  | 
 | 75 | +        terminate(threadPool);  | 
 | 76 | +        webServer.close();  | 
 | 77 | +    }  | 
 | 78 | + | 
 | 79 | +    public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException {  | 
 | 80 | +        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);  | 
 | 81 | + | 
 | 82 | +        try (var sender = createSender(senderFactory)) {  | 
 | 83 | +            sender.start();  | 
 | 84 | + | 
 | 85 | +            String responseJson = """  | 
 | 86 | +                {  | 
 | 87 | +                    "embeddings": [  | 
 | 88 | +                        [  | 
 | 89 | +                            -0.0123,  | 
 | 90 | +                            0.123  | 
 | 91 | +                        ]  | 
 | 92 | +                    ]  | 
 | 93 | +                {  | 
 | 94 | +                """;  | 
 | 95 | +            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));  | 
 | 96 | + | 
 | 97 | +            PlainActionFuture<InferenceServiceResults> listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool));  | 
 | 98 | + | 
 | 99 | +            var result = listener.actionGet(TIMEOUT);  | 
 | 100 | + | 
 | 101 | +            assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));  | 
 | 102 | + | 
 | 103 | +            assertEmbeddingsRequest();  | 
 | 104 | +        }  | 
 | 105 | +    }  | 
 | 106 | + | 
 | 107 | +    public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws IOException {  | 
 | 108 | +        var settings = buildSettingsWithRetryFields(  | 
 | 109 | +            TimeValue.timeValueMillis(1),  | 
 | 110 | +            TimeValue.timeValueMinutes(1),  | 
 | 111 | +            TimeValue.timeValueSeconds(0)  | 
 | 112 | +        );  | 
 | 113 | +        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);  | 
 | 114 | + | 
 | 115 | +        try (var sender = createSender(senderFactory)) {  | 
 | 116 | +            sender.start();  | 
 | 117 | + | 
 | 118 | +            String responseJson = """  | 
 | 119 | +                [  | 
 | 120 | +                    {  | 
 | 121 | +                        "embeddings": [  | 
 | 122 | +                            [  | 
 | 123 | +                                -0.0123,  | 
 | 124 | +                                0.123  | 
 | 125 | +                            ]  | 
 | 126 | +                        ]  | 
 | 127 | +                    {  | 
 | 128 | +                ]  | 
 | 129 | +                """;  | 
 | 130 | +            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));  | 
 | 131 | + | 
 | 132 | +            PlainActionFuture<InferenceServiceResults> listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool));  | 
 | 133 | + | 
 | 134 | +            var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));  | 
 | 135 | +            assertThat(  | 
 | 136 | +                thrownException.getMessage(),  | 
 | 137 | +                is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]")  | 
 | 138 | +            );  | 
 | 139 | + | 
 | 140 | +            assertEmbeddingsRequest();  | 
 | 141 | +        }  | 
 | 142 | +    }  | 
 | 143 | + | 
 | 144 | +    public void testExecute_ReturnsSuccessfulResponse_ForCompletionAction() throws IOException {  | 
 | 145 | +        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);  | 
 | 146 | + | 
 | 147 | +        try (var sender = createSender(senderFactory)) {  | 
 | 148 | +            sender.start();  | 
 | 149 | + | 
 | 150 | +            String responseJson = """  | 
 | 151 | +                {  | 
 | 152 | +                    "id": "chatcmpl-03e70a75-efb6-447d-b661-e5ed0bd59ce9",  | 
 | 153 | +                    "choices": [  | 
 | 154 | +                        {  | 
 | 155 | +                            "finish_reason": "length",  | 
 | 156 | +                            "index": 0,  | 
 | 157 | +                            "logprobs": null,  | 
 | 158 | +                            "message": {  | 
 | 159 | +                                "content": "Hello there, how may I assist you today?",  | 
 | 160 | +                                "refusal": null,  | 
 | 161 | +                                "role": "assistant",  | 
 | 162 | +                                "annotations": null,  | 
 | 163 | +                                "audio": null,  | 
 | 164 | +                                "function_call": null,  | 
 | 165 | +                                "tool_calls": null  | 
 | 166 | +                            }  | 
 | 167 | +                        }  | 
 | 168 | +                    ],  | 
 | 169 | +                    "created": 1750157476,  | 
 | 170 | +                    "model": "llama3.2:3b",  | 
 | 171 | +                    "object": "chat.completion",  | 
 | 172 | +                    "service_tier": null,  | 
 | 173 | +                    "system_fingerprint": "fp_ollama",  | 
 | 174 | +                    "usage": {  | 
 | 175 | +                        "completion_tokens": 10,  | 
 | 176 | +                        "prompt_tokens": 30,  | 
 | 177 | +                        "total_tokens": 40,  | 
 | 178 | +                        "completion_tokens_details": null,  | 
 | 179 | +                        "prompt_tokens_details": null  | 
 | 180 | +                    }  | 
 | 181 | +                }  | 
 | 182 | +                """;  | 
 | 183 | +            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));  | 
 | 184 | + | 
 | 185 | +            PlainActionFuture<InferenceServiceResults> listener = createCompletionFuture(sender, createWithEmptySettings(threadPool));  | 
 | 186 | + | 
 | 187 | +            var result = listener.actionGet(TIMEOUT);  | 
 | 188 | + | 
 | 189 | +            assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?"))));  | 
 | 190 | + | 
 | 191 | +            assertCompletionRequest();  | 
 | 192 | +        }  | 
 | 193 | +    }  | 
 | 194 | + | 
 | 195 | +    public void testExecute_FailsFromInvalidResponseFormat_ForCompletionAction() throws IOException {  | 
 | 196 | +        var settings = buildSettingsWithRetryFields(  | 
 | 197 | +            TimeValue.timeValueMillis(1),  | 
 | 198 | +            TimeValue.timeValueMinutes(1),  | 
 | 199 | +            TimeValue.timeValueSeconds(0)  | 
 | 200 | +        );  | 
 | 201 | +        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);  | 
 | 202 | + | 
 | 203 | +        try (var sender = createSender(senderFactory)) {  | 
 | 204 | +            sender.start();  | 
 | 205 | + | 
 | 206 | +            String responseJson = """  | 
 | 207 | +                {  | 
 | 208 | +                    "invalid_field": "unexpected"  | 
 | 209 | +                }  | 
 | 210 | +                """;  | 
 | 211 | +            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));  | 
 | 212 | + | 
 | 213 | +            PlainActionFuture<InferenceServiceResults> listener = createCompletionFuture(  | 
 | 214 | +                sender,  | 
 | 215 | +                new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())  | 
 | 216 | +            );  | 
 | 217 | + | 
 | 218 | +            var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));  | 
 | 219 | +            assertThat(  | 
 | 220 | +                thrownException.getMessage(),  | 
 | 221 | +                is("Failed to send Llama completion request from inference entity id [id]. Cause: Required [choices]")  | 
 | 222 | +            );  | 
 | 223 | + | 
 | 224 | +            assertCompletionRequest();  | 
 | 225 | +        }  | 
 | 226 | +    }  | 
 | 227 | + | 
 | 228 | +    private PlainActionFuture<InferenceServiceResults> createEmbeddingsFuture(Sender sender, ServiceComponents threadPool) {  | 
 | 229 | +        var model = LlamaEmbeddingModelTests.createEmbeddingsModel("model", getUrl(webServer), "secret");  | 
 | 230 | +        var actionCreator = new LlamaActionCreator(sender, threadPool);  | 
 | 231 | +        var action = actionCreator.create(model);  | 
 | 232 | + | 
 | 233 | +        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();  | 
 | 234 | +        action.execute(  | 
 | 235 | +            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),  | 
 | 236 | +            InferenceAction.Request.DEFAULT_TIMEOUT,  | 
 | 237 | +            listener  | 
 | 238 | +        );  | 
 | 239 | +        return listener;  | 
 | 240 | +    }  | 
 | 241 | + | 
 | 242 | +    private PlainActionFuture<InferenceServiceResults> createCompletionFuture(Sender sender, ServiceComponents threadPool) {  | 
 | 243 | +        var model = LlamaChatCompletionModelTests.createCompletionModel("model", getUrl(webServer), "secret");  | 
 | 244 | +        var actionCreator = new LlamaActionCreator(sender, threadPool);  | 
 | 245 | +        var action = actionCreator.create(model);  | 
 | 246 | + | 
 | 247 | +        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();  | 
 | 248 | +        action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);  | 
 | 249 | +        return listener;  | 
 | 250 | +    }  | 
 | 251 | + | 
 | 252 | +    private void assertCompletionRequest() throws IOException {  | 
 | 253 | +        assertCommonRequestProperties();  | 
 | 254 | + | 
 | 255 | +        var requestMap = entityAsMap(webServer.requests().get(0).getBody());  | 
 | 256 | +        assertThat(requestMap.size(), is(4));  | 
 | 257 | +        assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello"))));  | 
 | 258 | +        assertThat(requestMap.get("model"), is("model"));  | 
 | 259 | +        assertThat(requestMap.get("n"), is(1));  | 
 | 260 | +        assertThat(requestMap.get("stream"), is(false));  | 
 | 261 | +    }  | 
 | 262 | + | 
 | 263 | +    @SuppressWarnings("unchecked")  | 
 | 264 | +    private void assertEmbeddingsRequest() throws IOException {  | 
 | 265 | +        assertCommonRequestProperties();  | 
 | 266 | + | 
 | 267 | +        var requestMap = entityAsMap(webServer.requests().get(0).getBody());  | 
 | 268 | +        assertThat(requestMap.size(), is(2));  | 
 | 269 | +        assertThat(requestMap.get("contents"), instanceOf(List.class));  | 
 | 270 | +        var inputList = (List<String>) requestMap.get("contents");  | 
 | 271 | +        assertThat(inputList, contains("abc"));  | 
 | 272 | +    }  | 
 | 273 | + | 
 | 274 | +    private void assertCommonRequestProperties() {  | 
 | 275 | +        assertThat(webServer.requests(), hasSize(1));  | 
 | 276 | +        assertNull(webServer.requests().get(0).getUri().getQuery());  | 
 | 277 | +        assertThat(  | 
 | 278 | +            webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),  | 
 | 279 | +            equalTo(XContentType.JSON.mediaTypeWithoutParameters())  | 
 | 280 | +        );  | 
 | 281 | +        assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));  | 
 | 282 | +    }  | 
 | 283 | +}  | 
0 commit comments