Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
Expand Down Expand Up @@ -258,7 +257,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
List.of(query),
TextExpansionConfigUpdate.EMPTY_UPDATE,
false,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
null
);
inferRequest.setHighPriority(true);
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
Expand Down Expand Up @@ -162,7 +161,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
List.of(modelText),
TextExpansionConfigUpdate.EMPTY_UPDATE,
false,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
null
);
inferRequest.setHighPriority(true);
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
Expand Down Expand Up @@ -116,7 +115,7 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
List.of(modelText),
TextEmbeddingConfigUpdate.EMPTY_INSTANCE,
false,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
null
);

inferRequest.setHighPriority(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import java.util.function.Supplier;

Expand Down Expand Up @@ -179,6 +180,13 @@ public class InferencePlugin extends Plugin
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
public static final TimeValue DEFAULT_QUERY_INFERENCE_TIMEOUT = TimeValue.timeValueSeconds(TimeUnit.SECONDS.toSeconds(10));
public static final Setting<TimeValue> QUERY_INFERENCE_TIMEOUT = Setting.timeSetting(
"xpack.inference.semantic_text.inference_timeout",
DEFAULT_QUERY_INFERENCE_TIMEOUT,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

public static final LicensedFeature.Momentary INFERENCE_API_FEATURE = LicensedFeature.momentary(
"inference",
Expand Down Expand Up @@ -311,7 +319,8 @@ public Collection<?> createComponents(PluginServices services) {
serviceComponents.get(),
inferenceServiceSettings,
modelRegistry.get(),
authorizationHandler
authorizationHandler,
context
),
context -> new SageMakerService(
new SageMakerModelBuilder(sageMakerSchemas),
Expand All @@ -321,7 +330,8 @@ public Collection<?> createComponents(PluginServices services) {
),
sageMakerSchemas,
services.threadPool(),
sageMakerConfigurations::getOrCompute
sageMakerConfigurations::getOrCompute,
context
)
)
);
Expand Down Expand Up @@ -383,24 +393,24 @@ public void loadExtensions(ExtensionLoader loader) {

public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
return List.of(
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
context -> new CohereService(httpFactory.get(), serviceComponents.get()),
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()),
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()),
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()),
context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()),
context -> new MistralService(httpFactory.get(), serviceComponents.get()),
context -> new AnthropicService(httpFactory.get(), serviceComponents.get()),
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()),
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get(), context),
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context),
context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context),
context -> new CohereService(httpFactory.get(), serviceComponents.get(), context),
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context),
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context),
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context),
context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get(), context),
context -> new MistralService(httpFactory.get(), serviceComponents.get(), context),
context -> new AnthropicService(httpFactory.get(), serviceComponents.get(), context),
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get(), context),
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get(), context),
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get(), context),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context),
ElasticsearchInternalService::new,
context -> new CustomService(httpFactory.get(), serviceComponents.get())
context -> new CustomService(httpFactory.get(), serviceComponents.get(), context)
);
}

Expand Down Expand Up @@ -495,7 +505,7 @@ public List<Setting<?>> getSettings() {
settings.add(SKIP_VALIDATE_AND_START);
settings.add(INDICES_INFERENCE_BATCH_SIZE);
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());

settings.add(QUERY_INFERENCE_TIMEOUT);
return settings;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
Expand Down Expand Up @@ -237,7 +236,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
List.of(query),
Map.of(),
InputType.INTERNAL_SEARCH,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API,
null,
false
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Nullable;
Expand All @@ -17,12 +18,14 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand All @@ -42,11 +45,17 @@ public abstract class SenderService implements InferenceService {
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION);
private final Sender sender;
private final ServiceComponents serviceComponents;
private final ClusterService clusterService;

public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
public SenderService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
Objects.requireNonNull(factory);
sender = factory.createSender();
this.serviceComponents = Objects.requireNonNull(serviceComponents);
this.clusterService = Objects.requireNonNull(context.clusterService());
}

public Sender getSender() {
Expand All @@ -73,6 +82,9 @@ public void infer(
init();
var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList();
var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream);
if (timeout == null) {
timeout = clusterService.getClusterSettings().get(InferencePlugin.QUERY_INFERENCE_TIMEOUT);
}
doInfer(model, inferenceInput, taskSettings, timeout, listener);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -85,8 +86,12 @@ public class AlibabaCloudSearchService extends SenderService {
InputType.INTERNAL_SEARCH
);

public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public AlibabaCloudSearchService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -93,9 +94,10 @@ public class AmazonBedrockService extends SenderService {
public AmazonBedrockService(
HttpRequestSender.Factory httpSenderFactory,
AmazonBedrockRequestSender.Factory amazonBedrockFactory,
ServiceComponents serviceComponents
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(httpSenderFactory, serviceComponents);
super(httpSenderFactory, serviceComponents, context);
this.amazonBedrockSender = amazonBedrockFactory.createSender();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -58,8 +59,12 @@ public class AnthropicService extends SenderService {

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.COMPLETION);

public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public AnthropicService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -83,8 +84,12 @@ public class AzureAiStudioService extends SenderService {
InputType.INTERNAL_SEARCH
);

public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public AzureAiStudioService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -69,8 +70,12 @@ public class AzureOpenAiService extends SenderService {
private static final String SERVICE_NAME = "Azure OpenAI";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);

public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public AzureOpenAiService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -84,8 +85,12 @@ public class CohereService extends SenderService {
// The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated
// on every request

public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public CohereService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -74,8 +75,12 @@ public class CustomService extends SenderService {
TaskType.COMPLETION
);

public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public CustomService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -58,8 +59,12 @@ public class DeepSeekService extends SenderService {
);
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES_FOR_STREAMING = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);

public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
public DeepSeekService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
InferenceServiceExtension.InferenceServiceFactoryContext context
) {
super(factory, serviceComponents, context);
}

@Override
Expand Down
Loading