Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
3b418ae
wip
ymao1 Feb 11, 2025
5d020a7
wip
ymao1 Feb 21, 2025
cceb308
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 3, 2025
b515e84
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 3, 2025
625a5a2
[CI] Auto commit changes from spotless
Mar 3, 2025
c064e26
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 4, 2025
f6dbf51
Adding internal input types
ymao1 Mar 4, 2025
6a25d08
Merge branch 'es-117856' of github.com:ymao1/elasticsearch into es-11…
ymao1 Mar 4, 2025
b7c2481
[CI] Auto commit changes from spotless
Mar 4, 2025
2d2b6db
Throwing validation exception for services that don't support input type
ymao1 Mar 4, 2025
3dc38f7
Merge branch 'es-117856' of github.com:ymao1/elasticsearch into es-11…
ymao1 Mar 4, 2025
356d546
linting
ymao1 Mar 4, 2025
a143af1
hugging face
ymao1 Mar 4, 2025
7f20d32
voyage ai
ymao1 Mar 4, 2025
d6c2464
google ai studio
ymao1 Mar 5, 2025
f63e852
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 5, 2025
f9962ce
bedrock updates
ymao1 Mar 5, 2025
2905643
Fixing tests
ymao1 Mar 5, 2025
f47538c
Fixing tests
ymao1 Mar 6, 2025
5b4ee68
Fixing tests
ymao1 Mar 6, 2025
be17339
bedrock updates
ymao1 Mar 6, 2025
bba8ac5
elasticsearch
ymao1 Mar 6, 2025
bf32efd
azure openai
ymao1 Mar 6, 2025
3dbeaff
Merge
ymao1 Mar 6, 2025
c3c9cef
[CI] Auto commit changes from spotless
Mar 6, 2025
cc96e9a
Refactoring all the things
ymao1 Mar 8, 2025
fbc8791
Merge branch 'es-117856' of github.com:ymao1/elasticsearch into es-11…
ymao1 Mar 8, 2025
e6e877e
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 8, 2025
215b4b7
[CI] Auto commit changes from spotless
Mar 8, 2025
de1f8fc
Everything compiles
ymao1 Mar 8, 2025
4838727
Merge branch 'es-117856' of github.com:ymao1/elasticsearch into es-11…
ymao1 Mar 8, 2025
1946d49
spotless
ymao1 Mar 8, 2025
f1bbcc6
external actions tests
ymao1 Mar 9, 2025
ea4ad64
external request tests
ymao1 Mar 10, 2025
3a3e946
service tests
ymao1 Mar 10, 2025
20c28b0
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 10, 2025
84bc649
Fixing integration tests
ymao1 Mar 10, 2025
4ad3214
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 11, 2025
9f744db
Cleanup
ymao1 Mar 11, 2025
4564680
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 11, 2025
aef7da8
Update docs/changelog/122638.yaml
ymao1 Mar 11, 2025
6cbd41e
Merge
ymao1 Mar 12, 2025
f364ad3
Merge branch 'es-117856' of github.com:ymao1/elasticsearch into es-11…
ymao1 Mar 12, 2025
113cf3d
Merging in main
ymao1 Mar 17, 2025
cb8d337
Cleanup
ymao1 Mar 17, 2025
ef6074b
Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/…
ymao1 Mar 17, 2025
61c8d7b
Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/…
ymao1 Mar 17, 2025
61091c3
PR feedback
ymao1 Mar 17, 2025
588c6fb
Merge branch 'main' of github.com:elastic/elasticsearch into es-117856
ymao1 Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion server/src/main/java/org/elasticsearch/inference/InputType.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import java.util.Locale;

import static org.elasticsearch.core.Strings.format;

