Skip to content
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
3b418ae
wip
ymao1 Feb 11, 2025
5d020a7
wip
ymao1 Feb 21, 2025
cceb308
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 3, 2025
b515e84
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 3, 2025
625a5a2
[CI] Auto commit changes from spotless
Mar 3, 2025
c064e26
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 4, 2025
f6dbf51
Adding internal input types
ymao1 Mar 4, 2025
6a25d08
Merge branch 'es-117856' of github.com:ymao1/elasticsearch into es-11…
ymao1 Mar 4, 2025
b7c2481
[CI] Auto commit changes from spotless
Mar 4, 2025
2d2b6db
Throwing validation exception for services that don't support input type
ymao1 Mar 4, 2025
3dc38f7
Merge branch 'es-117856' of github.com:ymao1/elasticsearch into es-11…
ymao1 Mar 4, 2025
356d546
linting
ymao1 Mar 4, 2025
a143af1
hugging face
ymao1 Mar 4, 2025
7f20d32
voyage ai
ymao1 Mar 4, 2025
d6c2464
google ai studio
ymao1 Mar 5, 2025
f63e852
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 5, 2025
f9962ce
bedrock updates
ymao1 Mar 5, 2025
2905643
Fixing tests
ymao1 Mar 5, 2025
f47538c
Fixing tests
ymao1 Mar 6, 2025
5b4ee68
Fixing tests
ymao1 Mar 6, 2025
be17339
bedrock updates
ymao1 Mar 6, 2025
bba8ac5
elasticsearch
ymao1 Mar 6, 2025
bf32efd
azure openai
ymao1 Mar 6, 2025
3dbeaff
Merge
ymao1 Mar 6, 2025
c3c9cef
[CI] Auto commit changes from spotless
Mar 6, 2025
cc96e9a
Refactoring all the things
ymao1 Mar 8, 2025
fbc8791
Merge branch 'es-117856' of github.com:ymao1/elasticsearch into es-11…
ymao1 Mar 8, 2025
e6e877e
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 8, 2025
215b4b7
[CI] Auto commit changes from spotless
Mar 8, 2025
de1f8fc
Everything compiles
ymao1 Mar 8, 2025
4838727
Merge branch 'es-117856' of github.com:ymao1/elasticsearch into es-11…
ymao1 Mar 8, 2025
1946d49
spotless
ymao1 Mar 8, 2025
f1bbcc6
external actions tests
ymao1 Mar 9, 2025
ea4ad64
external request tests
ymao1 Mar 10, 2025
3a3e946
service tests
ymao1 Mar 10, 2025
20c28b0
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 10, 2025
84bc649
Fixing integration tests
ymao1 Mar 10, 2025
4ad3214
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 11, 2025
9f744db
Cleanup
ymao1 Mar 11, 2025
4564680
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 11, 2025
aef7da8
Update docs/changelog/122638.yaml
ymao1 Mar 11, 2025
6cbd41e
Merge
ymao1 Mar 12, 2025
f364ad3
Merge branch 'es-117856' of github.com:ymao1/elasticsearch into es-11…
ymao1 Mar 12, 2025
113cf3d
Merging in main
ymao1 Mar 17, 2025
cb8d337
Cleanup
ymao1 Mar 17, 2025
ef6074b
Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/…
ymao1 Mar 17, 2025
61c8d7b
Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/…
ymao1 Mar 17, 2025
61091c3
PR feedback
ymao1 Mar 17, 2025
588c6fb
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion server/src/main/java/org/elasticsearch/inference/InputType.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import java.util.Locale;

import static org.elasticsearch.core.Strings.format;

