Skip to content

Commit 4642f15

Browse files
authored
[Inference API] Wait for assignments to happen in InferenceServiceNodeLocalRateLimitCalculatorTests. (#121379)
1 parent 9f572a3 commit 4642f15

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

muted-tests.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,6 @@ tests:
360360
- class: org.elasticsearch.xpack.security.CoreWithSecurityClientYamlTestSuiteIT
361361
method: test {yaml=indices.get_alias/10_basic/Get aliases via /*/_alias/}
362362
issue: https://github.com/elastic/elasticsearch/issues/121290
363-
- class: org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculatorTests
364-
issue: https://github.com/elastic/elasticsearch/issues/121294
365363
- class: org.elasticsearch.env.NodeEnvironmentTests
366364
method: testGetBestDowngradeVersion
367365
issue: https://github.com/elastic/elasticsearch/issues/121316

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import org.elasticsearch.xpack.inference.services.SenderService;
1616
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
1717

18-
import java.io.IOException;
1918
import java.util.Arrays;
2019
import java.util.Collection;
2120
import java.util.Set;
21+
import java.util.concurrent.TimeUnit;
2222

2323
import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.DEFAULT_MAX_NODES_PER_GROUPING;
2424
import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS;
@@ -27,38 +27,36 @@
2727
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 0)
2828
public class InferenceServiceNodeLocalRateLimitCalculatorTests extends ESIntegTestCase {
2929

30+
private static final Integer RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS = 15;
31+
3032
public void setUp() throws Exception {
3133
super.setUp();
3234
}
3335

34-
public void testInitialClusterGrouping_Correct() {
36+
public void testInitialClusterGrouping_Correct() throws Exception {
3537
// Start with 2-5 nodes
3638
var numNodes = randomIntBetween(2, 5);
3739
var nodeNames = internalCluster().startNodes(numNodes);
3840
ensureStableCluster(numNodes);
3941

40-
RateLimitAssignment firstAssignment = null;
42+
var firstCalculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
43+
waitForRateLimitingAssignments(firstCalculator);
4144

42-
for (String nodeName : nodeNames) {
43-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeName);
45+
RateLimitAssignment firstAssignment = firstCalculator.getRateLimitAssignment(
46+
ElasticInferenceService.NAME,
47+
TaskType.SPARSE_EMBEDDING
48+
);
4449

45-
// Check first node's assignments
46-
if (firstAssignment == null) {
47-
// Get assignment for a specific service (e.g., EIS)
48-
firstAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
49-
50-
assertNotNull(firstAssignment);
51-
// Verify there are assignments for this service
52-
assertFalse(firstAssignment.responsibleNodes().isEmpty());
53-
} else {
54-
// Verify other nodes see the same assignment
55-
var currentAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
56-
assertEquals(firstAssignment, currentAssignment);
57-
}
50+
// Verify that all other nodes land on the same assignment
51+
for (String nodeName : nodeNames.subList(1, nodeNames.size())) {
52+
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeName);
53+
waitForRateLimitingAssignments(calculator);
54+
var currentAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
55+
assertEquals(firstAssignment, currentAssignment);
5856
}
5957
}
6058

61-
public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws IOException {
59+
public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws Exception {
6260
// Start with 3-5 nodes
6361
var numNodes = randomIntBetween(3, 5);
6462
var nodeNames = internalCluster().startNodes(numNodes);
@@ -78,6 +76,7 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
7876
}
7977

8078
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeLeftInCluster);
79+
waitForRateLimitingAssignments(calculator);
8180

8281
Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
8382

@@ -93,13 +92,14 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
9392
}
9493
}
9594

96-
public void testGrouping_RespectsMaxNodesPerGroupingLimit() {
95+
public void testGrouping_RespectsMaxNodesPerGroupingLimit() throws Exception {
9796
// Start with more nodes possible per grouping
9897
var numNodes = DEFAULT_MAX_NODES_PER_GROUPING + randomIntBetween(1, 3);
9998
var nodeNames = internalCluster().startNodes(numNodes);
10099
ensureStableCluster(numNodes);
101100

102101
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
102+
waitForRateLimitingAssignments(calculator);
103103

104104
Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
105105

@@ -111,13 +111,14 @@ public void testGrouping_RespectsMaxNodesPerGroupingLimit() {
111111
}
112112
}
113113

114-
public void testInitialRateLimitsCalculation_Correct() throws IOException {
114+
public void testInitialRateLimitsCalculation_Correct() throws Exception {
115115
// Start with max nodes per grouping (=3)
116116
int numNodes = DEFAULT_MAX_NODES_PER_GROUPING;
117117
var nodeNames = internalCluster().startNodes(numNodes);
118118
ensureStableCluster(numNodes);
119119

120120
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
121+
waitForRateLimitingAssignments(calculator);
121122

122123
Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
123124

@@ -129,7 +130,7 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
129130

130131
if ((service instanceof SenderService senderService)) {
131132
var sender = senderService.getSender();
132-
if (sender instanceof HttpRequestSender httpSender) {
133+
if (sender instanceof HttpRequestSender) {
133134
var assignment = calculator.getRateLimitAssignment(service.name(), TaskType.SPARSE_EMBEDDING);
134135

135136
assertNotNull(assignment);
@@ -141,13 +142,14 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
141142
}
142143
}
143144

144-
public void testRateLimits_Decrease_OnNodeJoin() {
145+
public void testRateLimits_Decrease_OnNodeJoin() throws Exception {
145146
// Start with 2 nodes
146147
var initialNodes = 2;
147148
var nodeNames = internalCluster().startNodes(initialNodes);
148149
ensureStableCluster(initialNodes);
149150

150151
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
152+
waitForRateLimitingAssignments(calculator);
151153

152154
for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
153155
var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName);
@@ -159,6 +161,7 @@ public void testRateLimits_Decrease_OnNodeJoin() {
159161
// Add a new node
160162
internalCluster().startNode();
161163
ensureStableCluster(initialNodes + 1);
164+
waitForRateLimitingAssignments(calculator);
162165

163166
// Get updated assignments
164167
var updatedAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType());
@@ -169,13 +172,14 @@ public void testRateLimits_Decrease_OnNodeJoin() {
169172
}
170173
}
171174

172-
public void testRateLimits_Increase_OnNodeLeave() throws IOException {
175+
public void testRateLimits_Increase_OnNodeLeave() throws Exception {
173176
// Start with max nodes per grouping (=3)
174177
int numNodes = DEFAULT_MAX_NODES_PER_GROUPING;
175178
var nodeNames = internalCluster().startNodes(numNodes);
176179
ensureStableCluster(numNodes);
177180

178181
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
182+
waitForRateLimitingAssignments(calculator);
179183

180184
for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
181185
var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName);
@@ -188,6 +192,7 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
188192
var nodeToRemove = nodeNames.get(numNodes - 1);
189193
internalCluster().stopNode(nodeToRemove);
190194
ensureStableCluster(numNodes - 1);
195+
waitForRateLimitingAssignments(calculator);
191196

192197
// Get updated assignments
193198
var updatedAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType());
@@ -202,4 +207,12 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
202207
protected Collection<Class<? extends Plugin>> nodePlugins() {
203208
return Arrays.asList(LocalStateInferencePlugin.class);
204209
}
210+
211+
private void waitForRateLimitingAssignments(InferenceServiceNodeLocalRateLimitCalculator calculator) throws Exception {
212+
assertBusy(() -> {
213+
var assignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
214+
assertNotNull(assignment);
215+
assertFalse(assignment.responsibleNodes().isEmpty());
216+
}, RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS, TimeUnit.SECONDS);
217+
}
205218
}

0 commit comments

Comments
 (0)