Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ static TransportVersion def(int id) {
public static final TransportVersion PROJECT_ID_IN_SNAPSHOT = def(9_040_0_00);
public static final TransportVersion INDEX_STATS_AND_METADATA_INCLUDE_PEAK_WRITE_LOAD = def(9_041_0_00);
public static final TransportVersion REPOSITORIES_METADATA_AS_PROJECT_CUSTOM = def(9_042_0_00);
public static final TransportVersion INFERENCE_REQUEST_SERVICE_TASK_TYPE_RATE_LIMITING = def(9_043_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockRequestExecutorService;
Expand Down Expand Up @@ -89,8 +90,8 @@ protected AmazonBedrockRequestSender(
}

@Override
public void updateRateLimitDivisor(int rateLimitDivisor) {
executorService.updateRateLimitDivisor(rateLimitDivisor);
public void updateRateLimitDivisor(String serviceName, TaskType taskType, int rateLimitDivisor) {
executorService.updateRateLimitDivisor(serviceName, taskType, rateLimitDivisor);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.RequestManager;

Expand All @@ -21,7 +22,7 @@ public interface RequestExecutor {

void shutdown();

void updateRateLimitDivisor(int newDivisor);
void updateRateLimitDivisor(String serviceName, TaskType taskType, int newDivisor);

boolean isShutdown();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
abstract class AlibabaCloudSearchRequestManager extends BaseRequestManager {

protected AlibabaCloudSearchRequestManager(ThreadPool threadPool, AlibabaCloudSearchModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
AlibabaCloudSearchRequestManager.RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int apiKeyHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@

import java.util.Objects;

public abstract class AmazonBedrockRequestManager implements RequestManager {
public abstract class AmazonBedrockRequestManager extends BaseRequestManager {

protected final ThreadPool threadPool;
protected final TimeValue timeout;
private final AmazonBedrockModel baseModel;

protected AmazonBedrockRequestManager(AmazonBedrockModel baseModel, ThreadPool threadPool, @Nullable TimeValue timeout) {
super(
threadPool,
baseModel.getInferenceEntityId(),
AmazonBedrockRequestManager.RateLimitGrouping.of(baseModel),
baseModel.rateLimitSettings(),
baseModel.getConfigurations().getService(),
baseModel.getConfigurations().getTaskType()
);
this.baseModel = Objects.requireNonNull(baseModel);
this.threadPool = Objects.requireNonNull(threadPool);
this.timeout = timeout;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
abstract class AnthropicRequestManager extends BaseRequestManager {

protected AnthropicRequestManager(ThreadPool threadPool, AnthropicModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
AnthropicRequestManager.RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int accountHash, int modelIdHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
public abstract class AzureAiStudioRequestManager extends BaseRequestManager {

protected AzureAiStudioRequestManager(ThreadPool threadPool, AzureAiStudioModel model) {
super(threadPool, model.getInferenceEntityId(), AzureAiStudioRequestManager.RateLimitGrouping.of(model), model.rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
AzureAiStudioRequestManager.RateLimitGrouping.of(model),
model.rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int targetHashcode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@

public abstract class AzureOpenAiRequestManager extends BaseRequestManager {
protected AzureOpenAiRequestManager(ThreadPool threadPool, AzureOpenAiModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int resourceNameHash, int deploymentIdHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
Expand All @@ -23,17 +24,28 @@ abstract class BaseRequestManager implements RequestManager {
// the rate and the other inference endpoint's rate will be ignored
private final EndpointGrouping endpointGrouping;
private final RateLimitSettings rateLimitSettings;
private final String service;
private final TaskType taskType;

BaseRequestManager(ThreadPool threadPool, String inferenceEntityId, Object rateLimitGroup, RateLimitSettings rateLimitSettings) {
BaseRequestManager(
ThreadPool threadPool,
String inferenceEntityId,
Object rateLimitGroup,
RateLimitSettings rateLimitSettings,
String service,
TaskType taskType
) {
this.threadPool = Objects.requireNonNull(threadPool);
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);

Objects.requireNonNull(rateLimitSettings);
this.endpointGrouping = new EndpointGrouping(Objects.requireNonNull(rateLimitGroup).hashCode(), rateLimitSettings);
this.rateLimitSettings = rateLimitSettings;
this.service = service;
this.taskType = taskType;
}

BaseRequestManager(ThreadPool threadPool, RateLimitGroupingModel rateLimitGroupingModel) {
BaseRequestManager(ThreadPool threadPool, RateLimitGroupingModel rateLimitGroupingModel, String service, TaskType taskType) {
this.threadPool = Objects.requireNonNull(threadPool);
Objects.requireNonNull(rateLimitGroupingModel);

Expand All @@ -43,6 +55,8 @@ abstract class BaseRequestManager implements RequestManager {
rateLimitGroupingModel.rateLimitSettings()
);
this.rateLimitSettings = rateLimitGroupingModel.rateLimitSettings();
this.service = service;
this.taskType = taskType;
}

protected void execute(Runnable runnable) {
Expand All @@ -64,5 +78,15 @@ public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}

@Override
public String service() {
return this.service;
}

@Override
public TaskType taskType() {
return this.taskType;
}

private record EndpointGrouping(int group, RateLimitSettings settings) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
abstract class CohereRequestManager extends BaseRequestManager {

protected CohereRequestManager(ThreadPool threadPool, CohereModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
CohereRequestManager.RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int apiKeyHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,25 @@ public class DeepSeekRequestManager extends BaseRequestManager {
private final DeepSeekChatCompletionModel model;

public DeepSeekRequestManager(DeepSeekChatCompletionModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), model.rateLimitGroup(), model.rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
DeepSeekRequestManager.RateLimitGrouping.of(model),
model.rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
this.model = Objects.requireNonNull(model);
}

record RateLimitGrouping(int apiKeyHash) {
public static DeepSeekRequestManager.RateLimitGrouping of(DeepSeekChatCompletionModel model) {
Objects.requireNonNull(model);

return new DeepSeekRequestManager.RateLimitGrouping(model.apiKey().hashCode());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it was intentional to limit the rate limit to the max allowed:

So I think we should revert the changes around the api key here.

cc: @prwhelan

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct - there is effectively no rate limiting for DeepSeek

}
}

@Override
public void execute(
InferenceInputs inferenceInputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@ public abstract class ElasticInferenceServiceRequestManager extends BaseRequestM
private final ElasticInferenceServiceRequestMetadata requestMetadata;

protected ElasticInferenceServiceRequestManager(ThreadPool threadPool, ElasticInferenceServiceModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
ElasticInferenceServiceRequestManager.RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
this.requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ public GenericRequestManager(
Function<T, Request> requestCreator,
Class<T> inputType
) {
super(threadPool, rateLimitGroupingModel);
super(
threadPool,
rateLimitGroupingModel,
rateLimitGroupingModel.getConfigurations().getService(),
rateLimitGroupingModel.getConfigurations().getTaskType()
);
this.responseHandler = Objects.requireNonNull(responseHandler);
this.requestCreator = Objects.requireNonNull(requestCreator);
this.inputType = Objects.requireNonNull(inputType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@

public abstract class GoogleAiStudioRequestManager extends BaseRequestManager {
GoogleAiStudioRequestManager(ThreadPool threadPool, GoogleAiStudioModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
GoogleAiStudioRequestManager.RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int modelIdHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,26 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiModel;

import java.util.Objects;

public abstract class GoogleVertexAiRequestManager extends BaseRequestManager {

GoogleVertexAiRequestManager(ThreadPool threadPool, GoogleVertexAiModel model, Object rateLimitGroup) {
super(threadPool, model.getInferenceEntityId(), rateLimitGroup, model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
GoogleVertexAiRequestManager.RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int modelIdHash) {
public static GoogleVertexAiRequestManager.RateLimitGrouping of(GoogleVertexAiModel model) {
Objects.requireNonNull(model);

return new GoogleVertexAiRequestManager.RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
Expand Down Expand Up @@ -111,8 +112,8 @@ public void start() {
}
}

public void updateRateLimitDivisor(int rateLimitDivisor) {
service.updateRateLimitDivisor(rateLimitDivisor);
public void updateRateLimitDivisor(String serviceName, TaskType taskType, int rateLimitDivisor) {
service.updateRateLimitDivisor(serviceName, taskType, rateLimitDivisor);
}

private void waitForStartToComplete() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,14 @@ public static HuggingFaceRequestManager of(
private final Truncator truncator;

private HuggingFaceRequestManager(HuggingFaceModel model, ResponseHandler responseHandler, Truncator truncator, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
HuggingFaceRequestManager.RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
this.model = model;
this.responseHandler = responseHandler;
this.truncator = truncator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@

public abstract class IbmWatsonxRequestManager extends BaseRequestManager {
IbmWatsonxRequestManager(ThreadPool threadPool, IbmWatsonxModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
IbmWatsonxRequestManager.RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int modelIdHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
abstract class JinaAIRequestManager extends BaseRequestManager {

protected JinaAIRequestManager(ThreadPool threadPool, JinaAIModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
JinaAIRequestManager.RateLimitGrouping.of(model),
model.rateLimitServiceSettings().rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
}

record RateLimitGrouping(int apiKeyHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@ private static ResponseHandler createEmbeddingsHandler() {
}

public MistralEmbeddingsRequestManager(MistralEmbeddingsModel model, Truncator truncator, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitSettings());
super(
threadPool,
model.getInferenceEntityId(),
RateLimitGrouping.of(model),
model.rateLimitSettings(),
model.getConfigurations().getService(),
model.getConfigurations().getTaskType()
);
this.model = Objects.requireNonNull(model);
this.truncator = Objects.requireNonNull(truncator);

Expand Down
Loading