1919
2020import java .util .Locale ;
2121import java .util .Objects ;
22- import java .util .function .BiFunction ;
2322import java .util .function .Function ;
24- import java .util .function .Supplier ;
2523
2624import static org .elasticsearch .core .Strings .format ;
2725import static org .elasticsearch .xpack .inference .external .http .HttpUtils .checkForEmptyBody ;
@@ -38,7 +36,6 @@ public abstract class BaseResponseHandler implements ResponseHandler {
3836 public static final String SERVER_ERROR_OBJECT = "Received an error response" ;
3937 public static final String BAD_REQUEST = "Received a bad request status code" ;
4038 public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code" ;
41- protected static final String ERROR_TYPE = "error" ;
4239 protected static final String STREAM_ERROR = "stream_error" ;
4340
4441 protected final String requestType ;
@@ -140,47 +137,22 @@ protected Exception buildError(String message, Request request, HttpResult resul
140137 * @param request the request that caused the error
141138 * @param result the HTTP result containing the error response
142139 * @param errorResponse the parsed error response from the HTTP result
143- * @param errorResponseClassSupplier the supplier that provides the class of the expected error response type
144- * @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors
145140 * @return an instance of {@link UnifiedChatCompletionException} with details from the error response
146141 */
147142 protected UnifiedChatCompletionException buildChatCompletionError (
148143 String message ,
149144 Request request ,
150145 HttpResult result ,
151- ErrorResponse errorResponse ,
152- Supplier <Class <? extends ErrorResponse >> errorResponseClassSupplier ,
153- ChatCompletionErrorBuilder chatCompletionErrorBuilder
146+ ErrorResponse errorResponse
154147 ) {
155148 assert request .isStreaming () : "Only streaming requests support this format" ;
156149 var statusCode = result .response ().getStatusLine ().getStatusCode ();
157150 var errorMessage = extractErrorMessage (message , request , errorResponse , statusCode );
158151 var restStatus = toRestStatus (statusCode );
159152
160- return buildChatCompletionError (errorResponse , errorMessage , restStatus , errorResponseClassSupplier , chatCompletionErrorBuilder );
161- }
162-
163- /**
164- * Builds a {@link UnifiedChatCompletionException} for a streaming request.
165- * This method is used when an error response is received from the external service.
166- * Only streaming requests should use this method.
167- *
168- * @param errorResponse the error response parsed from the HTTP result
169- * @param errorMessage the error message to include in the exception
170- * @param restStatus the REST status code of the response
171- * @param errorResponseClassSupplier the supplier that provides the class of the expected error response type
172- * @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors
173- * @return an instance of {@link UnifiedChatCompletionException} with details from the error response
174- */
175- protected UnifiedChatCompletionException buildChatCompletionError (
176- ErrorResponse errorResponse ,
177- String errorMessage ,
178- RestStatus restStatus ,
179- Supplier <Class <? extends ErrorResponse >> errorResponseClassSupplier ,
180- ChatCompletionErrorBuilder chatCompletionErrorBuilder
181- ) {
182- if (errorResponse .errorStructureFound () && errorResponseClassSupplier .get ().isInstance (errorResponse )) {
183- return chatCompletionErrorBuilder .buildProviderSpecificChatCompletionError (errorResponse , errorMessage , restStatus );
153+ if (errorResponse .errorStructureFound ()
154+ && errorResponse instanceof UnifiedChatCompletionExceptionConvertible chatCompletionExceptionConvertible ) {
155+ return chatCompletionExceptionConvertible .toUnifiedChatCompletionException (errorMessage , restStatus );
184156 } else {
185157 return buildDefaultChatCompletionError (errorResponse , errorMessage , restStatus );
186158 }
@@ -196,7 +168,7 @@ protected UnifiedChatCompletionException buildChatCompletionError(
196168 * @param restStatus the REST status code of the response
197169 * @return an instance of {@link UnifiedChatCompletionException} with details from the error response
198170 */
199- private static UnifiedChatCompletionException buildDefaultChatCompletionError (
171+ protected static UnifiedChatCompletionException buildDefaultChatCompletionError (
200172 ErrorResponse errorResponse ,
201173 String errorMessage ,
202174 RestStatus restStatus
@@ -217,31 +189,27 @@ private static UnifiedChatCompletionException buildDefaultChatCompletionError(
217189 * @param inferenceEntityId the ID of the inference entity
218190 * @param message the error message
219191 * @param e the exception that caused the error, can be null
220- * @param errorResponseClassSupplier a supplier that provides the class of the expected error response type
221- * @param specificErrorBuilder a function that builds a specific error based on the inference entity ID and error response
222192 * @param midStreamErrorExtractor a function that extracts the mid-stream error response from the message
223193 * @return a {@link UnifiedChatCompletionException} representing the mid-stream error
224194 */
225195 protected UnifiedChatCompletionException buildMidStreamChatCompletionError (
226196 String inferenceEntityId ,
227197 String message ,
228198 Exception e ,
229- Supplier <Class <? extends ErrorResponse >> errorResponseClassSupplier ,
230- BiFunction <String , ErrorResponse , UnifiedChatCompletionException > specificErrorBuilder ,
231199 Function <String , ErrorResponse > midStreamErrorExtractor
232200 ) {
233201 // Extract the error response from the message using the provided method
234- var errorResponse = midStreamErrorExtractor .apply (message );
202+ var error = midStreamErrorExtractor .apply (message );
235203 // Check if the error response matches the expected type
236- if (errorResponse .errorStructureFound () && errorResponseClassSupplier . get (). isInstance ( errorResponse ) ) {
204+ if (error .errorStructureFound () && error instanceof MidStreamUnifiedChatCompletionExceptionConvertible midStreamError ) {
237205 // If it matches, we can build a custom mid-stream error exception
238- return specificErrorBuilder . apply (inferenceEntityId , errorResponse );
206+ return midStreamError . toUnifiedChatCompletionException (inferenceEntityId );
239207 } else if (e != null ) {
240208 // If the error response does not match, we can still return an exception based on the original throwable
241209 return UnifiedChatCompletionException .fromThrowable (e );
242210 } else {
243211 // If no specific error response is found, we return a default mid-stream error
244- return buildDefaultMidStreamChatCompletionError (inferenceEntityId , errorResponse );
212+ return buildDefaultMidStreamChatCompletionError (inferenceEntityId , error );
245213 }
246214 }
247215
@@ -277,7 +245,7 @@ private static String createErrorType(ErrorResponse errorResponse) {
277245 return errorResponse != null ? errorResponse .getClass ().getSimpleName () : "unknown" ;
278246 }
279247
280- private static String extractErrorMessage (String message , Request request , ErrorResponse errorResponse , int statusCode ) {
248+ protected static String extractErrorMessage (String message , Request request , ErrorResponse errorResponse , int statusCode ) {
281249 return (errorResponse == null
282250 || errorResponse .errorStructureFound () == false
283251 || Strings .isNullOrEmpty (errorResponse .getErrorMessage ()))
@@ -291,7 +259,7 @@ private static String extractErrorMessage(String message, Request request, Error
291259 );
292260 }
293261
294- public static RestStatus toRestStatus (int statusCode ) {
262+ protected static RestStatus toRestStatus (int statusCode ) {
295263 RestStatus code = null ;
296264 if (statusCode < 500 ) {
297265 code = RestStatus .fromCode (statusCode );
0 commit comments