|
27 | 27 | import org.elasticsearch.inference.MinimalServiceSettings; |
28 | 28 | import org.elasticsearch.inference.Model; |
29 | 29 | import org.elasticsearch.inference.TaskType; |
| 30 | +import org.elasticsearch.inference.UnifiedCompletionRequest; |
30 | 31 | import org.elasticsearch.test.ESTestCase; |
31 | 32 | import org.elasticsearch.test.http.MockResponse; |
32 | 33 | import org.elasticsearch.test.http.MockWebServer; |
33 | 34 | import org.elasticsearch.threadpool.ThreadPool; |
34 | 35 | import org.elasticsearch.xcontent.ToXContent; |
| 36 | +import org.elasticsearch.xcontent.XContentFactory; |
35 | 37 | import org.elasticsearch.xcontent.XContentType; |
36 | 38 | import org.elasticsearch.xpack.core.inference.action.InferenceAction; |
37 | 39 | import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; |
| 40 | +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; |
38 | 41 | import org.elasticsearch.xpack.core.ml.search.WeightedToken; |
39 | 42 | import org.elasticsearch.xpack.inference.external.http.HttpClientManager; |
40 | 43 | import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; |
|
44 | 47 | import org.elasticsearch.xpack.inference.logging.ThrottlerManager; |
45 | 48 | import org.elasticsearch.xpack.inference.registry.ModelRegistry; |
46 | 49 | import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; |
| 50 | +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; |
47 | 51 | import org.elasticsearch.xpack.inference.services.ServiceFields; |
48 | 52 | import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorization; |
49 | 53 | import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler; |
50 | 54 | import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationTests; |
| 55 | +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; |
| 56 | +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; |
51 | 57 | import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; |
| 58 | +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; |
52 | 59 | import org.hamcrest.MatcherAssert; |
53 | 60 | import org.hamcrest.Matchers; |
54 | 61 | import org.junit.After; |
|
61 | 68 | import java.util.Map; |
62 | 69 | import java.util.concurrent.TimeUnit; |
63 | 70 |
|
| 71 | +import static org.elasticsearch.ExceptionsHelper.unwrapCause; |
64 | 72 | import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; |
65 | 73 | import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; |
| 74 | +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; |
66 | 75 | import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; |
67 | 76 | import static org.elasticsearch.xpack.inference.Utils.getModelListenerForException; |
68 | 77 | import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; |
|
76 | 85 | import static org.hamcrest.CoreMatchers.is; |
77 | 86 | import static org.hamcrest.Matchers.equalTo; |
78 | 87 | import static org.hamcrest.Matchers.hasSize; |
| 88 | +import static org.hamcrest.Matchers.isA; |
79 | 89 | import static org.mockito.ArgumentMatchers.any; |
80 | 90 | import static org.mockito.Mockito.doAnswer; |
81 | 91 | import static org.mockito.Mockito.mock; |
@@ -949,6 +959,62 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo |
949 | 959 | } |
950 | 960 | } |
951 | 961 |
|
| 962 | + public void testUnifiedCompletionError() throws Exception { |
| 963 | + var eisGatewayUrl = getUrl(webServer); |
| 964 | + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); |
| 965 | + try (var service = createService(senderFactory, eisGatewayUrl)) { |
| 966 | + var responseJson = """ |
| 967 | + { |
| 968 | + "error": "The model `rainbow-sprinkles` does not exist or you do not have access to it." |
| 969 | + }"""; |
| 970 | + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); |
| 971 | + var model = new ElasticInferenceServiceCompletionModel( |
| 972 | + "id", |
| 973 | + TaskType.COMPLETION, |
| 974 | + "elastic", |
| 975 | + new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)), |
| 976 | + EmptyTaskSettings.INSTANCE, |
| 977 | + EmptySecretSettings.INSTANCE, |
| 978 | + new ElasticInferenceServiceComponents(eisGatewayUrl) |
| 979 | + ); |
| 980 | + PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); |
| 981 | + service.unifiedCompletionInfer( |
| 982 | + model, |
| 983 | + UnifiedCompletionRequest.of( |
| 984 | + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) |
| 985 | + ), |
| 986 | + InferenceAction.Request.DEFAULT_TIMEOUT, |
| 987 | + listener |
| 988 | + ); |
| 989 | + |
| 990 | + var result = listener.actionGet(TIMEOUT); |
| 991 | + |
| 992 | + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> { |
| 993 | + e = unwrapCause(e); |
| 994 | + assertThat(e, isA(UnifiedChatCompletionException.class)); |
| 995 | + try (var builder = XContentFactory.jsonBuilder()) { |
| 996 | + ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { |
| 997 | + try { |
| 998 | + xContent.toXContent(builder, EMPTY_PARAMS); |
| 999 | + } catch (IOException ex) { |
| 1000 | + throw new RuntimeException(ex); |
| 1001 | + } |
| 1002 | + }); |
| 1003 | + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); |
| 1004 | + |
| 1005 | + assertThat(json, is(""" |
| 1006 | + {\ |
| 1007 | + "error":{\ |
| 1008 | + "code":"not_found",\ |
| 1009 | + "message":"Received an unsuccessful status code for request from inference entity id [id] status \ |
| 1010 | + [404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\ |
| 1011 | + "type":"error"\ |
| 1012 | + }}""")); |
| 1013 | + } |
| 1014 | + }); |
| 1015 | + } |
| 1016 | + } |
| 1017 | + |
952 | 1018 | private ElasticInferenceService createServiceWithMockSender() { |
953 | 1019 | return createServiceWithMockSender(ElasticInferenceServiceAuthorizationTests.createEnabledAuth()); |
954 | 1020 | } |
|
0 commit comments