|
| 1 | +/* |
| 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | + * or more contributor license agreements. Licensed under the Elastic License |
| 4 | + * 2.0; you may not use this file except in compliance with the Elastic License |
| 5 | + * 2.0. |
| 6 | + */ |
| 7 | + |
| 8 | +package org.elasticsearch.xpack.inference.external.http.retry; |
| 9 | + |
| 10 | +import org.elasticsearch.rest.RestStatus; |
| 11 | +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; |
| 12 | +import org.elasticsearch.xpack.inference.external.http.HttpResult; |
| 13 | +import org.elasticsearch.xpack.inference.external.request.Request; |
| 14 | + |
| 15 | +import java.util.Locale; |
| 16 | +import java.util.Objects; |
| 17 | + |
| 18 | +import static org.elasticsearch.core.Strings.format; |
| 19 | +import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.SERVER_ERROR_OBJECT; |
| 20 | +import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.toRestStatus; |
| 21 | + |
| 22 | +public class ChatCompletionErrorResponseHandler { |
| 23 | + private static final String STREAM_ERROR = "stream_error"; |
| 24 | + |
| 25 | + private final UnifiedChatCompletionErrorParser unifiedChatCompletionErrorParser; |
| 26 | + |
| 27 | + public ChatCompletionErrorResponseHandler(UnifiedChatCompletionErrorParser errorParser) { |
| 28 | + this.unifiedChatCompletionErrorParser = Objects.requireNonNull(errorParser); |
| 29 | + } |
| 30 | + |
| 31 | + public void checkForErrorObject(Request request, HttpResult result) { |
| 32 | + var errorEntity = unifiedChatCompletionErrorParser.parse(result); |
| 33 | + |
| 34 | + if (errorEntity.errorStructureFound()) { |
| 35 | + // We don't really know what happened because the status code was 200 so we'll return a failure and let the |
| 36 | + // client retry if necessary |
| 37 | + // If we did want to retry here, we'll need to determine if this was a streaming request, if it was |
| 38 | + // we shouldn't retry because that would replay the entire streaming request and the client would get |
| 39 | + // duplicate chunks back |
| 40 | + throw new RetryException(false, buildChatCompletionErrorInternal(SERVER_ERROR_OBJECT, request, result, errorEntity)); |
| 41 | + } |
| 42 | + } |
| 43 | + |
| 44 | + public UnifiedChatCompletionException buildChatCompletionError(String message, Request request, HttpResult result) { |
| 45 | + var errorResponse = unifiedChatCompletionErrorParser.parse(result); |
| 46 | + return buildChatCompletionErrorInternal(message, request, result, errorResponse); |
| 47 | + } |
| 48 | + |
| 49 | + private UnifiedChatCompletionException buildChatCompletionErrorInternal( |
| 50 | + String message, |
| 51 | + Request request, |
| 52 | + HttpResult result, |
| 53 | + UnifiedChatCompletionErrorResponse errorResponse |
| 54 | + ) { |
| 55 | + assert request.isStreaming() : "Only streaming requests support this format"; |
| 56 | + var statusCode = result.response().getStatusLine().getStatusCode(); |
| 57 | + var errorMessage = BaseResponseHandler.constructErrorMessage(message, request, errorResponse, statusCode); |
| 58 | + var restStatus = toRestStatus(statusCode); |
| 59 | + |
| 60 | + if (errorResponse.errorStructureFound()) { |
| 61 | + return new UnifiedChatCompletionException( |
| 62 | + restStatus, |
| 63 | + errorMessage, |
| 64 | + errorResponse.type(), |
| 65 | + errorResponse.code(), |
| 66 | + errorResponse.param() |
| 67 | + ); |
| 68 | + } else { |
| 69 | + return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus); |
| 70 | + } |
| 71 | + } |
| 72 | + |
| 73 | + /** |
| 74 | + * Builds a default {@link UnifiedChatCompletionException} for a streaming request. |
| 75 | + * This method is used when an error response is received we were unable to parse it in the format we were expecting. |
| 76 | + * Only streaming requests should use this method. |
| 77 | + * |
| 78 | + * @param errorResponse the error response extracted from the HTTP result |
| 79 | + * @param errorMessage the error message to include in the exception |
| 80 | + * @param restStatus the REST status code of the response |
| 81 | + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response |
| 82 | + */ |
| 83 | + private static UnifiedChatCompletionException buildDefaultChatCompletionError( |
| 84 | + ErrorResponse errorResponse, |
| 85 | + String errorMessage, |
| 86 | + RestStatus restStatus |
| 87 | + ) { |
| 88 | + return new UnifiedChatCompletionException( |
| 89 | + restStatus, |
| 90 | + errorMessage, |
| 91 | + createErrorType(errorResponse), |
| 92 | + restStatus.name().toLowerCase(Locale.ROOT) |
| 93 | + ); |
| 94 | + } |
| 95 | + |
| 96 | + /** |
| 97 | + * Builds a mid-stream error for a streaming request. |
| 98 | + * This method is used when an error occurs while processing a streaming response. |
| 99 | + * Only streaming requests should use this method. |
| 100 | + * |
| 101 | + * @param inferenceEntityId the ID of the inference entity |
| 102 | + * @param message the error message |
| 103 | + * @param e the exception that caused the error, can be null |
| 104 | + * @return a {@link UnifiedChatCompletionException} representing the mid-stream error |
| 105 | + */ |
| 106 | + public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { |
| 107 | + var error = unifiedChatCompletionErrorParser.parse(message); |
| 108 | + |
| 109 | + if (error.errorStructureFound()) { |
| 110 | + return new UnifiedChatCompletionException( |
| 111 | + RestStatus.INTERNAL_SERVER_ERROR, |
| 112 | + format( |
| 113 | + "%s for request from inference entity id [%s]. Error message: [%s]", |
| 114 | + SERVER_ERROR_OBJECT, |
| 115 | + inferenceEntityId, |
| 116 | + error.getErrorMessage() |
| 117 | + ), |
| 118 | + error.type(), |
| 119 | + error.code(), |
| 120 | + error.param() |
| 121 | + ); |
| 122 | + } else if (e != null) { |
| 123 | + // If the error response does not match, we can still return an exception based on the original throwable |
| 124 | + return UnifiedChatCompletionException.fromThrowable(e); |
| 125 | + } else { |
| 126 | + // If no specific error response is found, we return a default mid-stream error |
| 127 | + return buildDefaultMidStreamChatCompletionError(inferenceEntityId, error); |
| 128 | + } |
| 129 | + } |
| 130 | + |
| 131 | + /** |
| 132 | + * Builds a default mid-stream error for a streaming request. |
| 133 | + * This method is used when no specific error response is found in the message. |
| 134 | + * Only streaming requests should use this method. |
| 135 | + * |
| 136 | + * @param inferenceEntityId the ID of the inference entity |
| 137 | + * @param errorResponse the error response extracted from the message |
| 138 | + * @return a {@link UnifiedChatCompletionException} representing the default mid-stream error |
| 139 | + */ |
| 140 | + private static UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError( |
| 141 | + String inferenceEntityId, |
| 142 | + ErrorResponse errorResponse |
| 143 | + ) { |
| 144 | + return new UnifiedChatCompletionException( |
| 145 | + RestStatus.INTERNAL_SERVER_ERROR, |
| 146 | + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), |
| 147 | + createErrorType(errorResponse), |
| 148 | + STREAM_ERROR |
| 149 | + ); |
| 150 | + } |
| 151 | + |
| 152 | + /** |
| 153 | + * Creates a string representation of the error type based on the provided ErrorResponse. |
| 154 | + * This method is used to generate a human-readable error type for logging or exception messages. |
| 155 | + * |
| 156 | + * @param errorResponse the ErrorResponse object |
| 157 | + * @return a string representing the error type |
| 158 | + */ |
| 159 | + private static String createErrorType(ErrorResponse errorResponse) { |
| 160 | + return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; |
| 161 | + } |
| 162 | +} |
0 commit comments