Skip to content

Commit aae528a

Browse files
Add HuggingFaceChatCompletionResponseHandler and associated tests
1 parent bd2e601 commit aae528a

File tree

4 files changed

+316
-5
lines changed

4 files changed

+316
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.huggingface;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.inference.InferenceServiceResults;
12+
import org.elasticsearch.rest.RestStatus;
13+
import org.elasticsearch.xcontent.ConstructingObjectParser;
14+
import org.elasticsearch.xcontent.ParseField;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xcontent.XContentParser;
17+
import org.elasticsearch.xcontent.XContentParserConfiguration;
18+
import org.elasticsearch.xcontent.XContentType;
19+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
20+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
21+
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
22+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
23+
import org.elasticsearch.xpack.inference.external.request.Request;
24+
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
25+
26+
import java.util.Locale;
27+
import java.util.Optional;
28+
import java.util.concurrent.Flow;
29+
30+
import static org.elasticsearch.core.Strings.format;
31+
32+
/**
33+
* Handles streaming chat completion responses and error parsing for Hugging Face inference endpoints.
34+
* Adapts the OpenAI handler to support Hugging Face's simpler error schema with fields like "message" and "http_status_code".
35+
*/
36+
public class HuggingFaceChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
37+
38+
@Override
39+
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
40+
return super.parseResult(request, flow);
41+
}
42+
43+
public HuggingFaceChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
44+
super(requestType, parseFunction, HuggingFaceErrorResponse::fromResponse);
45+
}
46+
47+
@Override
48+
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
49+
assert request.isStreaming() : "Only streaming requests support this format";
50+
var responseStatusCode = result.response().getStatusLine().getStatusCode();
51+
if (request.isStreaming()) {
52+
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
53+
var restStatus = toRestStatus(responseStatusCode);
54+
return errorResponse instanceof HuggingFaceErrorResponse huggingFaceErrorResponse
55+
? new UnifiedChatCompletionException(
56+
restStatus,
57+
errorMessage,
58+
createErrorType(errorResponse),
59+
extractErrorCode(huggingFaceErrorResponse)
60+
)
61+
: new UnifiedChatCompletionException(
62+
restStatus,
63+
errorMessage,
64+
createErrorType(errorResponse),
65+
restStatus.name().toLowerCase(Locale.ROOT)
66+
);
67+
} else {
68+
return super.buildError(message, request, result, errorResponse);
69+
}
70+
}
71+
72+
@Override
73+
protected Exception buildMidStreamError(Request request, String message, Exception e) {
74+
var errorResponse = HuggingFaceErrorResponse.fromString(message);
75+
if (errorResponse instanceof HuggingFaceErrorResponse huggingFaceErrorResponse) {
76+
return new UnifiedChatCompletionException(
77+
RestStatus.INTERNAL_SERVER_ERROR,
78+
format(
79+
"%s for request from inference entity id [%s]. Error message: [%s]",
80+
SERVER_ERROR_OBJECT,
81+
request.getInferenceEntityId(),
82+
errorResponse.getErrorMessage()
83+
),
84+
createErrorType(errorResponse),
85+
extractErrorCode(huggingFaceErrorResponse)
86+
);
87+
} else if (e != null) {
88+
return UnifiedChatCompletionException.fromThrowable(e);
89+
} else {
90+
return new UnifiedChatCompletionException(
91+
RestStatus.INTERNAL_SERVER_ERROR,
92+
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
93+
createErrorType(errorResponse),
94+
"stream_error"
95+
);
96+
}
97+
}
98+
99+
private static String extractErrorCode(HuggingFaceErrorResponse huggingFaceErrorResponse) {
100+
return huggingFaceErrorResponse.httpStatusCode() != null ? String.valueOf(huggingFaceErrorResponse.httpStatusCode()) : null;
101+
}
102+
103+
private static class HuggingFaceErrorResponse extends ErrorResponse {
104+
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
105+
"hugging_face_error",
106+
true,
107+
args -> Optional.ofNullable((HuggingFaceErrorResponse) args[0])
108+
);
109+
private static final ConstructingObjectParser<HuggingFaceErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
110+
"hugging_face_error",
111+
true,
112+
args -> new HuggingFaceErrorResponse((String) args[0], (Integer) args[1])
113+
);
114+
115+
static {
116+
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message"));
117+
ERROR_BODY_PARSER.declareIntOrNull(ConstructingObjectParser.optionalConstructorArg(), -1, new ParseField("http_status_code"));
118+
119+
ERROR_PARSER.declareObjectOrNull(
120+
ConstructingObjectParser.optionalConstructorArg(),
121+
ERROR_BODY_PARSER,
122+
null,
123+
new ParseField("error")
124+
);
125+
}
126+
127+
private static ErrorResponse fromResponse(HttpResult response) {
128+
try (
129+
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
130+
.createParser(XContentParserConfiguration.EMPTY, response.body())
131+
) {
132+
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
133+
} catch (Exception e) {
134+
// swallow the error
135+
}
136+
137+
return ErrorResponse.UNDEFINED_ERROR;
138+
}
139+
140+
private static ErrorResponse fromString(String response) {
141+
try (
142+
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
143+
.createParser(XContentParserConfiguration.EMPTY, response)
144+
) {
145+
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
146+
} catch (Exception e) {
147+
// swallow the error
148+
}
149+
150+
return ErrorResponse.UNDEFINED_ERROR;
151+
}
152+
153+
@Nullable
154+
private final Integer httpStatusCode;
155+
156+
HuggingFaceErrorResponse(String errorMessage, @Nullable Integer httpStatusCode) {
157+
super(errorMessage);
158+
this.httpStatusCode = httpStatusCode;
159+
}
160+
161+
@Nullable
162+
public Integer httpStatusCode() {
163+
return httpStatusCode;
164+
}
165+
166+
}
167+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
4141
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
4242
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
43-
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
4443
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
4544
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
4645
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@@ -68,7 +67,7 @@ public class HuggingFaceService extends HuggingFaceBaseService {
6867
TaskType.COMPLETION,
6968
TaskType.CHAT_COMPLETION
7069
);
71-
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler(
70+
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler(
7271
"hugging face chat completion",
7372
OpenAiChatCompletionResponseEntity::fromResponse
7473
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.Objects;
3030
import java.util.Optional;
3131
import java.util.concurrent.Flow;
32+
import java.util.function.Function;
3233

3334
import static org.elasticsearch.core.Strings.format;
3435

@@ -37,6 +38,14 @@ public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponsePa
3738
super(requestType, parseFunction, OpenAiErrorResponse::fromResponse);
3839
}
3940

41+
public OpenAiUnifiedChatCompletionResponseHandler(
42+
String requestType,
43+
ResponseParser parseFunction,
44+
Function<HttpResult, ErrorResponse> errorParseFunction
45+
) {
46+
super(requestType, parseFunction, errorParseFunction);
47+
}
48+
4049
@Override
4150
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
4251
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
@@ -59,15 +68,19 @@ protected Exception buildError(String message, Request request, HttpResult resul
5968
: new UnifiedChatCompletionException(
6069
restStatus,
6170
errorMessage,
62-
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
71+
createErrorType(errorResponse),
6372
restStatus.name().toLowerCase(Locale.ROOT)
6473
);
6574
} else {
6675
return super.buildError(message, request, result, errorResponse);
6776
}
6877
}
6978

