Skip to content

Commit cdb3c1c

Browse files
Refactor HuggingFace error handling to improve response structure and add streaming support
1 parent 2fa3dff commit cdb3c1c

File tree

5 files changed

+145
-86
lines changed

5 files changed

+145
-86
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
package org.elasticsearch.xpack.inference.services.huggingface;
99

1010
import org.elasticsearch.core.Nullable;
11-
import org.elasticsearch.inference.InferenceServiceResults;
1211
import org.elasticsearch.rest.RestStatus;
1312
import org.elasticsearch.xcontent.ConstructingObjectParser;
1413
import org.elasticsearch.xcontent.ParseField;
@@ -21,11 +20,11 @@
2120
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
2221
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
2322
import org.elasticsearch.xpack.inference.external.request.Request;
23+
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceErrorResponseEntity;
2424
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
2525

2626
import java.util.Locale;
2727
import java.util.Optional;
28-
import java.util.concurrent.Flow;
2928

3029
import static org.elasticsearch.core.Strings.format;
3130

@@ -35,13 +34,10 @@
3534
*/
3635
public class HuggingFaceChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
3736

38-
@Override
39-
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
40-
return super.parseResult(request, flow);
41-
}
37+
private static final String HUGGING_FACE_ERROR = "hugging_face_error";
4238

4339
public HuggingFaceChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
44-
super(requestType, parseFunction, HuggingFaceErrorResponse::fromResponse);
40+
super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse);
4541
}
4642

