Skip to content

Commit f0a5e25

Browse files
timgreinelasticsearchmachinejonathan-buttnerdemjened
authored
[8.x] [Inference API] Add node-local rate limiting for the inference API (#120400) (#121251)
* [Inference API] Add node-local rate limiting for the inference API (#120400) * Add node-local rate limiting for the inference API * Fix integration tests by using new LocalStateInferencePlugin instead of InferencePlugin and adjust formatting. * Correct feature flag name * Add more docs, reorganize methods and make some methods package private * Clarify comment in BaseInferenceActionRequest * Fix wrong merge * Fix checkstyle * Fix checkstyle in tests * Check that the service we want to the read the rate limit config for actually exists * [CI] Auto commit changes from spotless * checkStyle apply * Update docs/changelog/120400.yaml * Move rate limit division logic to RequestExecutorService * Spotless apply * Remove debug sout * Adding a few suggestions * Adam feedback * Fix compilation error * [CI] Auto commit changes from spotless * Add BWC test case to InferenceActionRequestTests * Add BWC test case to UnifiedCompletionActionRequestTests * Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java Co-authored-by: Adam Demjen <[email protected]> * Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java Co-authored-by: Adam Demjen <[email protected]> * Remove addressed TODO * Spotless apply * Only use new rate limit specific feature flag * Use ThreadLocalRandom * [CI] Auto commit changes from spotless * Use Randomness.get() * [CI] Auto commit changes from spotless * Fix import * Use ConcurrentHashMap in InferenceServiceNodeLocalRateLimitCalculator * Check for null value in getRateLimitAssignment and remove AtomicReference * Remove newAssignments * Up the default rate limit for completions * Put deprecated feature flag back in * Check feature flag in BaseTransportInferenceAction * spotlessApply * Export inference.common * Do not export inference.common * Provide noop rate limit calculator, if feature flag is disabled * Add proper dependency injection --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: Jonathan Buttner <[email protected]> Co-authored-by: Adam Demjen <[email protected]> * Use .get(0) as getFirst() doesn't exist in 8.18 (probably JDK difference?) --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: Jonathan Buttner <[email protected]> Co-authored-by: Adam Demjen <[email protected]>
1 parent 1261557 commit f0a5e25

29 files changed

+1015
-49
lines changed

docs/changelog/120400.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 120400
2+
summary: "[Inference API] Add node-local rate limiting for the inference API"
3+
area: Machine Learning
4+
type: feature
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ static TransportVersion def(int id) {
176176
public static final TransportVersion RESOURCE_DEPRECATION_CHECKS = def(8_836_00_0);
177177
public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0);
178178
public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_00_0);
179+
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_00_0);
179180

180181
/*
181182
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,56 @@
77

88
package org.elasticsearch.xpack.core.inference.action;
99

10+
import org.elasticsearch.TransportVersions;
1011
import org.elasticsearch.action.ActionRequest;
1112
import org.elasticsearch.common.io.stream.StreamInput;
13+
import org.elasticsearch.common.io.stream.StreamOutput;
1214
import org.elasticsearch.inference.TaskType;
1315

1416
import java.io.IOException;
1517

18+
/**
19+
* Base class for inference action requests. Tracks request routing state to prevent potential routing loops
20+
* and supports both streaming and non-streaming inference operations.
21+
*/
1622
public abstract class BaseInferenceActionRequest extends ActionRequest {
1723

24+
private boolean hasBeenRerouted;
25+
1826
public BaseInferenceActionRequest() {
1927
super();
2028
}
2129

2230
public BaseInferenceActionRequest(StreamInput in) throws IOException {
2331
super(in);
32+
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
33+
this.hasBeenRerouted = in.readBoolean();
34+
} else {
35+
// For backwards compatibility, we treat all inference requests coming from ES nodes having
36+
// a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior.
37+
this.hasBeenRerouted = true;
38+
}
2439
}
2540

2641
public abstract boolean isStreaming();
2742

2843
public abstract TaskType getTaskType();
2944

3045
public abstract String getInferenceEntityId();
46+
47+
public void setHasBeenRerouted(boolean hasBeenRerouted) {
48+
this.hasBeenRerouted = hasBeenRerouted;
49+
}
50+
51+
public boolean hasBeenRerouted() {
52+
return hasBeenRerouted;
53+
}
54+
55+
@Override
56+
public void writeTo(StreamOutput out) throws IOException {
57+
super.writeTo(out);
58+
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
59+
out.writeBoolean(hasBeenRerouted);
60+
}
61+
}
3162
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,29 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn
386386
assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED));
387387
}
388388