/**
* Defines the type of request, whether the request is to ingest a document or search for a document.
*/
Expand All @@ -19,7 +21,11 @@ public enum InputType {
SEARCH,
UNSPECIFIED,
CLASSIFICATION,
CLUSTERING;
CLUSTERING,

// Use the following enums when calling the inference API internally
INTERNAL_SEARCH,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸ‘

INTERNAL_INGEST;

@Override
public String toString() {
Expand All @@ -29,4 +35,20 @@ public String toString() {
public static InputType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}

public static InputType fromRestString(String name) {
var inputType = InputType.fromString(name);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that the values passed from the REST API do not include the internal values

if (inputType == InputType.INTERNAL_INGEST || inputType == InputType.INTERNAL_SEARCH) {
throw new IllegalArgumentException(format("Unrecognized input_type [%s]", inputType));
}
return inputType;
}

public static boolean isInternalTypeOrUnspecified(InputType inputType) {
return inputType == InputType.INTERNAL_INGEST || inputType == InputType.INTERNAL_SEARCH || inputType == InputType.UNSPECIFIED;
}

public static boolean isSpecified(InputType inputType) {
return inputType != null && inputType != InputType.UNSPECIFIED;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ public static class Request extends BaseInferenceActionRequest {

public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(30);
public static final ParseField INPUT = new ParseField("input");
public static final ParseField INPUT_TYPE = new ParseField("input_type");
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
public static final ParseField QUERY = new ParseField("query");
public static final ParseField TIMEOUT = new ParseField("timeout");

static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
static {
PARSER.declareStringArray(Request.Builder::setInput, INPUT);
PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE);
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
PARSER.declareString(Request.Builder::setQuery, QUERY);
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
Expand All @@ -78,8 +80,6 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
Request.Builder builder = PARSER.apply(parser, null);
builder.setInferenceEntityId(inferenceEntityId);
builder.setTaskType(taskType);
// For rest requests we won't know what the input type is
builder.setInputType(InputType.UNSPECIFIED);
return builder;
}

Expand Down Expand Up @@ -199,6 +199,14 @@ public ActionRequestValidationException validate() {
}
}

if (taskType.equals(TaskType.TEXT_EMBEDDING) == false
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throws validation exception if input_type is used with task types other than text-embedding or if task type is not specified

&& taskType.equals(TaskType.ANY) == false
&& (inputType != null && InputType.isInternalTypeOrUnspecified(inputType) == false)) {
var e = new ActionRequestValidationException();
e.addValidationError(format("Field [input_type] cannot be specified for task type [%s]", taskType));
return e;
}

return null;
}

Expand Down Expand Up @@ -294,6 +302,11 @@ public Builder setInputType(InputType inputType) {
return this;
}

public Builder setInputType(String inputType) {
this.inputType = InputType.fromRestString(inputType);
return this;
}

public Builder setTaskSettings(Map<String, Object> taskSettings) {
this.taskSettings = taskSettings;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,76 @@ public void testValidation_Rerank_Empty() {
assertThat(queryEmptyError.getMessage(), is("Validation Failed: 1: Field [query] cannot be empty for task type [rerank];"));
}

public void testValidation_Rerank_WithInputType() {
InferenceAction.Request request = new InferenceAction.Request(
TaskType.RERANK,
"model",
"query",
List.of("input"),
null,
InputType.SEARCH,
null,
false
);
ActionRequestValidationException queryError = request.validate();
assertNotNull(queryError);
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [rerank];"));
}

public void testValidation_SparseEmbedding_WithInputType() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.SPARSE_EMBEDDING,
"model",
"",
List.of("input"),
null,
InputType.SEARCH,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(
queryError.getMessage(),
is("Validation Failed: 1: Field [input_type] cannot be specified for task type [sparse_embedding];")
);
}

public void testValidation_Completion_WithInputType() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.COMPLETION,
"model",
"",
List.of("input"),
null,
InputType.SEARCH,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [completion];"));
}

public void testValidation_ChatCompletion_WithInputType() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.CHAT_COMPLETION,
"model",
"",
List.of("input"),
null,
InputType.SEARCH,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(
queryError.getMessage(),
is("Validation Failed: 1: Field [input_type] cannot be specified for task type [chat_completion];")
);
}

