Skip to content

Commit 5ca4cb7

Browse files
fix InferenceServiceNodeLocalRateLimitCalculatorTests
1 parent f14ab28 commit 5ca4cb7

File tree

3 files changed

+50
-174
lines changed

3 files changed

+50
-174
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ static TransportVersion def(int id) {
208208
public static final TransportVersion PROJECT_ID_IN_SNAPSHOT = def(9_040_0_00);
209209
public static final TransportVersion INDEX_STATS_AND_METADATA_INCLUDE_PEAK_WRITE_LOAD = def(9_041_0_00);
210210
public static final TransportVersion REPOSITORIES_METADATA_AS_PROJECT_CUSTOM = def(9_042_0_00);
211-
public static final TransportVersion INFERENCE_REQUEST_SERVICE_TASK_TYPE_RATE_LIMITING = def(9_043_0_00);
212211

213212
/*
214213
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java

Lines changed: 6 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -103,148 +103,14 @@ private SortedMap<String, SortedMap<TaskType, MaxNodesPerGroupingStrategy>> crea
103103

104104
MaxNodesPerGroupingStrategy defaultStrategy = (numNodesInCluster) -> DEFAULT_MAX_NODES_PER_GROUPING;
105105

106-
// Alibaba Cloud Search
107-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> alibabaCloudSearchConfigs = new TreeMap<>();
108-
var alibabaCloudSearchService = serviceRegistry.getService(AlibabaCloudSearchService.NAME);
109-
if (alibabaCloudSearchService.isPresent()) {
110-
var alibabaCloudSearchTaskTypes = alibabaCloudSearchService.get().supportedTaskTypes();
111-
for (TaskType taskType : alibabaCloudSearchTaskTypes) {
112-
alibabaCloudSearchConfigs.put(taskType, defaultStrategy);
106+
for (var service : serviceRegistry.getServices().values()) {
107+
TreeMap<TaskType, MaxNodesPerGroupingStrategy> serviceConfigs = new TreeMap<>();
108+
var taskTypes = service.supportedTaskTypes();
109+
for (TaskType taskType : taskTypes) {
110+
serviceConfigs.put(taskType, defaultStrategy);
113111
}
112+
serviceNodeLocalRateLimitConfigs.put(service.name(), serviceConfigs);
114113
}
115-
serviceNodeLocalRateLimitConfigs.put(AlibabaCloudSearchService.NAME, alibabaCloudSearchConfigs);
116-
117-
// Amazon Bedrock
118-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> amazonBedrockConfigs = new TreeMap<>();
119-
var amazonBedrockService = serviceRegistry.getService(AmazonBedrockService.NAME);
120-
if (amazonBedrockService.isPresent()) {
121-
var amazonBedrockTaskTypes = amazonBedrockService.get().supportedTaskTypes();
122-
for (TaskType taskType : amazonBedrockTaskTypes) {
123-
amazonBedrockConfigs.put(taskType, defaultStrategy);
124-
}
125-
}
126-
serviceNodeLocalRateLimitConfigs.put(AmazonBedrockService.NAME, amazonBedrockConfigs);
127-
128-
// Anthropic
129-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> anthropicConfigs = new TreeMap<>();
130-
var anthropicService = serviceRegistry.getService(AnthropicService.NAME);
131-
if (anthropicService.isPresent()) {
132-
var anthropicTaskTypes = anthropicService.get().supportedTaskTypes();
133-
for (TaskType taskType : anthropicTaskTypes) {
134-
anthropicConfigs.put(taskType, defaultStrategy);
135-
}
136-
}
137-
serviceNodeLocalRateLimitConfigs.put(AnthropicService.NAME, anthropicConfigs);
138-
139-
// Azure AI Studio
140-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> azureAiStudioConfigs = new TreeMap<>();
141-
var azureAiStudioService = serviceRegistry.getService(AzureAiStudioService.NAME);
142-
if (azureAiStudioService.isPresent()) {
143-
var azureAiStudioTaskTypes = azureAiStudioService.get().supportedTaskTypes();
144-
for (TaskType taskType : azureAiStudioTaskTypes) {
145-
azureAiStudioConfigs.put(taskType, defaultStrategy);
146-
}
147-
}
148-
serviceNodeLocalRateLimitConfigs.put(AzureAiStudioService.NAME, azureAiStudioConfigs);
149-
150-
// Cohere
151-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> cohereConfigs = new TreeMap<>();
152-
var cohereService = serviceRegistry.getService(CohereService.NAME);
153-
if (cohereService.isPresent()) {
154-
var cohereTaskTypes = cohereService.get().supportedTaskTypes();
155-
for (TaskType taskType : cohereTaskTypes) {
156-
cohereConfigs.put(taskType, defaultStrategy);
157-
}
158-
}
159-
serviceNodeLocalRateLimitConfigs.put(CohereService.NAME, cohereConfigs);
160-
161-
// DeepSeek
162-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> deepSeekConfigs = new TreeMap<>();
163-
var deepSeekService = serviceRegistry.getService(DeepSeekService.NAME);
164-
if (deepSeekService.isPresent()) {
165-
var deepSeekTaskTypes = deepSeekService.get().supportedTaskTypes();
166-
for (TaskType taskType : deepSeekTaskTypes) {
167-
deepSeekConfigs.put(taskType, defaultStrategy);
168-
}
169-
}
170-
serviceNodeLocalRateLimitConfigs.put(DeepSeekService.NAME, deepSeekConfigs);
171-
172-
// Elastic Inference Service (EIS)
173-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> elasticInferenceConfigs = new TreeMap<>();
174-
var elasticInferenceService = serviceRegistry.getService(ElasticInferenceService.NAME);
175-
if (elasticInferenceService.isPresent()) {
176-
var elasticInferenceTaskTypes = elasticInferenceService.get().supportedTaskTypes();
177-
for (TaskType taskType : elasticInferenceTaskTypes) {
178-
elasticInferenceConfigs.put(taskType, defaultStrategy);
179-
}
180-
}
181-
serviceNodeLocalRateLimitConfigs.put(ElasticInferenceService.NAME, elasticInferenceConfigs);
182-
183-
// Google AI Studio
184-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> googleAiStudioConfigs = new TreeMap<>();
185-
var googleAiStudioService = serviceRegistry.getService(GoogleAiStudioService.NAME);
186-
if (googleAiStudioService.isPresent()) {
187-
var googleAiStudioTaskTypes = googleAiStudioService.get().supportedTaskTypes();
188-
for (TaskType taskType : googleAiStudioTaskTypes) {
189-
googleAiStudioConfigs.put(taskType, defaultStrategy);
190-
}
191-
}
192-
serviceNodeLocalRateLimitConfigs.put(GoogleAiStudioService.NAME, googleAiStudioConfigs);
193-
194-
// Google Vertex AI
195-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> googleVertexAiConfigs = new TreeMap<>();
196-
var googleVertexAiService = serviceRegistry.getService(GoogleVertexAiService.NAME);
197-
if (googleVertexAiService.isPresent()) {
198-
var googleVertexAiTaskTypes = googleVertexAiService.get().supportedTaskTypes();
199-
for (TaskType taskType : googleVertexAiTaskTypes) {
200-
googleVertexAiConfigs.put(taskType, defaultStrategy);
201-
}
202-
}
203-
serviceNodeLocalRateLimitConfigs.put(GoogleVertexAiService.NAME, googleVertexAiConfigs);
204-
205-
// HuggingFace
206-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> huggingFaceConfigs = new TreeMap<>();
207-
var huggingFaceService = serviceRegistry.getService(HuggingFaceService.NAME);
208-
if (huggingFaceService.isPresent()) {
209-
var huggingFaceTaskTypes = huggingFaceService.get().supportedTaskTypes();
210-
for (TaskType taskType : huggingFaceTaskTypes) {
211-
huggingFaceConfigs.put(taskType, defaultStrategy);
212-
}
213-
}
214-
serviceNodeLocalRateLimitConfigs.put(HuggingFaceService.NAME, huggingFaceConfigs);
215-
216-
// IBM Watson X
217-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> ibmWatsonxConfigs = new TreeMap<>();
218-
var ibmWatsonxService = serviceRegistry.getService(IbmWatsonxService.NAME);
219-
if (ibmWatsonxService.isPresent()) {
220-
var ibmWatsonxTaskTypes = ibmWatsonxService.get().supportedTaskTypes();
221-
for (TaskType taskType : ibmWatsonxTaskTypes) {
222-
ibmWatsonxConfigs.put(taskType, defaultStrategy);
223-
}
224-
}
225-
serviceNodeLocalRateLimitConfigs.put(IbmWatsonxService.NAME, ibmWatsonxConfigs);
226-
227-
// Jina AI
228-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> jinaAiConfigs = new TreeMap<>();
229-
var jinaAiService = serviceRegistry.getService(JinaAIService.NAME);
230-
if (jinaAiService.isPresent()) {
231-
var jinaAiTaskTypes = jinaAiService.get().supportedTaskTypes();
232-
for (TaskType taskType : jinaAiTaskTypes) {
233-
jinaAiConfigs.put(taskType, defaultStrategy);
234-
}
235-
}
236-
serviceNodeLocalRateLimitConfigs.put(JinaAIService.NAME, jinaAiConfigs);
237-
238-
// Mistral
239-
TreeMap<TaskType, MaxNodesPerGroupingStrategy> mistralConfigs = new TreeMap<>();
240-
var mistralService = serviceRegistry.getService(MistralService.NAME);
241-
if (mistralService.isPresent()) {
242-
var mistralTaskTypes = mistralService.get().supportedTaskTypes();
243-
for (TaskType taskType : mistralTaskTypes) {
244-
mistralConfigs.put(taskType, defaultStrategy);
245-
}
246-
}
247-
serviceNodeLocalRateLimitConfigs.put(MistralService.NAME, mistralConfigs);
248114

249115
return Collections.unmodifiableSortedMap(serviceNodeLocalRateLimitConfigs);
250116
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.test.InternalTestCluster;
1414
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
1515
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
16+
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
1617
import org.elasticsearch.xpack.inference.services.SenderService;
1718
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
1819

@@ -88,13 +89,12 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
8889

8990
// Check assignments for each supported service
9091
for (var service : supportedServices) {
91-
var assignment = calculator.getRateLimitAssignment(service, TaskType.SPARSE_EMBEDDING);
92-
93-
assertNotNull(assignment);
94-
// Should have exactly one responsible node
95-
assertEquals(1, assignment.responsibleNodes().size());
96-
// That node should be our remaining node
97-
assertEquals(nodeLeftInCluster, assignment.responsibleNodes().get(0).getName());
92+
for (var taskType : calculator.serviceNodeLocalRateLimitConfigs().get(service).keySet()) {
93+
var assignment = calculator.getRateLimitAssignment(service, taskType);
94+
assertNotNull(assignment);
95+
assertThat(1, equalTo(assignment.responsibleNodes().size()));
96+
assertEquals(nodeLeftInCluster, assignment.responsibleNodes().get(0).getName());
97+
}
9898
}
9999
}
100100

@@ -110,10 +110,12 @@ public void testGrouping_RespectsMaxNodesPerGroupingLimit() throws Exception {
110110
Set<String> supportedServices = calculator.serviceNodeLocalRateLimitConfigs().keySet();
111111

112112
for (var service : supportedServices) {
113-
var assignment = calculator.getRateLimitAssignment(service, TaskType.SPARSE_EMBEDDING);
113+
for (var taskType : calculator.serviceNodeLocalRateLimitConfigs().get(service).keySet()) {
114+
var assignment = calculator.getRateLimitAssignment(service, taskType);
114115

115-
assertNotNull(assignment);
116-
assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(assignment.responsibleNodes().size()));
116+
assertNotNull(assignment);
117+
assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(assignment.responsibleNodes().size()));
118+
}
117119
}
118120
}
119121

@@ -133,14 +135,14 @@ public void testInitialRateLimitsCalculation_Correct() throws Exception {
133135
var serviceOptional = serviceRegistry.getService(serviceName);
134136
assertTrue(serviceOptional.isPresent());
135137
var service = serviceOptional.get();
136-
137138
if ((service instanceof SenderService senderService)) {
138139
var sender = senderService.getSender();
139-
if (sender instanceof HttpRequestSender) {
140-
var assignment = calculator.getRateLimitAssignment(service.name(), TaskType.SPARSE_EMBEDDING);
141-
142-
assertNotNull(assignment);
143-
assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(assignment.responsibleNodes().size()));
140+
for (var taskType : calculator.serviceNodeLocalRateLimitConfigs().get(serviceName).keySet()) {
141+
if (sender instanceof HttpRequestSender) {
142+
var assignment = calculator.getRateLimitAssignment(service.name(), taskType);
143+
assertNotNull(assignment);
144+
assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(assignment.responsibleNodes().size()));
145+
}
144146
}
145147
}
146148
}
@@ -159,25 +161,30 @@ public void testRateLimits_Decrease_OnNodeJoin() throws Exception {
159161

160162
var serviceNodeLocalRateLimitConfigs = calculator.serviceNodeLocalRateLimitConfigs();
161163

164+
// check initial node assignments
162165
for (var serviceName : serviceNodeLocalRateLimitConfigs.keySet()) {
163166
var configs = serviceNodeLocalRateLimitConfigs.get(serviceName);
164167
for (var taskType : configs.keySet()) {
165168
// Get initial assignments and rate limits
166169
var initialAssignment = calculator.getRateLimitAssignment(serviceName, taskType);
167170
assertEquals(2, initialAssignment.responsibleNodes().size());
171+
}
172+
}
168173

169-
// Add a new node
170-
internalCluster().startNode();
171-
ensureStableCluster(initialNodes + 1);
172-
waitForRateLimitingAssignments(calculator);
174+
// Add a node to update node assignments
175+
internalCluster().startNode();
176+
ensureStableCluster(initialNodes + 1);
177+
waitForRateLimitingAssignments(calculator);
173178

174-
// Get updated assignments
179+
// check updated node assignments
180+
for (var serviceName : serviceNodeLocalRateLimitConfigs.keySet()) {
181+
var configs = serviceNodeLocalRateLimitConfigs.get(serviceName);
182+
for (var taskType : configs.keySet()) {
175183
var updatedAssignment = calculator.getRateLimitAssignment(serviceName, taskType);
176-
177-
// Verify number of responsible nodes increased
178184
assertEquals(3, updatedAssignment.responsibleNodes().size());
179185
}
180186
}
187+
181188
}
182189

183190
public void testRateLimits_Increase_OnNodeLeave() throws Exception {
@@ -191,23 +198,26 @@ public void testRateLimits_Increase_OnNodeLeave() throws Exception {
191198

192199
var serviceNodeLocalRateLimitConfigs = calculator.serviceNodeLocalRateLimitConfigs();
193200

201+
// check initial node assignments
194202
for (var serviceName : serviceNodeLocalRateLimitConfigs.keySet()) {
195203
var configs = serviceNodeLocalRateLimitConfigs.get(serviceName);
196204
for (var taskType : configs.keySet()) {
197-
// Get initial assignments and rate limits
198205
var initialAssignment = calculator.getRateLimitAssignment(serviceName, taskType);
199206
assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(initialAssignment.responsibleNodes().size()));
207+
}
208+
}
200209

201-
// Remove a node
202-
var nodeToRemove = nodeNames.get(numNodes - 1);
203-
internalCluster().stopNode(nodeToRemove);
204-
ensureStableCluster(numNodes - 1);
205-
waitForRateLimitingAssignments(calculator);
210+
// remove a node to update node assignments
211+
var nodeToRemove = nodeNames.get(numNodes - 1);
212+
internalCluster().stopNode(nodeToRemove);
213+
ensureStableCluster(numNodes - 1);
214+
waitForRateLimitingAssignments(calculator);
206215

207-
// Get updated assignments
216+
// check updated node assignments
217+
for (var serviceName : serviceNodeLocalRateLimitConfigs.keySet()) {
218+
var configs = serviceNodeLocalRateLimitConfigs.get(serviceName);
219+
for (var taskType : configs.keySet()) {
208220
var updatedAssignment = calculator.getRateLimitAssignment(serviceName, taskType);
209-
210-
// Verify number of responsible nodes decreased
211221
assertThat(2, equalTo(updatedAssignment.responsibleNodes().size()));
212222
}
213223
}
@@ -241,7 +251,8 @@ private InferenceServiceNodeLocalRateLimitCalculator getCalculatorInstance(Inter
241251

242252
private void waitForRateLimitingAssignments(InferenceServiceNodeLocalRateLimitCalculator calculator) throws Exception {
243253
assertBusy(() -> {
244-
var assignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
254+
var assignment = calculator
255+
.getRateLimitAssignment(TestSparseInferenceServiceExtension.TestInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
245256
assertNotNull(assignment);
246257
assertFalse(assignment.responsibleNodes().isEmpty());
247258
}, RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS, TimeUnit.SECONDS);

0 commit comments

Comments
 (0)