Skip to content

Commit 6b7ad7c

Browse files
[Inference API] add service and task type aware rate limiting
1 parent 9f4db73 commit 6b7ad7c

31 files changed

+527
-130
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ static TransportVersion def(int id) {
208208
public static final TransportVersion PROJECT_ID_IN_SNAPSHOT = def(9_040_0_00);
209209
public static final TransportVersion INDEX_STATS_AND_METADATA_INCLUDE_PEAK_WRITE_LOAD = def(9_041_0_00);
210210
public static final TransportVersion REPOSITORIES_METADATA_AS_PROJECT_CUSTOM = def(9_042_0_00);
211+
public static final TransportVersion INFERENCE_REQUEST_SERVICE_TASK_TYPE_RATE_LIMITING = def(9_043_0_00);
211212

212213
/*
213214
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java

Lines changed: 211 additions & 27 deletions
Large diffs are not rendered by default.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.settings.Settings;
1515
import org.elasticsearch.core.TimeValue;
1616
import org.elasticsearch.inference.InferenceServiceResults;
17+
import org.elasticsearch.inference.TaskType;
1718
import org.elasticsearch.threadpool.ThreadPool;
1819
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
1920
import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockRequestExecutorService;
@@ -89,8 +90,8 @@ protected AmazonBedrockRequestSender(
8990
}
9091

9192
@Override
92-
public void updateRateLimitDivisor(int rateLimitDivisor) {
93-
executorService.updateRateLimitDivisor(rateLimitDivisor);
93+
public void updateRateLimitDivisor(String serviceName, TaskType taskType, int rateLimitDivisor) {
94+
executorService.updateRateLimitDivisor(serviceName, taskType, rateLimitDivisor);
9495
}
9596

9697
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestExecutor.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.core.Nullable;
1212
import org.elasticsearch.core.TimeValue;
1313
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.inference.TaskType;
1415
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
1516
import org.elasticsearch.xpack.inference.external.http.sender.RequestManager;
1617

@@ -21,7 +22,7 @@ public interface RequestExecutor {
2122

2223
void shutdown();
2324

24-
void updateRateLimitDivisor(int newDivisor);
25+
void updateRateLimitDivisor(String serviceName, TaskType taskType, int newDivisor);
2526

2627
boolean isShutdown();
2728

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRequestManager.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515
abstract class AlibabaCloudSearchRequestManager extends BaseRequestManager {
1616

1717
protected AlibabaCloudSearchRequestManager(ThreadPool threadPool, AlibabaCloudSearchModel model) {
18-
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
18+
super(
19+
threadPool,
20+
model.getInferenceEntityId(),
21+
AlibabaCloudSearchRequestManager.RateLimitGrouping.of(model),
22+
model.rateLimitServiceSettings().rateLimitSettings(),
23+
model.getConfigurations().getService(),
24+
model.getConfigurations().getTaskType()
25+
);
1926
}
2027

2128
record RateLimitGrouping(int apiKeyHash) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestManager.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,21 @@
1515

1616
import java.util.Objects;
1717

18-
public abstract class AmazonBedrockRequestManager implements RequestManager {
18+
public abstract class AmazonBedrockRequestManager extends BaseRequestManager {
1919

2020
protected final ThreadPool threadPool;
2121
protected final TimeValue timeout;
2222
private final AmazonBedrockModel baseModel;
2323

2424
protected AmazonBedrockRequestManager(AmazonBedrockModel baseModel, ThreadPool threadPool, @Nullable TimeValue timeout) {
25+
super(
26+
threadPool,
27+
baseModel.getInferenceEntityId(),
28+
AmazonBedrockRequestManager.RateLimitGrouping.of(baseModel),
29+
baseModel.rateLimitSettings(),
30+
baseModel.getConfigurations().getService(),
31+
baseModel.getConfigurations().getTaskType()
32+
);
2533
this.baseModel = Objects.requireNonNull(baseModel);
2634
this.threadPool = Objects.requireNonNull(threadPool);
2735
this.timeout = timeout;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicRequestManager.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616
abstract class AnthropicRequestManager extends BaseRequestManager {
1717

1818
protected AnthropicRequestManager(ThreadPool threadPool, AnthropicModel model) {
19-
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
19+
super(
20+
threadPool,
21+
model.getInferenceEntityId(),
22+
AnthropicRequestManager.RateLimitGrouping.of(model),
23+
model.rateLimitServiceSettings().rateLimitSettings(),
24+
model.getConfigurations().getService(),
25+
model.getConfigurations().getTaskType()
26+
);
2027
}
2128

2229
record RateLimitGrouping(int accountHash, int modelIdHash) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioRequestManager.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515
public abstract class AzureAiStudioRequestManager extends BaseRequestManager {
1616

1717
protected AzureAiStudioRequestManager(ThreadPool threadPool, AzureAiStudioModel model) {
18-
super(threadPool, model.getInferenceEntityId(), AzureAiStudioRequestManager.RateLimitGrouping.of(model), model.rateLimitSettings());
18+
super(
19+
threadPool,
20+
model.getInferenceEntityId(),
21+
AzureAiStudioRequestManager.RateLimitGrouping.of(model),
22+
model.rateLimitSettings(),
23+
model.getConfigurations().getService(),
24+
model.getConfigurations().getTaskType()
25+
);
1926
}
2027

2128
record RateLimitGrouping(int targetHashcode) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiRequestManager.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@
1414

1515
public abstract class AzureOpenAiRequestManager extends BaseRequestManager {
1616
protected AzureOpenAiRequestManager(ThreadPool threadPool, AzureOpenAiModel model) {
17-
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
17+
super(
18+
threadPool,
19+
model.getInferenceEntityId(),
20+
RateLimitGrouping.of(model),
21+
model.rateLimitServiceSettings().rateLimitSettings(),
22+
model.getConfigurations().getService(),
23+
model.getConfigurations().getTaskType()
24+
);
1825
}
1926

2027
record RateLimitGrouping(int resourceNameHash, int deploymentIdHash) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.external.http.sender;
99

10+
import org.elasticsearch.inference.TaskType;
1011
import org.elasticsearch.threadpool.ThreadPool;
1112
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1213
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@@ -23,17 +24,28 @@ abstract class BaseRequestManager implements RequestManager {
2324
// the rate and the other inference endpoint's rate will be ignored
2425
private final EndpointGrouping endpointGrouping;
2526
private final RateLimitSettings rateLimitSettings;
27+
private final String service;
28+
private final TaskType taskType;
2629

27-
BaseRequestManager(ThreadPool threadPool, String inferenceEntityId, Object rateLimitGroup, RateLimitSettings rateLimitSettings) {
30+
BaseRequestManager(
31+
ThreadPool threadPool,
32+
String inferenceEntityId,
33+
Object rateLimitGroup,
34+
RateLimitSettings rateLimitSettings,
35+
String service,
36+
TaskType taskType
37+
) {
2838
this.threadPool = Objects.requireNonNull(threadPool);
2939
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
3040

3141
Objects.requireNonNull(rateLimitSettings);
3242
this.endpointGrouping = new EndpointGrouping(Objects.requireNonNull(rateLimitGroup).hashCode(), rateLimitSettings);
3343
this.rateLimitSettings = rateLimitSettings;
44+
this.service = service;
45+
this.taskType = taskType;
3446
}
3547

36-
BaseRequestManager(ThreadPool threadPool, RateLimitGroupingModel rateLimitGroupingModel) {
48+
BaseRequestManager(ThreadPool threadPool, RateLimitGroupingModel rateLimitGroupingModel, String service, TaskType taskType) {
3749
this.threadPool = Objects.requireNonNull(threadPool);
3850
Objects.requireNonNull(rateLimitGroupingModel);
3951

@@ -43,6 +55,8 @@ abstract class BaseRequestManager implements RequestManager {
4355
rateLimitGroupingModel.rateLimitSettings()
4456
);
4557
this.rateLimitSettings = rateLimitGroupingModel.rateLimitSettings();
58+
this.service = service;
59+
this.taskType = taskType;
4660
}
4761

4862
protected void execute(Runnable runnable) {
@@ -64,5 +78,15 @@ public RateLimitSettings rateLimitSettings() {
6478
return rateLimitSettings;
6579
}
6680

81+
@Override
82+
public String service() {
83+
return this.service;
84+
}
85+
86+
@Override
87+
public TaskType taskType() {
88+
return this.taskType;
89+
}
90+
6791
private record EndpointGrouping(int group, RateLimitSettings settings) {}
6892
}

0 commit comments

Comments
 (0)