389+
public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeenReroutedToTrue() throws IOException {
390+
var instance = new InferenceAction.Request(
391+
TaskType.TEXT_EMBEDDING,
392+
"model",
393+
null,
394+
List.of("input"),
395+
Map.of(),
396+
InputType.UNSPECIFIED,
397+
InferenceAction.Request.DEFAULT_TIMEOUT,
398+
false
399+
);
400+
401+
InferenceAction.Request deserializedInstance = copyWriteable(
402+
instance,
403+
getNamedWriteableRegistry(),
404+
instanceReader(),
405+
TransportVersions.V_8_13_0
406+
);
407+
408+
// Verify that hasBeenRerouted is true after deserializing a request coming from an older transport version
409+
assertTrue(deserializedInstance.hasBeenRerouted());
410+
}
411+
389412
public void testGetInputTypeToWrite_ReturnsIngest_WhenInputTypeIsUnspecified_VersionBeforeUnspecifiedIntroduced() {
390413
assertThat(getInputTypeToWrite(InputType.UNSPECIFIED, TransportVersions.V_8_12_1), is(InputType.INGEST));
391414
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.core.inference.action;
99

1010
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
1112
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1213
import org.elasticsearch.common.io.stream.Writeable;
1314
import org.elasticsearch.core.TimeValue;
@@ -65,6 +66,25 @@ public void testValidation_ReturnsNull_When_TaskType_IsAny() {
6566
assertNull(request.validate());
6667
}
6768

69+
public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeenReroutedToTrue() throws IOException {
70+
var instance = new UnifiedCompletionAction.Request(
71+
"model",
72+
TaskType.ANY,
73+
UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())),
74+
TimeValue.timeValueSeconds(10)
75+
);
76+
77+
UnifiedCompletionAction.Request deserializedInstance = copyWriteable(
78+
instance,
79+
getNamedWriteableRegistry(),
80+
instanceReader(),
81+
TransportVersions.ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION
82+
);
83+
84+
// Verify that hasBeenRerouted is true after deserializing a request coming from an older transport version
85+
assertTrue(deserializedInstance.hasBeenRerouted());
86+
}
87+
6888
@Override
6989
protected UnifiedCompletionAction.Request mutateInstanceForVersion(UnifiedCompletionAction.Request instance, TransportVersion version) {
7090
return instance;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@
7272
import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction;
7373
import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction;
7474
import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter;
75+
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
76+
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
77+
import org.elasticsearch.xpack.inference.common.NoopNodeLocalRateLimitCalculator;
7578
import org.elasticsearch.xpack.inference.common.Truncator;
7679
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender;
7780
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -133,6 +136,7 @@
133136
import java.util.function.Supplier;
134137

135138
import static java.util.Collections.singletonList;
139+
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;
136140

137141
public class InferencePlugin extends Plugin
138142
implements
@@ -229,6 +233,7 @@ public List<RestHandler> getRestHandlers(
229233

230234
@Override
231235
public Collection<?> createComponents(PluginServices services) {
236+
var components = new ArrayList<>();
232237
var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService());
233238
var truncator = new Truncator(settings, services.clusterService());
234239
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator));
@@ -297,20 +302,38 @@ public Collection<?> createComponents(PluginServices services) {
297302

298303
// This must be done after the HttpRequestSenderFactory is created so that the services can get the
299304
// reference correctly
300-
var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
301-
registry.init(services.client());
302-
for (var service : registry.getServices().values()) {
305+
var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext);
306+
serviceRegistry.init(services.client());
307+
for (var service : serviceRegistry.getServices().values()) {
303308
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
304309
}
305-
inferenceServiceRegistry.set(registry);
310+
inferenceServiceRegistry.set(serviceRegistry);
306311

307-
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), registry, modelRegistry);
312+
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry);
308313
shardBulkInferenceActionFilter.set(actionFilter);
309314

310315
var meterRegistry = services.telemetryProvider().getMeterRegistry();
311-
var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
316+
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
317+
318+
components.add(serviceRegistry);
319+
components.add(modelRegistry);
320+
components.add(httpClientManager);
321+
components.add(inferenceStats);
322+
323+
// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
324+
// if the rate limiting feature flags are enabled, otherwise provide noop implementation
325+
InferenceServiceRateLimitCalculator calculator;
326+
if (INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG.isEnabled()) {
327+
calculator = new InferenceServiceNodeLocalRateLimitCalculator(services.clusterService(), serviceRegistry);
328+
} else {
329+
calculator = new NoopNodeLocalRateLimitCalculator();
330+
}
331+
332+
// Add binding for interface -> implementation
333+
components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator));
334+
components.add(calculator);
312335

313-
return List.of(modelRegistry, registry, httpClientManager, stats);
336+
return components;
314337
}
315338

316339
@Override

0 commit comments

Comments
 (0)