4743
@Override
@@ -51,12 +47,12 @@ protected Exception buildError(String message, Request request, HttpResult resul
5147
if (request.isStreaming()) {
5248
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
5349
var restStatus = toRestStatus(responseStatusCode);
54-
return errorResponse instanceof HuggingFaceErrorResponse huggingFaceErrorResponse
50+
return errorResponse instanceof HuggingFaceErrorResponseEntity
5551
? new UnifiedChatCompletionException(
5652
restStatus,
5753
errorMessage,
58-
createErrorType(errorResponse),
59-
extractErrorCode(huggingFaceErrorResponse)
54+
HUGGING_FACE_ERROR,
55+
restStatus.name().toLowerCase(Locale.ROOT)
6056
)
6157
: new UnifiedChatCompletionException(
6258
restStatus,
@@ -71,8 +67,8 @@ protected Exception buildError(String message, Request request, HttpResult resul
7167

7268
@Override
7369
protected Exception buildMidStreamError(Request request, String message, Exception e) {
74-
var errorResponse = HuggingFaceErrorResponse.fromString(message);
75-
if (errorResponse instanceof HuggingFaceErrorResponse huggingFaceErrorResponse) {
70+
var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message);
71+
if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
7672
return new UnifiedChatCompletionException(
7773
RestStatus.INTERNAL_SERVER_ERROR,
7874
format(
@@ -81,8 +77,8 @@ protected Exception buildMidStreamError(Request request, String message, Excepti
8177
request.getInferenceEntityId(),
8278
errorResponse.getErrorMessage()
8379
),
84-
createErrorType(errorResponse),
85-
extractErrorCode(huggingFaceErrorResponse)
80+
HUGGING_FACE_ERROR,
81+
extractErrorCode(streamingHuggingFaceErrorResponseEntity)
8682
);
8783
} else if (e != null) {
8884
return UnifiedChatCompletionException.fromThrowable(e);
@@ -96,25 +92,40 @@ protected Exception buildMidStreamError(Request request, String message, Excepti
9692
}
9793
}
9894

99-
private static String extractErrorCode(HuggingFaceErrorResponse huggingFaceErrorResponse) {
100-
return huggingFaceErrorResponse.httpStatusCode() != null ? String.valueOf(huggingFaceErrorResponse.httpStatusCode()) : null;
95+
private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
96+
return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null
97+
? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode())
98+
: null;
10199
}
102100

103-
private static class HuggingFaceErrorResponse extends ErrorResponse {
101+
/**
102+
* Represents a structured error response specifically for streaming operations
103+
* using HuggingFace APIs. This is separate from non-streaming error responses,
104+
* which are handled by {@link HuggingFaceErrorResponseEntity}.
105+
* An example error response for failed field validation for streaming operation would look like
106+
* <code>
107+
* {
108+
* "error": "Input validation error: cannot compile regex from schema",
109+
* "http_status_code": 422
110+
* }
111+
* </code>
112+
*/
113+
private static class StreamingHuggingFaceErrorResponseEntity extends ErrorResponse {
104114
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
105-
"hugging_face_error",
115+
HUGGING_FACE_ERROR,
106116
true,
107-
args -> Optional.ofNullable((HuggingFaceErrorResponse) args[0])
108-
);
109-
private static final ConstructingObjectParser<HuggingFaceErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
110-
"hugging_face_error",
111-
true,
112-
args -> new HuggingFaceErrorResponse((String) args[0], (Integer) args[1])
117+
args -> Optional.ofNullable((StreamingHuggingFaceErrorResponseEntity) args[0])
113118
);
119+
private static final ConstructingObjectParser<StreamingHuggingFaceErrorResponseEntity, Void> ERROR_BODY_PARSER =
120+
new ConstructingObjectParser<>(
121+
HUGGING_FACE_ERROR,
122+
true,
123+
args -> new StreamingHuggingFaceErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1])
124+
);
114125

115126
static {
116-
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message"));
117-
ERROR_BODY_PARSER.declareIntOrNull(ConstructingObjectParser.optionalConstructorArg(), -1, new ParseField("http_status_code"));
127+
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message"));
128+
ERROR_BODY_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField("http_status_code"));
118129

119130
ERROR_PARSER.declareObjectOrNull(
120131
ConstructingObjectParser.optionalConstructorArg(),
@@ -124,19 +135,12 @@ private static class HuggingFaceErrorResponse extends ErrorResponse {
124135
);
125136
}
126137

127-
private static ErrorResponse fromResponse(HttpResult response) {
128-
try (
129-
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
130-
.createParser(XContentParserConfiguration.EMPTY, response.body())
131-
) {
132-
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
133-
} catch (Exception e) {
134-
// swallow the error
135-
}
136-
137-
return ErrorResponse.UNDEFINED_ERROR;
138-
}
139-
138+
/**
139+
* Parses a streaming HuggingFace error response from a JSON string.
140+
*
141+
* @param response the raw JSON string representing an error
142+
* @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails
143+
*/
140144
private static ErrorResponse fromString(String response) {
141145
try (
142146
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
@@ -153,7 +157,7 @@ private static ErrorResponse fromString(String response) {
153157
@Nullable
154158
private final Integer httpStatusCode;
155159

156-
HuggingFaceErrorResponse(String errorMessage, @Nullable Integer httpStatusCode) {
160+
StreamingHuggingFaceErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) {
157161
super(errorMessage);
158162
this.httpStatusCode = httpStatusCode;
159163
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
3636
builder.startObject();
3737
unifiedRequestEntity.toXContent(builder, params);
3838

39-
builder.field(MODEL_FIELD, model.getServiceSettings().modelId());
39+
if (model.getServiceSettings().modelId() != null) {
40+
builder.field(MODEL_FIELD, model.getServiceSettings().modelId());
41+
}
4042

4143
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
4244
builder.field(MAX_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ public HuggingFaceErrorResponseEntity(String message) {
2121
}
2222

2323
/**
24+
* Represents a structured error response specifically for non-streaming operations
25+
* using HuggingFace APIs. This is separate from streaming error responses,
26+
* which are handled by private nested HuggingFaceChatCompletionResponseHandler.StreamingHuggingFaceErrorResponseEntity.
2427
* An example error response for invalid auth would look like
2528
* <code>
2629
* {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,34 +37,33 @@ public class HuggingFaceChatCompletionResponseHandlerTests extends ESTestCase {
3737
public void testFailValidationWithAllFields() throws IOException {
3838
var responseJson = """
3939
{
40-
"error": {
41-
"message": "a message",
42-
"http_status_code": 422
43-
}
40+
"error": "a message",
41+
"type": "validation"
4442
}
4543
""";
4644

4745
var errorJson = invalidResponseJson(responseJson);
4846

4947
assertThat(errorJson, is("""
50-
{"error":{"code":"422","message":"Received a server error status code for request from inference entity id [id] status [500]. \
51-
Error message: [a message]","type":"HuggingFaceErrorResponse"}}"""));
48+
{"error":{"code":"bad_request","message":"Received a server error status code for request from \
49+
inference entity id [id] status [500]. \
50+
Error message: [a message]",\
51+
"type":"hugging_face_error"}}"""));
5252
}
5353

5454
public void testFailValidationWithoutOptionalFields() throws IOException {
5555
var responseJson = """
5656
{
57-
"error": {
58-
"message": "a message"
59-
}
57+
"error": "a message"
6058
}
6159
""";
6260

6361
var errorJson = invalidResponseJson(responseJson);
6462

6563
assertThat(errorJson, is("""
66-
{"error":{"message":"Received a server error status code for request from inference entity id [id] status [500]. \
67-
Error message: [a message]","type":"HuggingFaceErrorResponse"}}"""));
64+
{"error":{"code":"bad_request","message":"Received a server error status code for request from \
65+
inference entity id [id] status [500]. \
66+
Error message: [a message]","type":"hugging_face_error"}}"""));
6867
}
6968

7069
public void testFailValidationWithInvalidJson() throws IOException {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java

Lines changed: 85 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -347,14 +347,12 @@ public void testUnifiedCompletionInfer() throws Exception {
347347
}
348348
}
349349

350-
public void testUnifiedCompletionError() throws Exception {
350+
public void testUnifiedCompletionNonStreamingError() throws Exception {
351351
String responseJson = """
352352
{
353-
"error": {
354-
"message": "The model `gpt-4awero` does not exist or you do not have access to it.",
355-
"http_status_code": "404"
356-
}
357-
}""";
353+
"error": "Model not found."
354+
}
355+
""";
358356
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
359357

360358
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
@@ -383,10 +381,10 @@ public void testUnifiedCompletionError() throws Exception {
383381
assertThat(json, is("""
384382
{\
385383
"error":{\
386-
"code":"404",\
384+
"code":"not_found",\
387385
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
388-
[404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\
389-
"type":"HuggingFaceErrorResponse"\
386+
[404]. Error message: [Model not found.]",\
387+
"type":"hugging_face_error"\
390388
}}"""));
391389
} catch (IOException ex) {
392390
throw new RuntimeException(ex);
@@ -400,16 +398,89 @@ public void testUnifiedCompletionError() throws Exception {
400398
public void testMidStreamUnifiedCompletionError() throws Exception {
401399
String responseJson = """
402400
event: error
403-
data: { "error": { "message": "Timed out waiting for more data" } }
401+
data: {"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \
402+
{\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported by Outlines.\\n\
403+
If it should be supported, please open an issue.","http_status_code":422}}
404+
405+
""";
406+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
407+
testStreamError("""
408+
{\
409+
"error":{\
410+
"code":"422",\
411+
"message":"Received an error response for request from inference entity id [id]. Error message: [Input validation error: \
412+
cannot compile regex from schema: Unsupported JSON Schema structure {\\"id\\":\\"123\\"} \\nMake sure it is valid to the \
413+
JSON Schema specification and check if it's supported by Outlines.\\nIf it should be supported, please open an issue.]",\
414+
"type":"hugging_face_error"\
415+
}}""");
416+
}
417+
418+
public void testMidStreamUnifiedCompletionErrorNoMessage() throws Exception {
419+
String responseJson = """
420+
event: error
421+
data: {"error":{"http_status_code":422}}
422+
423+
""";
424+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
425+
testStreamError("""
426+
{\
427+
"error":{\
428+
"code":"422",\
429+
"message":"Received an error response for request from inference entity id [id]. Error message: \
430+
[unknown]",\
431+
"type":"hugging_face_error"\
432+
}}""");
433+
}
434+
435+
public void testMidStreamUnifiedCompletionErrorNoHttpStatusCode() throws Exception {
436+
String responseJson = """
437+
event: error
438+
data: {"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \
439+
{\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported by \
440+
Outlines.\\nIf it should be supported, please open an issue."}}
404441
405442
""";
406443
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
407444
testStreamError("""
408445
{\
409446
"error":{\
410447
"message":"Received an error response for request from inference entity id [id]. Error message: \
411-
[Timed out waiting for more data]",\
412-
"type":"HuggingFaceErrorResponse"\
448+
[Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \
449+
{\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported\
450+
by Outlines.\\nIf it should be supported, please open an issue.]",\
451+
"type":"hugging_face_error"\
452+
}}""");
453+
}
454+
455+
public void testMidStreamUnifiedCompletionErrorNoHttpStatusCodeNoMessage() throws Exception {
456+
String responseJson = """
457+
event: error
458+
data: {"error":{}}
459+
460+
""";
461+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
462+
testStreamError("""
463+
{\
464+
"error":{\
465+
"message":"Received an error response for request from inference entity id [id]. Error message: \
466+
[unknown]",\
467+
"type":"hugging_face_error"\
468+
}}""");
469+
}
470+
471+
public void testUnifiedCompletionMalformedError() throws Exception {
472+
String responseJson = """
473+
data: { invalid json }
474+
475+
""";
476+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
477+
testStreamError("""
478+
{\
479+
"error":{\
480+
"code":"bad_request",\
481+
"message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\
482+
at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\
483+
"type":"x_content_parse_exception"\
413484
}}""");
414485
}
415486

@@ -448,22 +519,6 @@ private void testStreamError(String expectedResponse) throws Exception {
448519
}
449520
}
450521

451-
public void testUnifiedCompletionMalformedError() throws Exception {
452-
String responseJson = """
453-
data: { invalid json }
454-
455-
""";
456-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
457-
testStreamError("""
458-
{\
459-
"error":{\
460-
"code":"bad_request",\
461-
"message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\
462-
at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\
463-
"type":"x_content_parse_exception"\
464-
}}""");
465-
}
466-
467522
public void testInfer_StreamRequest() throws Exception {
468523
String responseJson = """
469524
data: {\
@@ -517,10 +572,7 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
517572
String responseJson = """
518573
{
519574
"error": {
520-
"message": "You didn't provide an API key...",
521-
"type": "invalid_request_error",
522-
"param": null,
523-
"code": null
575+
"message": "You didn't provide an API key..."
524576
}
525577
}""";
526578
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
@@ -540,8 +592,7 @@ public void testInfer_StreamRequestRetry() throws Exception {
540592
webServer.enqueue(new MockResponse().setResponseCode(503).setBody("""
541593
{
542594
"error": {
543-
"message": "server busy",
544-
"type": "server_busy"
595+
"message": "server busy"
545596
}
546597
}"""));
547598
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""

0 commit comments

Comments
 (0)