Skip to content

Commit 604d441

Browse files
Enhance LlamaChatCompletionResponseHandler to support mid-stream error handling and improve error response parsing
1 parent 8e7ca13 commit 604d441

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77

88
package org.elasticsearch.xpack.inference.services.llama.completion;
99

10+
import org.elasticsearch.rest.RestStatus;
11+
import org.elasticsearch.xcontent.ConstructingObjectParser;
12+
import org.elasticsearch.xcontent.ParseField;
13+
import org.elasticsearch.xcontent.XContentFactory;
14+
import org.elasticsearch.xcontent.XContentParser;
15+
import org.elasticsearch.xcontent.XContentParserConfiguration;
16+
import org.elasticsearch.xcontent.XContentType;
1017
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
1118
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1219
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
@@ -16,15 +23,37 @@
1623
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
1724

1825
import java.util.Locale;
26+
import java.util.Optional;
1927

28+
import static org.elasticsearch.core.Strings.format;
29+
30+
/**
31+
* Handles streaming chat completion responses and error parsing for Llama inference endpoints.
32+
* This handler is designed to work with the unified Llama chat completion API.
33+
*/
2034
public class LlamaChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
2135

2236
private static final String LLAMA_ERROR = "llama_error";
37+
private static final String STREAM_ERROR = "stream_error";
2338

39+
/**
40+
* Constructor for creating a LlamaChatCompletionResponseHandler with specified request type and response parser.
41+
*
42+
* @param requestType the type of request this handler will process
43+
* @param parseFunction the function to parse the response
44+
*/
2445
public LlamaChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
2546
super(requestType, parseFunction, LlamaErrorResponse::fromResponse);
2647
}
2748

49+
/**
50+
* Constructor for creating a LlamaChatCompletionResponseHandler with specified request type,
51+
* @param message the error message to include in the exception
52+
* @param request the request that caused the error
53+
* @param result the HTTP result containing the response
54+
* @param errorResponse the error response parsed from the HTTP result
55+
* @return an exception representing the error, specific to Llama chat completion
56+
*/
2857
@Override
2958
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
3059
assert request.isStreaming() : "Only streaming requests support this format";
@@ -44,4 +73,90 @@ protected Exception buildError(String message, Request request, HttpResult resul
4473
return super.buildError(message, request, result, errorResponse);
4574
}
4675
}
76+
77+
/**
78+
* Builds an exception for mid-stream errors encountered during Llama chat completion requests.
79+
*
80+
* @param request the request that caused the error
81+
* @param message the error message
82+
* @param e the exception that occurred, if any
83+
* @return a UnifiedChatCompletionException representing the error
84+
*/
85+
@Override
86+
protected Exception buildMidStreamError(Request request, String message, Exception e) {
87+
var errorResponse = StreamingLlamaErrorResponseEntity.fromString(message);
88+
if (errorResponse instanceof StreamingLlamaErrorResponseEntity) {
89+
return new UnifiedChatCompletionException(
90+
RestStatus.INTERNAL_SERVER_ERROR,
91+
format(
92+
"%s for request from inference entity id [%s]. Error message: [%s]",
93+
SERVER_ERROR_OBJECT,
94+
request.getInferenceEntityId(),
95+
errorResponse.getErrorMessage()
96+
),
97+
LLAMA_ERROR,
98+
STREAM_ERROR
99+
);
100+
} else if (e != null) {
101+
return UnifiedChatCompletionException.fromThrowable(e);
102+
} else {
103+
return new UnifiedChatCompletionException(
104+
RestStatus.INTERNAL_SERVER_ERROR,
105+
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
106+
createErrorType(errorResponse),
107+
STREAM_ERROR
108+
);
109+
}
110+
}
111+
112+
private static class StreamingLlamaErrorResponseEntity extends ErrorResponse {
113+
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
114+
LLAMA_ERROR,
115+
true,
116+
args -> Optional.ofNullable((LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity) args[0])
117+
);
118+
private static final ConstructingObjectParser<
119+
LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity,
120+
Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
121+
LLAMA_ERROR,
122+
true,
123+
args -> new LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity(
124+
args[0] != null ? (String) args[0] : "unknown"
125+
)
126+
);
127+
128+
static {
129+
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message"));
130+
131+
ERROR_PARSER.declareObjectOrNull(
132+
ConstructingObjectParser.optionalConstructorArg(),
133+
ERROR_BODY_PARSER,
134+
null,
135+
new ParseField("error")
136+
);
137+
}
138+
139+
/**
140+
* Parses a streaming Llama error response from a JSON string.
141+
*
142+
* @param response the raw JSON string representing an error
143+
* @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails
144+
*/
145+
private static ErrorResponse fromString(String response) {
146+
try (
147+
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
148+
.createParser(XContentParserConfiguration.EMPTY, response)
149+
) {
150+
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
151+
} catch (Exception e) {
152+
// swallow the error
153+
}
154+
155+
return ErrorResponse.UNDEFINED_ERROR;
156+
}
157+
158+
StreamingLlamaErrorResponseEntity(String errorMessage) {
159+
super(errorMessage);
160+
}
161+
}
47162
}

0 commit comments

Comments
 (0)