| 
 | 1 | +/*  | 
 | 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one  | 
 | 3 | + * or more contributor license agreements. Licensed under the Elastic License  | 
 | 4 | + * 2.0; you may not use this file except in compliance with the Elastic License  | 
 | 5 | + * 2.0.  | 
 | 6 | + */  | 
 | 7 | + | 
 | 8 | +package org.elasticsearch.xpack.inference.integration;  | 
 | 9 | + | 
 | 10 | +import org.elasticsearch.ResourceNotFoundException;  | 
 | 11 | +import org.elasticsearch.action.support.PlainActionFuture;  | 
 | 12 | +import org.elasticsearch.common.settings.Settings;  | 
 | 13 | +import org.elasticsearch.core.TimeValue;  | 
 | 14 | +import org.elasticsearch.inference.InferenceService;  | 
 | 15 | +import org.elasticsearch.inference.MinimalServiceSettings;  | 
 | 16 | +import org.elasticsearch.inference.Model;  | 
 | 17 | +import org.elasticsearch.inference.TaskType;  | 
 | 18 | +import org.elasticsearch.inference.UnparsedModel;  | 
 | 19 | +import org.elasticsearch.plugins.Plugin;  | 
 | 20 | +import org.elasticsearch.reindex.ReindexPlugin;  | 
 | 21 | +import org.elasticsearch.test.ESSingleNodeTestCase;  | 
 | 22 | +import org.elasticsearch.test.http.MockResponse;  | 
 | 23 | +import org.elasticsearch.test.http.MockWebServer;  | 
 | 24 | +import org.elasticsearch.threadpool.ThreadPool;  | 
 | 25 | +import org.elasticsearch.xpack.inference.external.http.HttpClientManager;  | 
 | 26 | +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;  | 
 | 27 | +import org.elasticsearch.xpack.inference.logging.ThrottlerManager;  | 
 | 28 | +import org.elasticsearch.xpack.inference.registry.ModelRegistry;  | 
 | 29 | +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;  | 
 | 30 | +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;  | 
 | 31 | +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler;  | 
 | 32 | +import org.junit.After;  | 
 | 33 | +import org.junit.Before;  | 
 | 34 | + | 
 | 35 | +import java.util.Collection;  | 
 | 36 | +import java.util.EnumSet;  | 
 | 37 | +import java.util.List;  | 
 | 38 | +import java.util.concurrent.TimeUnit;  | 
 | 39 | + | 
 | 40 | +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;  | 
 | 41 | +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;  | 
 | 42 | +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;  | 
 | 43 | +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;  | 
 | 44 | +import static org.hamcrest.CoreMatchers.is;  | 
 | 45 | +import static org.mockito.Mockito.mock;  | 
 | 46 | + | 
 | 47 | +public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {  | 
 | 48 | +    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);  | 
 | 49 | + | 
 | 50 | +    private ModelRegistry modelRegistry;  | 
 | 51 | +    private final MockWebServer webServer = new MockWebServer();  | 
 | 52 | +    private ThreadPool threadPool;  | 
 | 53 | +    private String gatewayUrl;  | 
 | 54 | + | 
 | 55 | +    @Before  | 
 | 56 | +    public void createComponents() throws Exception {  | 
 | 57 | +        threadPool = createThreadPool(inferenceUtilityPool());  | 
 | 58 | +        webServer.start();  | 
 | 59 | +        gatewayUrl = getUrl(webServer);  | 
 | 60 | +        modelRegistry = new ModelRegistry(client());  | 
 | 61 | +    }  | 
 | 62 | + | 
 | 63 | +    @After  | 
 | 64 | +    public void shutdown() {  | 
 | 65 | +        terminate(threadPool);  | 
 | 66 | +        webServer.close();  | 
 | 67 | +    }  | 
 | 68 | + | 
 | 69 | +    @Override  | 
 | 70 | +    protected boolean resetNodeAfterTest() {  | 
 | 71 | +        return true;  | 
 | 72 | +    }  | 
 | 73 | + | 
 | 74 | +    @Override  | 
 | 75 | +    protected Collection<Class<? extends Plugin>> getPlugins() {  | 
 | 76 | +        return pluginList(ReindexPlugin.class);  | 
 | 77 | +    }  | 
 | 78 | + | 
 | 79 | +    public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception {  | 
 | 80 | +        String responseJson = """  | 
 | 81 | +            {  | 
 | 82 | +                "models": [  | 
 | 83 | +                    {  | 
 | 84 | +                      "model_name": "rainbow-sprinkles",  | 
 | 85 | +                      "task_types": ["chat"]  | 
 | 86 | +                    }  | 
 | 87 | +                ]  | 
 | 88 | +            }  | 
 | 89 | +            """;  | 
 | 90 | + | 
 | 91 | +        webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));  | 
 | 92 | + | 
 | 93 | +        try (var service = createElasticInferenceService()) {  | 
 | 94 | +            service.waitForAuthorizationToComplete(TIMEOUT);  | 
 | 95 | +            assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));  | 
 | 96 | +            assertThat(  | 
 | 97 | +                service.defaultConfigIds(),  | 
 | 98 | +                is(  | 
 | 99 | +                    List.of(  | 
 | 100 | +                        new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)  | 
 | 101 | +                    )  | 
 | 102 | +                )  | 
 | 103 | +            );  | 
 | 104 | +            assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));  | 
 | 105 | + | 
 | 106 | +            PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();  | 
 | 107 | +            service.defaultConfigs(listener);  | 
 | 108 | +            assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));  | 
 | 109 | +        }  | 
 | 110 | +    }  | 
 | 111 | + | 
 | 112 | +    public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() throws Exception {  | 
 | 113 | +        {  | 
 | 114 | +            String responseJson = """  | 
 | 115 | +                {  | 
 | 116 | +                    "models": [  | 
 | 117 | +                        {  | 
 | 118 | +                          "model_name": "rainbow-sprinkles",  | 
 | 119 | +                          "task_types": ["chat"]  | 
 | 120 | +                        }  | 
 | 121 | +                    ]  | 
 | 122 | +                }  | 
 | 123 | +                """;  | 
 | 124 | + | 
 | 125 | +            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));  | 
 | 126 | + | 
 | 127 | +            try (var service = createElasticInferenceService()) {  | 
 | 128 | +                service.waitForAuthorizationToComplete(TIMEOUT);  | 
 | 129 | +                assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));  | 
 | 130 | +                assertThat(  | 
 | 131 | +                    service.defaultConfigIds(),  | 
 | 132 | +                    is(  | 
 | 133 | +                        List.of(  | 
 | 134 | +                            new InferenceService.DefaultConfigId(  | 
 | 135 | +                                ".rainbow-sprinkles-elastic",  | 
 | 136 | +                                MinimalServiceSettings.chatCompletion(),  | 
 | 137 | +                                service  | 
 | 138 | +                            )  | 
 | 139 | +                        )  | 
 | 140 | +                    )  | 
 | 141 | +                );  | 
 | 142 | +                assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));  | 
 | 143 | + | 
 | 144 | +                PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();  | 
 | 145 | +                service.defaultConfigs(listener);  | 
 | 146 | +                assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));  | 
 | 147 | + | 
 | 148 | +                var getModelListener = new PlainActionFuture<UnparsedModel>();  | 
 | 149 | +                // persists the default endpoints  | 
 | 150 | +                modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);  | 
 | 151 | + | 
 | 152 | +                var inferenceEntity = getModelListener.actionGet(TIMEOUT);  | 
 | 153 | +                assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic"));  | 
 | 154 | +                assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION));  | 
 | 155 | +            }  | 
 | 156 | +        }  | 
 | 157 | +        {  | 
 | 158 | +            String noAuthorizationResponseJson = """  | 
 | 159 | +                {  | 
 | 160 | +                    "models": []  | 
 | 161 | +                }  | 
 | 162 | +                """;  | 
 | 163 | + | 
 | 164 | +            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));  | 
 | 165 | + | 
 | 166 | +            try (var service = createElasticInferenceService()) {  | 
 | 167 | +                service.waitForAuthorizationToComplete(TIMEOUT);  | 
 | 168 | +                assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));  | 
 | 169 | +                assertTrue(service.defaultConfigIds().isEmpty());  | 
 | 170 | +                assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));  | 
 | 171 | + | 
 | 172 | +                var getModelListener = new PlainActionFuture<UnparsedModel>();  | 
 | 173 | +                modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);  | 
 | 174 | + | 
 | 175 | +                var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));  | 
 | 176 | +                assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));  | 
 | 177 | +            }  | 
 | 178 | +        }  | 
 | 179 | +    }  | 
 | 180 | + | 
 | 181 | +    public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnAuthForIt() throws Exception {  | 
 | 182 | +        {  | 
 | 183 | +            String responseJson = """  | 
 | 184 | +                {  | 
 | 185 | +                    "models": [  | 
 | 186 | +                        {  | 
 | 187 | +                          "model_name": "rainbow-sprinkles",  | 
 | 188 | +                          "task_types": ["chat"]  | 
 | 189 | +                        },  | 
 | 190 | +                        {  | 
 | 191 | +                          "model_name": "elser-v2",  | 
 | 192 | +                          "task_types": ["embed/text/sparse"]  | 
 | 193 | +                        }  | 
 | 194 | +                    ]  | 
 | 195 | +                }  | 
 | 196 | +                """;  | 
 | 197 | + | 
 | 198 | +            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));  | 
 | 199 | + | 
 | 200 | +            try (var service = createElasticInferenceService()) {  | 
 | 201 | +                service.waitForAuthorizationToComplete(TIMEOUT);  | 
 | 202 | +                assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));  | 
 | 203 | +                assertThat(  | 
 | 204 | +                    service.defaultConfigIds(),  | 
 | 205 | +                    is(  | 
 | 206 | +                        List.of(  | 
 | 207 | +                            new InferenceService.DefaultConfigId(  | 
 | 208 | +                                ".rainbow-sprinkles-elastic",  | 
 | 209 | +                                MinimalServiceSettings.chatCompletion(),  | 
 | 210 | +                                service  | 
 | 211 | +                            )  | 
 | 212 | +                        )  | 
 | 213 | +                    )  | 
 | 214 | +                );  | 
 | 215 | +                assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));  | 
 | 216 | + | 
 | 217 | +                PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();  | 
 | 218 | +                service.defaultConfigs(listener);  | 
 | 219 | +                assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));  | 
 | 220 | + | 
 | 221 | +                var getModelListener = new PlainActionFuture<UnparsedModel>();  | 
 | 222 | +                // persists the default endpoints  | 
 | 223 | +                modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);  | 
 | 224 | + | 
 | 225 | +                var inferenceEntity = getModelListener.actionGet(TIMEOUT);  | 
 | 226 | +                assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic"));  | 
 | 227 | +                assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION));  | 
 | 228 | +            }  | 
 | 229 | +        }  | 
 | 230 | +        {  | 
 | 231 | +            String noAuthorizationResponseJson = """  | 
 | 232 | +                {  | 
 | 233 | +                    "models": [  | 
 | 234 | +                        {  | 
 | 235 | +                          "model_name": "elser-v2",  | 
 | 236 | +                          "task_types": ["embed/text/sparse"]  | 
 | 237 | +                        }  | 
 | 238 | +                    ]  | 
 | 239 | +                }  | 
 | 240 | +                """;  | 
 | 241 | + | 
 | 242 | +            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));  | 
 | 243 | + | 
 | 244 | +            try (var service = createElasticInferenceService()) {  | 
 | 245 | +                service.waitForAuthorizationToComplete(TIMEOUT);  | 
 | 246 | +                assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));  | 
 | 247 | +                assertTrue(service.defaultConfigIds().isEmpty());  | 
 | 248 | +                assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));  | 
 | 249 | + | 
 | 250 | +                var getModelListener = new PlainActionFuture<UnparsedModel>();  | 
 | 251 | +                modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);  | 
 | 252 | + | 
 | 253 | +                var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));  | 
 | 254 | +                assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));  | 
 | 255 | +            }  | 
 | 256 | +        }  | 
 | 257 | +    }  | 
 | 258 | + | 
 | 259 | +    private ElasticInferenceService createElasticInferenceService() {  | 
 | 260 | +        var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));  | 
 | 261 | +        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);  | 
 | 262 | + | 
 | 263 | +        return new ElasticInferenceService(  | 
 | 264 | +            senderFactory,  | 
 | 265 | +            createWithEmptySettings(threadPool),  | 
 | 266 | +            new ElasticInferenceServiceComponents(gatewayUrl),  | 
 | 267 | +            modelRegistry,  | 
 | 268 | +            new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)  | 
 | 269 | +        );  | 
 | 270 | +    }  | 
 | 271 | +}  | 
0 commit comments