77
88package org .elasticsearch .xpack .inference .services .llama .completion ;
99
10+ import org .elasticsearch .rest .RestStatus ;
11+ import org .elasticsearch .xcontent .ConstructingObjectParser ;
12+ import org .elasticsearch .xcontent .ParseField ;
13+ import org .elasticsearch .xcontent .XContentFactory ;
14+ import org .elasticsearch .xcontent .XContentParser ;
15+ import org .elasticsearch .xcontent .XContentParserConfiguration ;
16+ import org .elasticsearch .xcontent .XContentType ;
1017import org .elasticsearch .xpack .core .inference .results .UnifiedChatCompletionException ;
1118import org .elasticsearch .xpack .inference .external .http .HttpResult ;
1219import org .elasticsearch .xpack .inference .external .http .retry .ErrorResponse ;
1623import org .elasticsearch .xpack .inference .services .openai .OpenAiUnifiedChatCompletionResponseHandler ;
1724
1825import java .util .Locale ;
26+ import java .util .Optional ;
1927
28+ import static org .elasticsearch .core .Strings .format ;
29+
30+ /**
31+ * Handles streaming chat completion responses and error parsing for Llama inference endpoints.
32+ * This handler is designed to work with the unified Llama chat completion API.
33+ */
2034public class LlamaChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
2135
2236 private static final String LLAMA_ERROR = "llama_error" ;
37+ private static final String STREAM_ERROR = "stream_error" ;
2338
39+ /**
40+ * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type and response parser.
41+ *
42+ * @param requestType the type of request this handler will process
43+ * @param parseFunction the function to parse the response
44+ */
2445 public LlamaChatCompletionResponseHandler (String requestType , ResponseParser parseFunction ) {
2546 super (requestType , parseFunction , LlamaErrorResponse ::fromResponse );
2647 }
2748
49+ /**
50+ * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type,
51+ * @param message the error message to include in the exception
52+ * @param request the request that caused the error
53+ * @param result the HTTP result containing the response
54+ * @param errorResponse the error response parsed from the HTTP result
55+ * @return an exception representing the error, specific to Llama chat completion
56+ */
2857 @ Override
2958 protected Exception buildError (String message , Request request , HttpResult result , ErrorResponse errorResponse ) {
3059 assert request .isStreaming () : "Only streaming requests support this format" ;
@@ -44,4 +73,90 @@ protected Exception buildError(String message, Request request, HttpResult resul
4473 return super .buildError (message , request , result , errorResponse );
4574 }
4675 }
76+
77+ /**
78+ * Builds an exception for mid-stream errors encountered during Llama chat completion requests.
79+ *
80+ * @param request the request that caused the error
81+ * @param message the error message
82+ * @param e the exception that occurred, if any
83+ * @return a UnifiedChatCompletionException representing the error
84+ */
85+ @ Override
86+ protected Exception buildMidStreamError (Request request , String message , Exception e ) {
87+ var errorResponse = StreamingLlamaErrorResponseEntity .fromString (message );
88+ if (errorResponse instanceof StreamingLlamaErrorResponseEntity ) {
89+ return new UnifiedChatCompletionException (
90+ RestStatus .INTERNAL_SERVER_ERROR ,
91+ format (
92+ "%s for request from inference entity id [%s]. Error message: [%s]" ,
93+ SERVER_ERROR_OBJECT ,
94+ request .getInferenceEntityId (),
95+ errorResponse .getErrorMessage ()
96+ ),
97+ LLAMA_ERROR ,
98+ STREAM_ERROR
99+ );
100+ } else if (e != null ) {
101+ return UnifiedChatCompletionException .fromThrowable (e );
102+ } else {
103+ return new UnifiedChatCompletionException (
104+ RestStatus .INTERNAL_SERVER_ERROR ,
105+ format ("%s for request from inference entity id [%s]" , SERVER_ERROR_OBJECT , request .getInferenceEntityId ()),
106+ createErrorType (errorResponse ),
107+ STREAM_ERROR
108+ );
109+ }
110+ }
111+
112+ private static class StreamingLlamaErrorResponseEntity extends ErrorResponse {
113+ private static final ConstructingObjectParser <Optional <ErrorResponse >, Void > ERROR_PARSER = new ConstructingObjectParser <>(
114+ LLAMA_ERROR ,
115+ true ,
116+ args -> Optional .ofNullable ((LlamaChatCompletionResponseHandler .StreamingLlamaErrorResponseEntity ) args [0 ])
117+ );
118+ private static final ConstructingObjectParser <
119+ LlamaChatCompletionResponseHandler .StreamingLlamaErrorResponseEntity ,
120+ Void > ERROR_BODY_PARSER = new ConstructingObjectParser <>(
121+ LLAMA_ERROR ,
122+ true ,
123+ args -> new LlamaChatCompletionResponseHandler .StreamingLlamaErrorResponseEntity (
124+ args [0 ] != null ? (String ) args [0 ] : "unknown"
125+ )
126+ );
127+
128+ static {
129+ ERROR_BODY_PARSER .declareString (ConstructingObjectParser .optionalConstructorArg (), new ParseField ("message" ));
130+
131+ ERROR_PARSER .declareObjectOrNull (
132+ ConstructingObjectParser .optionalConstructorArg (),
133+ ERROR_BODY_PARSER ,
134+ null ,
135+ new ParseField ("error" )
136+ );
137+ }
138+
139+ /**
140+ * Parses a streaming Llama error response from a JSON string.
141+ *
142+ * @param response the raw JSON string representing an error
143+ * @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails
144+ */
145+ private static ErrorResponse fromString (String response ) {
146+ try (
147+ XContentParser parser = XContentFactory .xContent (XContentType .JSON )
148+ .createParser (XContentParserConfiguration .EMPTY , response )
149+ ) {
150+ return ERROR_PARSER .apply (parser , null ).orElse (ErrorResponse .UNDEFINED_ERROR );
151+ } catch (Exception e ) {
152+ // swallow the error
153+ }
154+
155+ return ErrorResponse .UNDEFINED_ERROR ;
156+ }
157+
158+ StreamingLlamaErrorResponseEntity (String errorMessage ) {
159+ super (errorMessage );
160+ }
161+ }
47162}
0 commit comments