70-
private static Exception buildMidStreamError(Request request, String message, Exception e) {
79+
protected static String createErrorType(ErrorResponse errorResponse) {
80+
return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown";
81+
}
82+
83+
protected Exception buildMidStreamError(Request request, String message, Exception e) {
7184
var errorResponse = OpenAiErrorResponse.fromString(message);
7285
if (errorResponse instanceof OpenAiErrorResponse oer) {
7386
return new UnifiedChatCompletionException(
@@ -88,7 +101,7 @@ private static Exception buildMidStreamError(Request request, String message, Ex
88101
return new UnifiedChatCompletionException(
89102
RestStatus.INTERNAL_SERVER_ERROR,
90103
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
91-
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
104+
createErrorType(errorResponse),
92105
"stream_error"
93106
);
94107
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.huggingface;
9+
10+
import org.apache.http.HttpResponse;
11+
import org.apache.http.StatusLine;
12+
import org.elasticsearch.common.bytes.BytesReference;
13+
import org.elasticsearch.common.xcontent.XContentHelper;
14+
import org.elasticsearch.test.ESTestCase;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
17+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
18+
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
19+
import org.elasticsearch.xpack.inference.external.request.Request;
20+
21+
import java.io.IOException;
22+
import java.nio.charset.StandardCharsets;
23+
24+
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
25+
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
26+
import static org.hamcrest.Matchers.is;
27+
import static org.hamcrest.Matchers.isA;
28+
import static org.mockito.Mockito.mock;
29+
import static org.mockito.Mockito.when;
30+
31+
public class HuggingFaceChatCompletionResponseHandlerTests extends ESTestCase {
32+
private final HuggingFaceChatCompletionResponseHandler responseHandler = new HuggingFaceChatCompletionResponseHandler(
33+
"chat completions",
34+
(a, b) -> mock()
35+
);
36+
37+
public void testFailValidationWithAllFields() throws IOException {
38+
var responseJson = """
39+
{
40+
"error": {
41+
"message": "a message",
42+
"http_status_code": 422
43+
}
44+
}
45+
""";
46+
47+
var errorJson = invalidResponseJson(responseJson);
48+
49+
assertThat(errorJson, is("""
50+
{"error":{"code":"422","message":"Received a server error status code for request from inference entity id [id] status [500]. \
51+
Error message: [a message]","type":"HuggingFaceErrorResponse"}}"""));
52+
}
53+
54+
public void testFailValidationWithoutOptionalFields() throws IOException {
55+
var responseJson = """
56+
{
57+
"error": {
58+
"message": "a message"
59+
}
60+
}
61+
""";
62+
63+
var errorJson = invalidResponseJson(responseJson);
64+
65+
assertThat(errorJson, is("""
66+
{"error":{"message":"Received a server error status code for request from inference entity id [id] status [500]. \
67+
Error message: [a message]","type":"HuggingFaceErrorResponse"}}"""));
68+
}
69+
70+
public void testFailValidationWithInvalidJson() throws IOException {
71+
var responseJson = """
72+
what? this isn't a json
73+
""";
74+
75+
var errorJson = invalidResponseJson(responseJson);
76+
77+
assertThat(errorJson, is("""
78+
{"error":{"code":"bad_request","message":"Received a server error status code for request from inference entity id [id] status\
79+
[500]","type":"ErrorResponse"}}"""));
80+
}
81+
82+
private String invalidResponseJson(String responseJson) throws IOException {
83+
var exception = invalidResponse(responseJson);
84+
assertThat(exception, isA(RetryException.class));
85+
assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class));
86+
return toJson((UnifiedChatCompletionException) unwrapCause(exception));
87+
}
88+
89+
private Exception invalidResponse(String responseJson) {
90+
return expectThrows(
91+
RetryException.class,
92+
() -> responseHandler.validateResponse(
93+
mock(),
94+
mock(),
95+
mockRequest(),
96+
new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)),
97+
true
98+
)
99+
);
100+
}
101+
102+
private static Request mockRequest() {
103+
var request = mock(Request.class);
104+
when(request.getInferenceEntityId()).thenReturn("id");
105+
when(request.isStreaming()).thenReturn(true);
106+
return request;
107+
}
108+
109+
private static HttpResponse mock500Response() {
110+
int statusCode = 500;
111+
var statusLine = mock(StatusLine.class);
112+
when(statusLine.getStatusCode()).thenReturn(statusCode);
113+
114+
var response = mock(HttpResponse.class);
115+
when(response.getStatusLine()).thenReturn(statusLine);
116+
117+
return response;
118+
}
119+
120+
private String toJson(UnifiedChatCompletionException e) throws IOException {
121+
try (var builder = XContentFactory.jsonBuilder()) {
122+
e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
123+
try {
124+
xContent.toXContent(builder, EMPTY_PARAMS);
125+
} catch (IOException ex) {
126+
throw new RuntimeException(ex);
127+
}
128+
});
129+
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
130+
}
131+
}
132+
}

0 commit comments

Comments
 (0)