/**
* Defines the type of request, whether the request is to ingest a document or search for a document.
*/
Expand All @@ -19,7 +21,9 @@ public enum InputType {
SEARCH,
UNSPECIFIED,
CLASSIFICATION,
CLUSTERING;
CLUSTERING,
INTERNAL_SEARCH,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸ‘

INTERNAL_INGEST;

@Override
public String toString() {
Expand All @@ -29,4 +33,12 @@ public String toString() {
public static InputType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}

public static InputType fromRestString(String name) {
var inputType = InputType.fromString(name);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that the values passed from the REST API do not include the internal values

if (inputType == InputType.INTERNAL_INGEST || inputType == InputType.INTERNAL_SEARCH) {
throw new IllegalArgumentException(format("Unrecognized input_type [%s]", inputType));
}
return inputType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ public static class Request extends BaseInferenceActionRequest {

public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(30);
public static final ParseField INPUT = new ParseField("input");
public static final ParseField INPUT_TYPE = new ParseField("input_type");
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
public static final ParseField QUERY = new ParseField("query");
public static final ParseField TIMEOUT = new ParseField("timeout");

static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
static {
PARSER.declareStringArray(Request.Builder::setInput, INPUT);
PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE);
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
PARSER.declareString(Request.Builder::setQuery, QUERY);
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
Expand All @@ -78,8 +80,6 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
Request.Builder builder = PARSER.apply(parser, null);
builder.setInferenceEntityId(inferenceEntityId);
builder.setTaskType(taskType);
// For rest requests we won't know what the input type is
builder.setInputType(InputType.UNSPECIFIED);
return builder;
}

Expand Down Expand Up @@ -199,6 +199,12 @@ public ActionRequestValidationException validate() {
}
}

if (taskType.equals(TaskType.TEXT_EMBEDDING) == false && inputType != null) {
var e = new ActionRequestValidationException();
e.addValidationError(format("Field [input_type] cannot be specified for task type [%s]", taskType));
return e;
}

return null;
}

Expand Down Expand Up @@ -294,6 +300,11 @@ public Builder setInputType(InputType inputType) {
return this;
}

public Builder setInputType(String inputType) {
this.inputType = InputType.fromRestString(inputType);
return this;
}

public Builder setTaskSettings(Map<String, Object> taskSettings) {
this.taskSettings = taskSettings;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,76 @@ public void testValidation_Rerank_Empty() {
assertThat(queryEmptyError.getMessage(), is("Validation Failed: 1: Field [query] cannot be empty for task type [rerank];"));
}

public void testValidation_Rerank_WithInputType() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.RERANK,
"model",
"",
List.of("input"),
null,
InputType.SEARCH,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [rerank];"));
}

public void testValidation_SparseEmbedding_WithInputType() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.SPARSE_EMBEDDING,
"model",
"",
List.of("input"),
null,
InputType.SEARCH,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(
queryError.getMessage(),
is("Validation Failed: 1: Field [input_type] cannot be specified for task type [sparse_embedding];")
);
}

public void testValidation_Completion_WithInputType() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.COMPLETION,
"model",
"",
List.of("input"),
null,
InputType.SEARCH,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [completion];"));
}

public void testValidation_ChatCompletion_WithInputType() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.CHAT_COMPLETION,
"model",
"",
List.of("input"),
null,
InputType.SEARCH,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(
queryError.getMessage(),
is("Validation Failed: 1: Field [input_type] cannot be specified for task type [chat_completion];")
);
}