public void testParseRequest_DefaultsInputTypeToIngest() throws IOException {
String singleInputRequest = """
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
Expand All @@ -32,15 +31,15 @@ public AlibabaCloudSearchActionCreator(Sender sender, ServiceComponents serviceC
}

@Override
public ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = AlibabaCloudSearchEmbeddingsModel.of(model, taskSettings, inputType);
public ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map<String, Object> taskSettings) {
var overriddenModel = AlibabaCloudSearchEmbeddingsModel.of(model, taskSettings);

return new AlibabaCloudSearchEmbeddingsAction(sender, overriddenModel, serviceComponents);
}

@Override
public ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = AlibabaCloudSearchSparseModel.of(model, taskSettings, inputType);
public ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings) {
var overriddenModel = AlibabaCloudSearchSparseModel.of(model, taskSettings);

return new AlibabaCloudSearchSparseAction(sender, overriddenModel, serviceComponents);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
Expand All @@ -17,9 +16,9 @@
import java.util.Map;

public interface AlibabaCloudSearchActionVisitor {
ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map<String, Object> taskSettings);

ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings, InputType inputType);
ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings);

ExecutableAction create(AlibabaCloudSearchRerankModel model, Map<String, Object> taskSettings);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchCompletionRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
Expand Down Expand Up @@ -51,7 +51,7 @@ public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompl

@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
if (inferenceInputs instanceof DocumentsOnlyInput == false) {
if (inferenceInputs instanceof EmbeddingsInput == false) {
listener.onFailure(
new ElasticsearchStatusException(
format("Invalid inference input type, task type [%s] do not support Field [query]", TaskType.COMPLETION),
Expand All @@ -61,7 +61,7 @@ public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionLi
return;
}

var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs;
var docsOnlyInput = (EmbeddingsInput) inferenceInputs;
if (docsOnlyInput.getInputs().size() % 2 == 0) {
listener.onFailure(
new ElasticsearchStatusException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.action.cohere;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
Expand Down Expand Up @@ -40,8 +39,8 @@ public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
}

@Override
public ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = CohereEmbeddingsModel.of(model, taskSettings, inputType);
public ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings) {
var overriddenModel = CohereEmbeddingsModel.of(model, taskSettings);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere embeddings");
// TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager
var requestCreator = CohereEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.action.cohere;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
Expand All @@ -16,7 +15,7 @@
import java.util.Map;

public interface CohereActionVisitor {
ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings);

ExecutableAction create(CohereRerankModel model, Map<String, Object> taskSettings);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.action.elastic;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.ElasticInferenceServiceSparseEmbeddingsRequestManager;
Expand All @@ -30,23 +29,15 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer

private final TraceContext traceContext;

private final InputType inputType;

public ElasticInferenceServiceActionCreator(
Sender sender,
ServiceComponents serviceComponents,
TraceContext traceContext,
InputType inputType
) {
public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents, TraceContext traceContext) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
this.traceContext = traceContext;
this.inputType = inputType;
}

@Override
public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) {
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext, inputType);
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext);
var errorMessage = constructFailedToSendRequestMessage(
String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.action.googlevertexai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.GoogleVertexAiEmbeddingsRequestManager;
Expand All @@ -34,8 +33,8 @@ public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceCompo
}

@Override
public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, taskSettings, inputType);
public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings) {
var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, taskSettings);
var requestManager = new GoogleVertexAiEmbeddingsRequestManager(
overriddenModel,
serviceComponents.truncator(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.action.googlevertexai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
Expand All @@ -16,7 +15,7 @@

public interface GoogleVertexAiActionVisitor {

ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings);

ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;

public interface HuggingFaceActionVisitor {
ExecutableAction create(HuggingFaceEmbeddingsModel mode);
ExecutableAction create(HuggingFaceEmbeddingsModel model);

ExecutableAction create(HuggingFaceElserModel mode);
ExecutableAction create(HuggingFaceElserModel model);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.action.jinaai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.JinaAIEmbeddingsRequestManager;
Expand Down Expand Up @@ -35,8 +34,8 @@ public JinaAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
}

@Override
public ExecutableAction create(JinaAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = JinaAIEmbeddingsModel.of(model, taskSettings, inputType);
public ExecutableAction create(JinaAIEmbeddingsModel model, Map<String, Object> taskSettings) {
var overriddenModel = JinaAIEmbeddingsModel.of(model, taskSettings);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("JinaAI embeddings");
var requestCreator = JinaAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@

package org.elasticsearch.xpack.inference.external.action.jinaai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;

import java.util.Map;

public interface JinaAIActionVisitor {
ExecutableAction create(JinaAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
ExecutableAction create(JinaAIEmbeddingsModel model, Map<String, Object> taskSettings);

ExecutableAction create(JinaAIRerankModel model, Map<String, Object> taskSettings);
}
Loading