Skip to content

Commit 8bdd23e

Browse files
committed
address comments
1 parent 924253a commit 8bdd23e

File tree

5 files changed

+107
-69
lines changed

5 files changed

+107
-69
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/UnifiedChatCompletionException.java

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12-
import org.elasticsearch.ElasticsearchWrapperException;
1312
import org.elasticsearch.ExceptionsHelper;
1413
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
1514
import org.elasticsearch.core.Nullable;
@@ -18,6 +17,7 @@
1817

1918
import java.util.Iterator;
2019
import java.util.Locale;
20+
import java.util.Objects;
2121

2222
import static java.util.Collections.emptyIterator;
2323
import static org.elasticsearch.ExceptionsHelper.maybeError;
@@ -41,8 +41,8 @@ public UnifiedChatCompletionException(RestStatus status, String message, String
4141

4242
public UnifiedChatCompletionException(RestStatus status, String message, String type, @Nullable String code, @Nullable String param) {
4343
super(message, status);
44-
this.message = message;
45-
this.type = type;
44+
this.message = Objects.requireNonNull(message);
45+
this.type = Objects.requireNonNull(type);
4646
this.code = code;
4747
this.param = param;
4848
}
@@ -56,8 +56,8 @@ public UnifiedChatCompletionException(
5656
@Nullable String param
5757
) {
5858
super(message, cause, status);
59-
this.message = message;
60-
this.type = type;
59+
this.message = Objects.requireNonNull(message);
60+
this.type = Objects.requireNonNull(type);
6161
this.code = code;
6262
this.param = param;
6363
}
@@ -85,9 +85,7 @@ private static Iterator<ToXContent> optionalField(String key, String value) {
8585
}
8686

8787
public static UnifiedChatCompletionException fromThrowable(Throwable t) {
88-
if (t instanceof UnifiedChatCompletionException e) {
89-
return e;
90-
} else if (unwrapCause(t) instanceof UnifiedChatCompletionException e) {
88+
if (ExceptionsHelper.unwrapCause(t) instanceof UnifiedChatCompletionException e) {
9189
return e;
9290
} else {
9391
return maybeError(t).map(error -> {
@@ -116,25 +114,4 @@ public static UnifiedChatCompletionException fromThrowable(Throwable t) {
116114
});
117115
}
118116
}
119-
120-
private static Throwable unwrapCause(Throwable t) {
121-
int counter = 0;
122-
Throwable result = t;
123-
while (result instanceof ElasticsearchWrapperException) {
124-
if (result instanceof UnifiedChatCompletionException) {
125-
return result;
126-
}
127-
if (result.getCause() == null) {
128-
return result;
129-
}
130-
if (result.getCause() == result) {
131-
return result;
132-
}
133-
if (counter++ > 10) {
134-
return result;
135-
}
136-
result = result.getCause();
137-
}
138-
return result;
139-
}
140117
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99

1010
import org.elasticsearch.inference.InferenceServiceResults;
1111
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
12+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
1213
import org.elasticsearch.xpack.inference.external.http.HttpResult;
14+
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
1315
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1416
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor;
1517
import org.elasticsearch.xpack.inference.external.request.Request;
1618
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
1719
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
1820

21+
import java.util.Locale;
1922
import java.util.concurrent.Flow;
2023

2124
public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
@@ -37,4 +40,21 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR
3740
serverSentEventProcessor.subscribe(openAiProcessor);
3841
return new StreamingUnifiedChatCompletionResults(openAiProcessor);
3942
}
43+
44+
@Override
45+
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
46+
assert request.isStreaming() : "Only streaming requests support this format";
47+
var responseStatusCode = result.response().getStatusLine().getStatusCode();
48+
if (request.isStreaming()) {
49+
var restStatus = toRestStatus(responseStatusCode);
50+
return new UnifiedChatCompletionException(
51+
restStatus,
52+
errorMessage(message, request, result, errorResponse, responseStatusCode),
53+
"error",
54+
restStatus.name().toLowerCase(Locale.ROOT)
55+
);
56+
} else {
57+
return super.buildError(message, request, result, errorResponse);
58+
}
59+
}
4060
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -91,31 +91,24 @@ protected Exception buildError(String message, Request request, HttpResult resul
9191

9292
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
9393
var responseStatusCode = result.response().getStatusLine().getStatusCode();
94+
return new ElasticsearchStatusException(
95+
errorMessage(message, request, result, errorResponse, responseStatusCode),
96+
toRestStatus(responseStatusCode)
97+
);
98+
}
9499

95-
if (errorResponse == null
100+
protected String errorMessage(String message, Request request, HttpResult result, ErrorResponse errorResponse, int statusCode) {
101+
return (errorResponse == null
96102
|| errorResponse.errorStructureFound() == false
97-
|| Strings.isNullOrEmpty(errorResponse.getErrorMessage())) {
98-
return new ElasticsearchStatusException(
99-
format(
100-
"%s for request from inference entity id [%s] status [%s]",
103+
|| Strings.isNullOrEmpty(errorResponse.getErrorMessage()))
104+
? format("%s for request from inference entity id [%s] status [%s]", message, request.getInferenceEntityId(), statusCode)
105+
: format(
106+
"%s for request from inference entity id [%s] status [%s]. Error message: [%s]",
101107
message,
102108
request.getInferenceEntityId(),
103-
responseStatusCode
104-
),
105-
toRestStatus(responseStatusCode)
106-
);
107-
}
108-
109-
return new ElasticsearchStatusException(
110-
format(
111-
"%s for request from inference entity id [%s] status [%s]. Error message: [%s]",
112-
message,
113-
request.getInferenceEntityId(),
114-
responseStatusCode,
115-
errorResponse.getErrorMessage()
116-
),
117-
toRestStatus(responseStatusCode)
118-
);
109+
statusCode,
110+
errorResponse.getErrorMessage()
111+
);
119112
}
120113

121114
public static RestStatus toRestStatus(int statusCode) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.external.openai;
99

10-
import org.elasticsearch.common.Strings;
1110
import org.elasticsearch.core.Nullable;
1211
import org.elasticsearch.inference.InferenceServiceResults;
1312
import org.elasticsearch.xcontent.ConstructingObjectParser;
@@ -30,8 +29,6 @@
3029
import java.util.Optional;
3130
import java.util.concurrent.Flow;
3231

33-
import static org.elasticsearch.core.Strings.format;
34-
3532
public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
3633
public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
3734
super(requestType, parseFunction, OpenAiErrorResponse::fromResponse);
@@ -52,22 +49,7 @@ protected Exception buildError(String message, Request request, HttpResult resul
5249
assert request.isStreaming() : "Only streaming requests support this format";
5350
var responseStatusCode = result.response().getStatusLine().getStatusCode();
5451
if (request.isStreaming()) {
55-
var errorMessage = (errorResponse == null
56-
|| errorResponse.errorStructureFound() == false
57-
|| Strings.isNullOrEmpty(errorResponse.getErrorMessage()))
58-
? format(
59-
"%s for request from inference entity id [%s] status [%s]",
60-
message,
61-
request.getInferenceEntityId(),
62-
responseStatusCode
63-
)
64-
: format(
65-
"%s for request from inference entity id [%s] status [%s]. Error message: [%s]",
66-
message,
67-
request.getInferenceEntityId(),
68-
responseStatusCode,
69-
errorResponse.getErrorMessage()
70-
);
52+
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
7153
var restStatus = toRestStatus(responseStatusCode);
7254
return errorResponse instanceof OpenAiErrorResponse oer
7355
? new UnifiedChatCompletionException(restStatus, errorMessage, oer.type(), oer.code(), oer.param())

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,17 @@
2727
import org.elasticsearch.inference.MinimalServiceSettings;
2828
import org.elasticsearch.inference.Model;
2929
import org.elasticsearch.inference.TaskType;
30+
import org.elasticsearch.inference.UnifiedCompletionRequest;
3031
import org.elasticsearch.test.ESTestCase;
3132
import org.elasticsearch.test.http.MockResponse;
3233
import org.elasticsearch.test.http.MockWebServer;
3334
import org.elasticsearch.threadpool.ThreadPool;
3435
import org.elasticsearch.xcontent.ToXContent;
36+
import org.elasticsearch.xcontent.XContentFactory;
3537
import org.elasticsearch.xcontent.XContentType;
3638
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
3739
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
40+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
3841
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
3942
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
4043
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -44,11 +47,15 @@
4447
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
4548
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
4649
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
50+
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
4751
import org.elasticsearch.xpack.inference.services.ServiceFields;
4852
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorization;
4953
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler;
5054
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationTests;
55+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
56+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
5157
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
58+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
5259
import org.hamcrest.MatcherAssert;
5360
import org.hamcrest.Matchers;
5461
import org.junit.After;
@@ -61,8 +68,10 @@
6168
import java.util.Map;
6269
import java.util.concurrent.TimeUnit;
6370

71+
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
6472
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
6573
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
74+
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
6675
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
6776
import static org.elasticsearch.xpack.inference.Utils.getModelListenerForException;
6877
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
@@ -76,6 +85,7 @@
7685
import static org.hamcrest.CoreMatchers.is;
7786
import static org.hamcrest.Matchers.equalTo;
7887
import static org.hamcrest.Matchers.hasSize;
88+
import static org.hamcrest.Matchers.isA;
7989
import static org.mockito.ArgumentMatchers.any;
8090
import static org.mockito.Mockito.doAnswer;
8191
import static org.mockito.Mockito.mock;
@@ -949,6 +959,62 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
949959
}
950960
}
951961

962+
public void testUnifiedCompletionError() throws Exception {
963+
var eisGatewayUrl = getUrl(webServer);
964+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
965+
try (var service = createService(senderFactory, eisGatewayUrl)) {
966+
var responseJson = """
967+
{
968+
"error": "The model `rainbow-sprinkles` does not exist or you do not have access to it."
969+
}""";
970+
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
971+
var model = new ElasticInferenceServiceCompletionModel(
972+
"id",
973+
TaskType.COMPLETION,
974+
"elastic",
975+
new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)),
976+
EmptyTaskSettings.INSTANCE,
977+
EmptySecretSettings.INSTANCE,
978+
new ElasticInferenceServiceComponents(eisGatewayUrl)
979+
);
980+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
981+
service.unifiedCompletionInfer(
982+
model,
983+
UnifiedCompletionRequest.of(
984+
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
985+
),
986+
InferenceAction.Request.DEFAULT_TIMEOUT,
987+
listener
988+
);
989+
990+
var result = listener.actionGet(TIMEOUT);
991+
992+
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> {
993+
e = unwrapCause(e);
994+
assertThat(e, isA(UnifiedChatCompletionException.class));
995+
try (var builder = XContentFactory.jsonBuilder()) {
996+
((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
997+
try {
998+
xContent.toXContent(builder, EMPTY_PARAMS);
999+
} catch (IOException ex) {
1000+
throw new RuntimeException(ex);
1001+
}
1002+
});
1003+
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
1004+
1005+
assertThat(json, is("""
1006+
{\
1007+
"error":{\
1008+
"code":"not_found",\
1009+
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
1010+
[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\
1011+
"type":"error"\
1012+
}}"""));
1013+
}
1014+
});
1015+
}
1016+
}
1017+
9521018
private ElasticInferenceService createServiceWithMockSender() {
9531019
return createServiceWithMockSender(ElasticInferenceServiceAuthorizationTests.createEnabledAuth());
9541020
}

0 commit comments

Comments
 (0)