public void testParseRequest_DefaultsInputTypeToIngest() throws IOException {
String singleInputRequest = """
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockChatCompletionRequestManager;
Expand All @@ -35,8 +36,8 @@ public AmazonBedrockActionCreator(Sender sender, ServiceComponents serviceCompon
}

@Override
public ExecutableAction create(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
var overriddenModel = AmazonBedrockEmbeddingsModel.of(embeddingsModel, taskSettings);
public ExecutableAction create(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = AmazonBedrockEmbeddingsModel.of(embeddingsModel, taskSettings, inputType);
var requestManager = new AmazonBedrockEmbeddingsRequestManager(
overriddenModel,
serviceComponents.truncator(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

package org.elasticsearch.xpack.inference.external.action.amazonbedrock;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel;

import java.util.Map;

public interface AmazonBedrockActionVisitor {
ExecutableAction create(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings);
ExecutableAction create(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings, InputType inputType);

ExecutableAction create(AmazonBedrockChatCompletionModel completionModel, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.external.action.azureaistudio;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.AzureAiStudioChatCompletionRequestManager;
Expand Down Expand Up @@ -39,8 +40,8 @@ public ExecutableAction create(AzureAiStudioChatCompletionModel completionModel,
}

@Override
public ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
var overriddenModel = AzureAiStudioEmbeddingsModel.of(embeddingsModel, taskSettings);
public ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = AzureAiStudioEmbeddingsModel.of(embeddingsModel, taskSettings, inputType);
var requestManager = new AzureAiStudioEmbeddingsRequestManager(
overriddenModel,
serviceComponents.truncator(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

package org.elasticsearch.xpack.inference.external.action.azureaistudio;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;

import java.util.Map;

public interface AzureAiStudioActionVisitor {
ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings);
ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings, InputType inputType);

ExecutableAction create(AzureAiStudioChatCompletionModel completionModel, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.external.action.ibmwatsonx;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
Expand All @@ -33,11 +34,12 @@ public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponent
}

@Override
public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings) {
public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = IbmWatsonxEmbeddingsModel.of(model, taskSettings, inputType);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM WatsonX embeddings");
return new SenderExecutableAction(
sender,
getEmbeddingsRequestManager(model, serviceComponents.truncator(), serviceComponents.threadPool()),
getEmbeddingsRequestManager(overriddenModel, serviceComponents.truncator(), serviceComponents.threadPool()),
failedToSendRequestErrorMessage
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

package org.elasticsearch.xpack.inference.external.action.ibmwatsonx;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;

import java.util.Map;

public interface IbmWatsonxActionVisitor {
ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings);
ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);

ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.external.action.mistral;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.MistralEmbeddingsRequestManager;
Expand All @@ -29,9 +30,10 @@ public MistralActionCreator(Sender sender, ServiceComponents serviceComponents)
}

@Override
public ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
public ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = MistralEmbeddingsModel.of(embeddingsModel, taskSettings, inputType);
var requestManager = new MistralEmbeddingsRequestManager(
embeddingsModel,
overriddenModel,
serviceComponents.truncator(),
serviceComponents.threadPool()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

package org.elasticsearch.xpack.inference.external.action.mistral;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;

import java.util.Map;

public interface MistralActionVisitor {
ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings);
ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings, InputType inputType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.external.action.openai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
Expand Down Expand Up @@ -36,8 +37,8 @@ public OpenAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
}

@Override
public ExecutableAction create(OpenAiEmbeddingsModel model, Map<String, Object> taskSettings) {
var overriddenModel = OpenAiEmbeddingsModel.of(model, taskSettings);
public ExecutableAction create(OpenAiEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = OpenAiEmbeddingsModel.of(model, taskSettings, inputType);
var requestCreator = OpenAiEmbeddingsRequestManager.of(
overriddenModel,
serviceComponents.truncator(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

package org.elasticsearch.xpack.inference.external.action.openai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;

import java.util.Map;

public interface OpenAiActionVisitor {
ExecutableAction create(OpenAiEmbeddingsModel model, Map<String, Object> taskSettings);
ExecutableAction create(OpenAiEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);

ExecutableAction create(OpenAiChatCompletionModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ static String convertToString(InputType inputType) {
}

return switch (inputType) {
case INGEST -> SEARCH_DOCUMENT;
case SEARCH -> SEARCH_QUERY;
case INGEST, INTERNAL_INGEST -> SEARCH_DOCUMENT;
case SEARCH, INTERNAL_SEARCH -> SEARCH_QUERY;
default -> {
assert false : invalidInputTypeMessage(inputType);
yield null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
// default for testing
static String convertToString(InputType inputType) {
return switch (inputType) {
case INGEST -> SEARCH_DOCUMENT;
case SEARCH -> SEARCH_QUERY;
case INGEST, INTERNAL_INGEST -> SEARCH_DOCUMENT;
case SEARCH, INTERNAL_SEARCH -> SEARCH_QUERY;
case CLASSIFICATION -> CLASSIFICATION;
case CLUSTERING -> CLUSTERING;
default -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

static String convertToString(InputType inputType) {
return switch (inputType) {
case INGEST -> RETRIEVAL_DOCUMENT_TASK_TYPE;
case SEARCH -> RETRIEVAL_QUERY_TASK_TYPE;
case INGEST, INTERNAL_INGEST -> RETRIEVAL_DOCUMENT_TASK_TYPE;
case SEARCH, INTERNAL_SEARCH -> RETRIEVAL_QUERY_TASK_TYPE;
case CLASSIFICATION -> CLASSIFICATION_TASK_TYPE;
case CLUSTERING -> CLUSTERING_TASK_TYPE;
default -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
// default for testing
static String convertToString(InputType inputType) {
return switch (inputType) {
case INGEST -> SEARCH_DOCUMENT;
case SEARCH -> SEARCH_QUERY;
case INGEST, INTERNAL_INGEST -> SEARCH_DOCUMENT;
case SEARCH, INTERNAL_SEARCH -> SEARCH_QUERY;
case CLASSIFICATION -> CLASSIFICATION;
case CLUSTERING -> CLUSTERING;
default -> {
Expand Down
Loading