Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;

import java.net.URI;

public class ActionUtils {

public static ActionListener<InferenceServiceResults> wrapFailuresInElasticsearchException(
Expand All @@ -30,7 +27,12 @@ public static ActionListener<InferenceServiceResults> wrapFailuresInElasticsearc
if (unwrappedException instanceof ElasticsearchException esException) {
l.onFailure(esException);
} else {
l.onFailure(createInternalServerError(unwrappedException, errorMessage));
l.onFailure(
createInternalServerError(
unwrappedException,
Strings.format("%s. Cause: %s", errorMessage, unwrappedException.getMessage())
)
);
}
});
}
Expand All @@ -39,11 +41,7 @@ public static ElasticsearchStatusException createInternalServerError(Throwable e
return new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR, e);
}

public static String constructFailedToSendRequestMessage(@Nullable URI uri, String message) {
if (uri != null) {
return Strings.format("Failed to send %s request to [%s]", message, uri);
}

public static String constructFailedToSendRequestMessage(String message) {
return Strings.format("Failed to send %s request", message);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompl
this.model = Objects.requireNonNull(model);
this.sender = Objects.requireNonNull(sender);
this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search completion");
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search completion");
this.requestCreator = AlibabaCloudSearchCompletionRequestManager.of(account, model, serviceComponents.threadPool());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public AlibabaCloudSearchEmbeddingsAction(Sender sender, AlibabaCloudSearchEmbed
this.model = Objects.requireNonNull(model);
this.sender = Objects.requireNonNull(sender);
this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search text embeddings");
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search text embeddings");
this.requestCreator = AlibabaCloudSearchEmbeddingsRequestManager.of(account, model, serviceComponents.threadPool());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class AlibabaCloudSearchRerankAction implements ExecutableAction {
public AlibabaCloudSearchRerankAction(Sender sender, AlibabaCloudSearchRerankModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search rerank");
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search rerank");
this.sender = Objects.requireNonNull(sender);
this.requestCreator = AlibabaCloudSearchRerankRequestManager.of(account, model, serviceComponents.threadPool());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class AlibabaCloudSearchSparseAction implements ExecutableAction {
public AlibabaCloudSearchSparseAction(Sender sender, AlibabaCloudSearchSparseModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search sparse embeddings");
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search sparse embeddings");
this.sender = Objects.requireNonNull(sender);
requestCreator = AlibabaCloudSearchSparseRequestManager.of(account, model, serviceComponents.threadPool());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ public ExecutableAction create(AmazonBedrockEmbeddingsModel embeddingsModel, Map
serviceComponents.threadPool(),
timeout
);
var errorMessage = constructFailedToSendRequestMessage(null, "Amazon Bedrock embeddings");
var errorMessage = constructFailedToSendRequestMessage("Amazon Bedrock embeddings");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}

@Override
public ExecutableAction create(AmazonBedrockChatCompletionModel completionModel, Map<String, Object> taskSettings) {
var overriddenModel = AmazonBedrockChatCompletionModel.of(completionModel, taskSettings);
var requestManager = new AmazonBedrockChatCompletionRequestManager(overriddenModel, serviceComponents.threadPool(), timeout);
var errorMessage = constructFailedToSendRequestMessage(null, "Amazon Bedrock completion");
var errorMessage = constructFailedToSendRequestMessage("Amazon Bedrock completion");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public AnthropicActionCreator(Sender sender, ServiceComponents serviceComponents
public ExecutableAction create(AnthropicChatCompletionModel model, Map<String, Object> taskSettings) {
var overriddenModel = AnthropicChatCompletionModel.of(model, taskSettings);
var requestCreator = AnthropicCompletionRequestManager.of(overriddenModel, serviceComponents.threadPool());
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getUri(), ERROR_PREFIX);
var errorMessage = constructFailedToSendRequestMessage(ERROR_PREFIX);
return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public AzureAiStudioActionCreator(Sender sender, ServiceComponents serviceCompon
public ExecutableAction create(AzureAiStudioChatCompletionModel completionModel, Map<String, Object> taskSettings) {
var overriddenModel = AzureAiStudioChatCompletionModel.of(completionModel, taskSettings);
var requestManager = new AzureAiStudioChatCompletionRequestManager(overriddenModel, serviceComponents.threadPool());
var errorMessage = constructFailedToSendRequestMessage(completionModel.uri(), "Azure AI Studio completion");
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio completion");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}

Expand All @@ -46,7 +46,7 @@ public ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = constructFailedToSendRequestMessage(embeddingsModel.uri(), "Azure AI Studio embeddings");
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio embeddings");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ public ExecutableAction create(AzureOpenAiEmbeddingsModel model, Map<String, Obj
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getUri(), "Azure OpenAI embeddings");
var errorMessage = constructFailedToSendRequestMessage("Azure OpenAI embeddings");
return new SenderExecutableAction(sender, requestCreator, errorMessage);
}

@Override
public ExecutableAction create(AzureOpenAiCompletionModel model, Map<String, Object> taskSettings) {
var overriddenModel = AzureOpenAiCompletionModel.of(model, taskSettings);
var requestCreator = new AzureOpenAiCompletionRequestManager(overriddenModel, serviceComponents.threadPool());
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getUri(), COMPLETION_ERROR_PREFIX);
var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, COMPLETION_ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
@Override
public ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = CohereEmbeddingsModel.of(model, taskSettings, inputType);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().getCommonSettings().uri(),
"Cohere embeddings"
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere embeddings");
// TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager
var requestCreator = CohereEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
Expand All @@ -55,21 +52,15 @@ public ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object>
public ExecutableAction create(CohereRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = CohereRerankModel.of(model, taskSettings);
var requestCreator = CohereRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().uri(),
"Cohere rerank"
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere rerank");
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(CohereCompletionModel model, Map<String, Object> taskSettings) {
// no overridden model as task settings are always empty for cohere completion model
var requestManager = CohereCompletionRequestManager.of(model, serviceComponents.threadPool());
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
model.getServiceSettings().uri(),
COMPLETION_ERROR_PREFIX
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
return new SingleInputSenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ public ElasticInferenceServiceActionCreator(
public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) {
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext, inputType);
var errorMessage = constructFailedToSendRequestMessage(
model.uri(),
String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
);
return new SenderExecutableAction(sender, requestManager, errorMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String,
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "Google Vertex AI embeddings");
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Google Vertex AI embeddings");
return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings) {
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "Google Vertex AI rerank");
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Google Vertex AI rerank");
var requestManager = GoogleVertexAiRerankRequestManager.of(model, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponent

@Override
public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings) {
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "IBM WatsonX embeddings");
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM WatsonX embeddings");
return new SenderExecutableAction(
sender,
getEmbeddingsRequestManager(model, serviceComponents.truncator(), serviceComponents.threadPool()),
Expand All @@ -46,10 +46,7 @@ public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Obje
public ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = IbmWatsonxRerankModel.of(model, taskSettings);
var requestCreator = IbmWatsonxRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().uri(),
"Ibm Watsonx rerank"
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Ibm Watsonx rerank");
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,15 @@ public JinaAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
@Override
public ExecutableAction create(JinaAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = JinaAIEmbeddingsModel.of(model, taskSettings, inputType);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().getCommonSettings().uri(),
"JinaAI embeddings"
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("JinaAI embeddings");
var requestCreator = JinaAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(JinaAIRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = JinaAIRerankModel.of(model, taskSettings);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().getCommonSettings().uri(),
"JinaAI rerank"
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("JinaAI rerank");
var requestCreator = JinaAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map<Strin
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = constructFailedToSendRequestMessage(embeddingsModel.uri(), "Mistral embeddings");
var errorMessage = constructFailedToSendRequestMessage("Mistral embeddings");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ public ExecutableAction create(OpenAiEmbeddingsModel model, Map<String, Object>
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getServiceSettings().uri(), "OpenAI embeddings");
var errorMessage = constructFailedToSendRequestMessage("OpenAI embeddings");
return new SenderExecutableAction(sender, requestCreator, errorMessage);
}

@Override
public ExecutableAction create(OpenAiChatCompletionModel model, Map<String, Object> taskSettings) {
var overriddenModel = OpenAiChatCompletionModel.of(model, taskSettings);
var requestCreator = OpenAiCompletionRequestManager.of(overriddenModel, serviceComponents.threadPool());
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getServiceSettings().uri(), COMPLETION_ERROR_PREFIX);
var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, COMPLETION_ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,6 @@ protected void doUnifiedCompletionInfer(
var completionModel = (ElasticInferenceServiceCompletionModel) model;
var overriddenModel = ElasticInferenceServiceCompletionModel.of(completionModel, inputs.getRequest());
var errorMessage = constructFailedToSendRequestMessage(
overriddenModel.uri(),
String.format(Locale.ROOT, "%s completions", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,7 @@ protected void doInfer(
) {
if (model instanceof GoogleAiStudioCompletionModel completionModel) {
var requestManager = new GoogleAiStudioCompletionRequestManager(completionModel, getServiceComponents().threadPool());
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
completionModel.uri(inputs.stream()),
"Google AI Studio completion"
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Google AI Studio completion");
var action = new SingleInputSenderExecutableAction(
getSender(),
requestManager,
Expand All @@ -299,7 +296,7 @@ protected void doInfer(
getServiceComponents().truncator(),
getServiceComponents().threadPool()
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(embeddingsModel.uri(), "Google AI Studio embeddings");
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Google AI Studio embeddings");
var action = new SenderExecutableAction(getSender(), requestManager, failedToSendRequestErrorMessage);
action.execute(inputs, timeout, listener);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ public void doUnifiedCompletionInfer(

var overriddenModel = OpenAiChatCompletionModel.of(openAiModel, inputs.getRequest());
var requestCreator = OpenAiUnifiedCompletionRequestManager.of(overriddenModel, getServiceComponents().threadPool());
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getServiceSettings().uri(), COMPLETION_ERROR_PREFIX);
var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage);

action.execute(inputs, timeout, listener);
Expand Down
Loading