1919
2020import java .util .Locale ;
2121import java .util .Objects ;
22+ import java .util .function .BiFunction ;
2223import java .util .function .Function ;
24+ import java .util .function .Supplier ;
2325
2426import static org .elasticsearch .core .Strings .format ;
2527import static org .elasticsearch .xpack .inference .external .http .HttpUtils .checkForEmptyBody ;
@@ -124,7 +126,7 @@ protected Exception buildError(String message, Request request, HttpResult resul
124126 protected Exception buildError (String message , Request request , HttpResult result , ErrorResponse errorResponse ) {
125127 var responseStatusCode = result .response ().getStatusLine ().getStatusCode ();
126128 return new ElasticsearchStatusException (
127- errorMessage (message , request , errorResponse , responseStatusCode ),
129+ extractErrorMessage (message , request , errorResponse , responseStatusCode ),
128130 toRestStatus (responseStatusCode )
129131 );
130132 }
@@ -138,22 +140,24 @@ protected Exception buildError(String message, Request request, HttpResult resul
138140 * @param request the request that caused the error
139141 * @param result the HTTP result containing the error response
140142 * @param errorResponse the parsed error response from the HTTP result
141- * @param errorResponseClass the class of the expected error response type
143+ * @param errorResponseClassSupplier the supplier that provides the class of the expected error response type
144+ * @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors
142145 * @return an instance of {@link UnifiedChatCompletionException} with details from the error response
143146 */
144147 protected UnifiedChatCompletionException buildChatCompletionError (
145148 String message ,
146149 Request request ,
147150 HttpResult result ,
148151 ErrorResponse errorResponse ,
149- Class <? extends ErrorResponse > errorResponseClass
152+ Supplier <Class <? extends ErrorResponse >> errorResponseClassSupplier ,
153+ ChatCompletionErrorBuilder chatCompletionErrorBuilder
150154 ) {
151155 assert request .isStreaming () : "Only streaming requests support this format" ;
152156 var statusCode = result .response ().getStatusLine ().getStatusCode ();
153- var errorMessage = errorMessage (message , request , errorResponse , statusCode );
157+ var errorMessage = extractErrorMessage (message , request , errorResponse , statusCode );
154158 var restStatus = toRestStatus (statusCode );
155159
156- return buildChatCompletionError (errorResponse , errorMessage , restStatus , errorResponseClass );
160+ return buildChatCompletionError (errorResponse , errorMessage , restStatus , errorResponseClassSupplier , chatCompletionErrorBuilder );
157161 }
158162
159163 /**
@@ -164,43 +168,24 @@ protected UnifiedChatCompletionException buildChatCompletionError(
164168 * @param errorResponse the error response parsed from the HTTP result
165169 * @param errorMessage the error message to include in the exception
166170 * @param restStatus the REST status code of the response
167- * @param errorResponseClass the class of the expected error response type
171+ * @param errorResponseClassSupplier the supplier that provides the class of the expected error response type
172+ * @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors
168173 * @return an instance of {@link UnifiedChatCompletionException} with details from the error response
169174 */
170175 protected UnifiedChatCompletionException buildChatCompletionError (
171176 ErrorResponse errorResponse ,
172177 String errorMessage ,
173178 RestStatus restStatus ,
174- Class <? extends ErrorResponse > errorResponseClass
179+ Supplier <Class <? extends ErrorResponse >> errorResponseClassSupplier ,
180+ ChatCompletionErrorBuilder chatCompletionErrorBuilder
175181 ) {
176- if (errorResponseClass .isInstance (errorResponse )) {
177- return buildProviderSpecificChatCompletionError (errorResponse , errorMessage , restStatus );
182+ if (errorResponseClassSupplier . get () .isInstance (errorResponse )) {
183+ return chatCompletionErrorBuilder . buildProviderSpecificChatCompletionError (errorResponse , errorMessage , restStatus );
178184 } else {
179185 return buildDefaultChatCompletionError (errorResponse , errorMessage , restStatus );
180186 }
181187 }
182188
183- /**
184- * Builds a custom {@link UnifiedChatCompletionException} for a streaming request.
185- * This method is called when a specific error response is found in the HTTP result.
186- * It must be implemented by subclasses to handle specific error response formats.
187- * Only streaming requests should use this method.
188- *
189- * @param errorResponse the error response parsed from the HTTP result
190- * @param errorMessage the error message to include in the exception
191- * @param restStatus the REST status code of the response
192- * @return an instance of {@link UnifiedChatCompletionException} with details from the error response
193- */
194- protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError (
195- ErrorResponse errorResponse ,
196- String errorMessage ,
197- RestStatus restStatus
198- ) {
199- throw new UnsupportedOperationException (
200- "Custom error handling is not implemented. Please override buildProviderSpecificChatCompletionError method."
201- );
202- }
203-
204189 /**
205190 * Builds a default {@link UnifiedChatCompletionException} for a streaming request.
206191 * This method is used when an error response is received but no specific error handling is implemented.
@@ -211,7 +196,7 @@ protected UnifiedChatCompletionException buildProviderSpecificChatCompletionErro
211196 * @param restStatus the REST status code of the response
212197 * @return an instance of {@link UnifiedChatCompletionException} with details from the error response
213198 */
214- protected UnifiedChatCompletionException buildDefaultChatCompletionError (
199+ private static UnifiedChatCompletionException buildDefaultChatCompletionError (
215200 ErrorResponse errorResponse ,
216201 String errorMessage ,
217202 RestStatus restStatus
@@ -232,21 +217,25 @@ protected UnifiedChatCompletionException buildDefaultChatCompletionError(
232217 * @param inferenceEntityId the ID of the inference entity
233218 * @param message the error message
234219 * @param e the exception that caused the error, can be null
235- * @param errorResponseClass the class of the expected error response type
220+ * @param errorResponseClassSupplier a supplier that provides the class of the expected error response type
221+ * @param specificErrorBuilder a function that builds a specific error based on the inference entity ID and error response
222+ * @param midStreamErrorExtractor a function that extracts the mid-stream error response from the message
236223 * @return a {@link UnifiedChatCompletionException} representing the mid-stream error
237224 */
238225 protected UnifiedChatCompletionException buildMidStreamChatCompletionError (
239226 String inferenceEntityId ,
240227 String message ,
241228 Exception e ,
242- Class <? extends ErrorResponse > errorResponseClass
229+ Supplier <Class <? extends ErrorResponse >> errorResponseClassSupplier ,
230+ BiFunction <String , ErrorResponse , UnifiedChatCompletionException > specificErrorBuilder ,
231+ Function <String , ErrorResponse > midStreamErrorExtractor
243232 ) {
244233 // Extract the error response from the message using the provided method
245- var errorResponse = extractMidStreamChatCompletionErrorResponse (message );
234+ var errorResponse = midStreamErrorExtractor . apply (message );
246235 // Check if the error response matches the expected type
247- if (errorResponseClass .isInstance (errorResponse )) {
236+ if (errorResponseClassSupplier . get () .isInstance (errorResponse )) {
248237 // If it matches, we can build a custom mid-stream error exception
249- return buildProviderSpecificMidStreamChatCompletionError (inferenceEntityId , errorResponse );
238+ return specificErrorBuilder . apply (inferenceEntityId , errorResponse );
250239 } else if (e != null ) {
251240 // If the error response does not match, we can still return an exception based on the original throwable
252241 return UnifiedChatCompletionException .fromThrowable (e );
@@ -256,26 +245,6 @@ protected UnifiedChatCompletionException buildMidStreamChatCompletionError(
256245 }
257246 }
258247
259- /**
260- * Builds a custom mid-stream {@link UnifiedChatCompletionException} for a streaming request.
261- * This method is called when a specific error response is found in the message.
262- * It must be implemented by subclasses to handle specific error response formats.
263- * Only streaming requests should use this method.
264- *
265- * @param inferenceEntityId the ID of the inference entity
266- * @param errorResponse the error response parsed from the message
267- * @return an instance of {@link UnifiedChatCompletionException} with details from the error response
268- */
269- protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError (
270- String inferenceEntityId ,
271- ErrorResponse errorResponse
272- ) {
273- throw new UnsupportedOperationException (
274- "Mid-stream error handling is not implemented for this response handler. "
275- + "Please override buildProviderSpecificMidStreamChatCompletionError method."
276- );
277- }
278-
279248 /**
280249 * Builds a default mid-stream error for a streaming request.
281250 * This method is used when no specific error response is found in the message.
@@ -285,7 +254,7 @@ protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompl
285254 * @param errorResponse the error response extracted from the message
286255 * @return a {@link UnifiedChatCompletionException} representing the default mid-stream error
287256 */
288- protected UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError (
257+ protected static UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError (
289258 String inferenceEntityId ,
290259 ErrorResponse errorResponse
291260 ) {
@@ -297,33 +266,18 @@ protected UnifiedChatCompletionException buildDefaultMidStreamChatCompletionErro
297266 );
298267 }
299268
300- /**
301- * Extracts the mid-stream error response from the message.
302- * This method is used to parse the error response from a streaming message.
303- * It must be implemented by subclasses to handle specific error response formats.
304- * Only streaming requests should use this method.
305- *
306- * @param message the message containing the error response
307- * @return an {@link ErrorResponse} object representing the mid-stream error
308- */
309- protected ErrorResponse extractMidStreamChatCompletionErrorResponse (String message ) {
310- throw new UnsupportedOperationException (
311- "Mid-stream error extraction is not implemented. Please override extractMidStreamChatCompletionErrorResponse method."
312- );
313- }
314-
315269 /**
316270 * Creates a string representation of the error type based on the provided ErrorResponse.
317271 * This method is used to generate a human-readable error type for logging or exception messages.
318272 *
319273 * @param errorResponse the ErrorResponse object
320274 * @return a string representing the error type
321275 */
322- protected static String createErrorType (ErrorResponse errorResponse ) {
276+ private static String createErrorType (ErrorResponse errorResponse ) {
323277 return errorResponse != null ? errorResponse .getClass ().getSimpleName () : "unknown" ;
324278 }
325279
326- protected String errorMessage (String message , Request request , ErrorResponse errorResponse , int statusCode ) {
280+ private static String extractErrorMessage (String message , Request request , ErrorResponse errorResponse , int statusCode ) {
327281 return (errorResponse == null
328282 || errorResponse .errorStructureFound () == false
329283 || Strings .isNullOrEmpty (errorResponse .getErrorMessage ()))
0 commit comments