Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.util.Objects;
import java.util.function.Function;
import java.util.function.BiFunction;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
Expand All @@ -37,17 +37,21 @@ public abstract class BaseResponseHandler implements ResponseHandler {

protected final String requestType;
private final ResponseParser parseFunction;
private final Function<HttpResult, ErrorResponse> errorParseFunction;
private final BiFunction<Request, HttpResult, ErrorResponse> errorParseFunction;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to understand the context better, why do we need the inference ID as part of the error in the custom service PR? Is there a reason we need the full Request here instead of passing just the inference ID?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can just pass the inference id to limit the scope 👍

private final boolean canHandleStreamingResponses;

public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function<HttpResult, ErrorResponse> errorParseFunction) {
public BaseResponseHandler(
String requestType,
ResponseParser parseFunction,
BiFunction<Request, HttpResult, ErrorResponse> errorParseFunction
) {
this(requestType, parseFunction, errorParseFunction, false);
}

public BaseResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction,
BiFunction<Request, HttpResult, ErrorResponse> errorParseFunction,
boolean canHandleStreamingResponses
) {
this.requestType = Objects.requireNonNull(requestType);
Expand Down Expand Up @@ -96,7 +100,7 @@ public void validateResponse(
protected abstract void checkForFailureStatusCode(Request request, HttpResult result);

private void checkForErrorObject(Request request, HttpResult result) {
var errorEntity = errorParseFunction.apply(result);
var errorEntity = errorParseFunction.apply(request, result);

if (errorEntity.errorStructureFound()) {
// We don't really know what happened because the status code was 200 so we'll return a failure and let the
Expand All @@ -109,7 +113,7 @@ private void checkForErrorObject(Request request, HttpResult result) {
}

protected Exception buildError(String message, Request request, HttpResult result) {
var errorEntityMsg = errorParseFunction.apply(result);
var errorEntityMsg = errorParseFunction.apply(request, result);
return buildError(message, request, result, errorEntityMsg);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -68,4 +69,8 @@ public static ErrorResponse fromResponse(HttpResult response, String defaultMess
public static ErrorResponse fromResponse(HttpResult response) {
return fromResponse(response, "");
}

public static ErrorResponse fromResponse(Request request, HttpResult response) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit unclear why we drop the request here? Is this where we would make a follow-up change to utilize the request?

return fromResponse(response);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.request.Request;

public class AlibabaCloudSearchErrorResponseEntity extends ErrorResponse {
private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchErrorResponseEntity.class);
Expand All @@ -23,6 +24,10 @@ private AlibabaCloudSearchErrorResponseEntity(String errorMessage) {
super(errorMessage);
}

public static ErrorResponse fromResponse(Request request, HttpResult response) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add this to the ErrorResponse class to avoid duplication?

return fromResponse(response);
}

/**
* An example error response for invalid auth would look like
* <code>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.elasticsearch.xpack.inference.services.openai.OpenAiStreamingProcessor;

import java.util.concurrent.Flow;
import java.util.function.Function;
import java.util.function.BiFunction;

import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;
Expand Down Expand Up @@ -55,7 +55,7 @@ public class AzureMistralOpenAiExternalResponseHandler extends BaseResponseHandl
public AzureMistralOpenAiExternalResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction,
BiFunction<Request, HttpResult, ErrorResponse> errorParseFunction,
boolean canHandleStreamingResponses
) {
super(requestType, parseFunction, errorParseFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.request.Request;

public class CohereErrorResponseEntity extends ErrorResponse {

private CohereErrorResponseEntity(String errorMessage) {
super(errorMessage);
}

public static ErrorResponse fromResponse(Request request, HttpResult response) {
return fromResponse(response);
}

/**
* An example error response for invalid auth would look like
* <code>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.io.IOException;

Expand All @@ -37,6 +38,10 @@ private ElasticInferenceServiceErrorResponseEntity(String errorMessage) {
super(errorMessage);
}

public static ErrorResponse fromResponse(Request request, HttpResult response) {
return fromResponse(response);
}

public static ErrorResponse fromResponse(HttpResult response) {
return fromParser(
() -> XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.util.Map;
import java.util.Objects;
Expand All @@ -23,6 +24,10 @@ private GoogleAiStudioErrorResponseEntity(String errorMessage) {
super(errorMessage);
}

public static ErrorResponse fromResponse(Request request, HttpResult response) {
return fromResponse(response);
}

/**
* An example error response for invalid auth would look like
* <code>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.util.Map;
import java.util.Objects;
Expand All @@ -23,6 +24,10 @@ private GoogleVertexAiErrorResponseEntity(String errorMessage) {
super(errorMessage);
}

public static ErrorResponse fromResponse(Request request, HttpResult response) {
return fromResponse(response);
}

/**
* An example error response for invalid auth would look like
* <code>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.request.Request;

public class HuggingFaceErrorResponseEntity extends ErrorResponse {

public HuggingFaceErrorResponseEntity(String message) {
super(message);
}

public static ErrorResponse fromResponse(Request request, HttpResult response) {
return fromResponse(response);
}

/**
* An example error response for invalid auth would look like
* <code>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.util.Map;
import java.util.Objects;
Expand All @@ -23,6 +24,10 @@ private IbmWatsonxErrorResponseEntity(String errorMessage) {
super(errorMessage);
}

public static ErrorResponse fromResponse(Request request, HttpResult response) {
return fromResponse(response);
}

@SuppressWarnings("unchecked")
public static ErrorResponse fromResponse(HttpResult response) {
try (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.request.Request;

public class JinaAIErrorResponseEntity extends ErrorResponse {

private JinaAIErrorResponseEntity(String errorMessage) {
super(errorMessage);
}

public static ErrorResponse fromResponse(Request request, HttpResult response) {
return fromResponse(response);
}

/**
* Parse an HTTP response into a JinaAIErrorResponseEntity
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.util.function.Function;
import java.util.function.BiFunction;

public class OpenAiChatCompletionResponseHandler extends OpenAiResponseHandler {
public OpenAiChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
Expand All @@ -23,7 +23,7 @@ public OpenAiChatCompletionResponseHandler(String requestType, ResponseParser pa
protected OpenAiChatCompletionResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction
BiFunction<Request, HttpResult, ErrorResponse> errorParseFunction
) {
super(requestType, parseFunction, errorParseFunction, true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;

import java.util.concurrent.Flow;
import java.util.function.Function;
import java.util.function.BiFunction;

import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;

Expand Down Expand Up @@ -50,7 +50,7 @@ public OpenAiResponseHandler(String requestType, ResponseParser parseFunction, b
protected OpenAiResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction,
BiFunction<Request, HttpResult, ErrorResponse> errorParseFunction,
boolean canHandleStreamingResponses
) {
super(requestType, parseFunction, errorParseFunction, canHandleStreamingResponses);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ private static class OpenAiErrorResponse extends ErrorResponse {
);
}

private static ErrorResponse fromResponse(Request request, HttpResult response) {
return fromResponse(response);
}

private static ErrorResponse fromResponse(HttpResult response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.request.Request;

public class VoyageAIErrorResponseEntity extends ErrorResponse {

private VoyageAIErrorResponseEntity(String errorMessage) {
super(errorMessage);
}

public static ErrorResponse fromResponse(Request request, HttpResult response) {
return fromResponse(response);
}

/**
* Parse an HTTP response into a VoyageAIErrorResponseEntity
*
Expand Down