Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;

import java.net.URI;

public class ActionUtils {

public static ActionListener<InferenceServiceResults> wrapFailuresInElasticsearchException(
Expand All @@ -30,7 +27,12 @@ public static ActionListener<InferenceServiceResults> wrapFailuresInElasticsearc
if (unwrappedException instanceof ElasticsearchException esException) {
l.onFailure(esException);
} else {
l.onFailure(createInternalServerError(unwrappedException, errorMessage));
l.onFailure(
createInternalServerError(
unwrappedException,
Strings.format("%s. Cause: %s", errorMessage, unwrappedException.getMessage())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the exception is not an ElasticsearchException we'll grab the cause to return to the API call. We've noticed in the case of apache timeout exceptions it would have been useful to return that information directly to the client instead of only logging it.

)
);
}
});
}
Expand All @@ -39,11 +41,7 @@ public static ElasticsearchStatusException createInternalServerError(Throwable e
return new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR, e);
}

public static String constructFailedToSendRequestMessage(@Nullable URI uri, String message) {
if (uri != null) {
return Strings.format("Failed to send %s request to [%s]", message, uri);
}

public static String constructFailedToSendRequestMessage(String message) {
return Strings.format("Failed to send %s request", message);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompl
this.model = Objects.requireNonNull(model);
this.sender = Objects.requireNonNull(sender);
this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search completion");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are to remove the URL parameter.

this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search completion");
this.requestCreator = AlibabaCloudSearchCompletionRequestManager.of(account, model, serviceComponents.threadPool());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public AlibabaCloudSearchEmbeddingsAction(Sender sender, AlibabaCloudSearchEmbed
this.model = Objects.requireNonNull(model);
this.sender = Objects.requireNonNull(sender);
this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search text embeddings");
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search text embeddings");
this.requestCreator = AlibabaCloudSearchEmbeddingsRequestManager.of(account, model, serviceComponents.threadPool());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class AlibabaCloudSearchRerankAction implements ExecutableAction {
public AlibabaCloudSearchRerankAction(Sender sender, AlibabaCloudSearchRerankModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search rerank");
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search rerank");
this.sender = Objects.requireNonNull(sender);
this.requestCreator = AlibabaCloudSearchRerankRequestManager.of(account, model, serviceComponents.threadPool());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class AlibabaCloudSearchSparseAction implements ExecutableAction {
public AlibabaCloudSearchSparseAction(Sender sender, AlibabaCloudSearchSparseModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search sparse embeddings");
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search sparse embeddings");
this.sender = Objects.requireNonNull(sender);
requestCreator = AlibabaCloudSearchSparseRequestManager.of(account, model, serviceComponents.threadPool());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ public ExecutableAction create(AmazonBedrockEmbeddingsModel embeddingsModel, Map
serviceComponents.threadPool(),
timeout
);
var errorMessage = constructFailedToSendRequestMessage(null, "Amazon Bedrock embeddings");
var errorMessage = constructFailedToSendRequestMessage("Amazon Bedrock embeddings");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}

@Override
public ExecutableAction create(AmazonBedrockChatCompletionModel completionModel, Map<String, Object> taskSettings) {
var overriddenModel = AmazonBedrockChatCompletionModel.of(completionModel, taskSettings);
var requestManager = new AmazonBedrockChatCompletionRequestManager(overriddenModel, serviceComponents.threadPool(), timeout);
var errorMessage = constructFailedToSendRequestMessage(null, "Amazon Bedrock completion");
var errorMessage = constructFailedToSendRequestMessage("Amazon Bedrock completion");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public AnthropicActionCreator(Sender sender, ServiceComponents serviceComponents
public ExecutableAction create(AnthropicChatCompletionModel model, Map<String, Object> taskSettings) {
var overriddenModel = AnthropicChatCompletionModel.of(model, taskSettings);
var requestCreator = AnthropicCompletionRequestManager.of(overriddenModel, serviceComponents.threadPool());
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getUri(), ERROR_PREFIX);
var errorMessage = constructFailedToSendRequestMessage(ERROR_PREFIX);
return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public AzureAiStudioActionCreator(Sender sender, ServiceComponents serviceCompon
public ExecutableAction create(AzureAiStudioChatCompletionModel completionModel, Map<String, Object> taskSettings) {
var overriddenModel = AzureAiStudioChatCompletionModel.of(completionModel, taskSettings);
var requestManager = new AzureAiStudioChatCompletionRequestManager(overriddenModel, serviceComponents.threadPool());
var errorMessage = constructFailedToSendRequestMessage(completionModel.uri(), "Azure AI Studio completion");
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio completion");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}

Expand All @@ -46,7 +46,7 @@ public ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = constructFailedToSendRequestMessage(embeddingsModel.uri(), "Azure AI Studio embeddings");
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio embeddings");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ public ExecutableAction create(AzureOpenAiEmbeddingsModel model, Map<String, Obj
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getUri(), "Azure OpenAI embeddings");
var errorMessage = constructFailedToSendRequestMessage("Azure OpenAI embeddings");
return new SenderExecutableAction(sender, requestCreator, errorMessage);
}

@Override
public ExecutableAction create(AzureOpenAiCompletionModel model, Map<String, Object> taskSettings) {
var overriddenModel = AzureOpenAiCompletionModel.of(model, taskSettings);
var requestCreator = new AzureOpenAiCompletionRequestManager(overriddenModel, serviceComponents.threadPool());
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getUri(), COMPLETION_ERROR_PREFIX);
var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, COMPLETION_ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
@Override
public ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = CohereEmbeddingsModel.of(model, taskSettings, inputType);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().getCommonSettings().uri(),
"Cohere embeddings"
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere embeddings");
// TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager
var requestCreator = CohereEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
Expand All @@ -55,21 +52,15 @@ public ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object>
public ExecutableAction create(CohereRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = CohereRerankModel.of(model, taskSettings);
var requestCreator = CohereRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().uri(),
"Cohere rerank"
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere rerank");
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(CohereCompletionModel model, Map<String, Object> taskSettings) {
// no overridden model as task settings are always empty for cohere completion model
var requestManager = CohereCompletionRequestManager.of(model, serviceComponents.threadPool());
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
model.getServiceSettings().uri(),
COMPLETION_ERROR_PREFIX
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
return new SingleInputSenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ public ElasticInferenceServiceActionCreator(
public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) {
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext, inputType);
var errorMessage = constructFailedToSendRequestMessage(
model.uri(),
String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
);
return new SenderExecutableAction(sender, requestManager, errorMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String,
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "Google Vertex AI embeddings");
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Google Vertex AI embeddings");
return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings) {
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "Google Vertex AI rerank");
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Google Vertex AI rerank");
var requestManager = GoogleVertexAiRerankRequestManager.of(model, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponent

@Override
public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings) {
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "IBM WatsonX embeddings");
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM WatsonX embeddings");
return new SenderExecutableAction(
sender,
getEmbeddingsRequestManager(model, serviceComponents.truncator(), serviceComponents.threadPool()),
Expand All @@ -46,10 +46,7 @@ public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Obje
public ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = IbmWatsonxRerankModel.of(model, taskSettings);
var requestCreator = IbmWatsonxRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().uri(),
"Ibm Watsonx rerank"
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Ibm Watsonx rerank");
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,15 @@ public JinaAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
@Override
public ExecutableAction create(JinaAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = JinaAIEmbeddingsModel.of(model, taskSettings, inputType);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().getCommonSettings().uri(),
"JinaAI embeddings"
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("JinaAI embeddings");
var requestCreator = JinaAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(JinaAIRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = JinaAIRerankModel.of(model, taskSettings);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().getCommonSettings().uri(),
"JinaAI rerank"
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("JinaAI rerank");
var requestCreator = JinaAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map<Strin
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = constructFailedToSendRequestMessage(embeddingsModel.uri(), "Mistral embeddings");
var errorMessage = constructFailedToSendRequestMessage("Mistral embeddings");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ public ExecutableAction create(OpenAiEmbeddingsModel model, Map<String, Object>
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getServiceSettings().uri(), "OpenAI embeddings");
var errorMessage = constructFailedToSendRequestMessage("OpenAI embeddings");
return new SenderExecutableAction(sender, requestCreator, errorMessage);
}

@Override
public ExecutableAction create(OpenAiChatCompletionModel model, Map<String, Object> taskSettings) {
var overriddenModel = OpenAiChatCompletionModel.of(model, taskSettings);
var requestCreator = OpenAiCompletionRequestManager.of(overriddenModel, serviceComponents.threadPool());
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getServiceSettings().uri(), COMPLETION_ERROR_PREFIX);
var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, COMPLETION_ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents)
@Override
public ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI embeddings");
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("VoyageAI embeddings");
var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = VoyageAIRerankModel.of(model, taskSettings);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI rerank");
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("VoyageAI rerank");
var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,6 @@ protected void doUnifiedCompletionInfer(
var completionModel = (ElasticInferenceServiceCompletionModel) model;
var overriddenModel = ElasticInferenceServiceCompletionModel.of(completionModel, inputs.getRequest());
var errorMessage = constructFailedToSendRequestMessage(
overriddenModel.uri(),
String.format(Locale.ROOT, "%s completions", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
);

Expand Down
Loading