Skip to content

Commit 239d4a2

Browse files
authored
[8.18] [Inference API] Remove second calculator instance as component and update tests (#121284) (#121530)
1 parent 8550a53 commit 239d4a2

File tree

2 files changed

+66
-31
lines changed

2 files changed

+66
-31
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,6 @@ public Collection<?> createComponents(PluginServices services) {
331331

332332
// Add binding for interface -> implementation
333333
components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator));
334-
components.add(calculator);
335334

336335
return components;
337336
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,55 +10,55 @@
1010
import org.elasticsearch.inference.TaskType;
1111
import org.elasticsearch.plugins.Plugin;
1212
import org.elasticsearch.test.ESIntegTestCase;
13+
import org.elasticsearch.test.InternalTestCluster;
1314
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
1415
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
1516
import org.elasticsearch.xpack.inference.services.SenderService;
1617
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
1718

18-
import java.io.IOException;
1919
import java.util.Arrays;
2020
import java.util.Collection;
2121
import java.util.Set;
22+
import java.util.concurrent.TimeUnit;
2223

2324
import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.DEFAULT_MAX_NODES_PER_GROUPING;
2425
import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS;
2526
import static org.hamcrest.Matchers.equalTo;
27+
import static org.hamcrest.Matchers.instanceOf;
2628

2729
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 0)
2830
public class InferenceServiceNodeLocalRateLimitCalculatorTests extends ESIntegTestCase {
2931

32+
private static final Integer RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS = 15;
33+
3034
public void setUp() throws Exception {
3135
super.setUp();
3236
}
3337

34-
public void testInitialClusterGrouping_Correct() {
38+
public void testInitialClusterGrouping_Correct() throws Exception {
3539
// Start with 2-5 nodes
3640
var numNodes = randomIntBetween(2, 5);
3741
var nodeNames = internalCluster().startNodes(numNodes);
3842
ensureStableCluster(numNodes);
3943

40-
RateLimitAssignment firstAssignment = null;
44+
var firstCalculator = getCalculatorInstance(internalCluster(), nodeNames.get(0));
45+
waitForRateLimitingAssignments(firstCalculator);
4146

42-
for (String nodeName : nodeNames) {
43-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeName);
44-
45-
// Check first node's assignments
46-
if (firstAssignment == null) {
47-
// Get assignment for a specific service (e.g., EIS)
48-
firstAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
49-
50-
assertNotNull(firstAssignment);
51-
// Verify there are assignments for this service
52-
assertFalse(firstAssignment.responsibleNodes().isEmpty());
53-
} else {
54-
// Verify other nodes see the same assignment
55-
var currentAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
56-
assertEquals(firstAssignment, currentAssignment);
57-
}
47+
RateLimitAssignment firstAssignment = firstCalculator.getRateLimitAssignment(
48+
ElasticInferenceService.NAME,
49+
TaskType.SPARSE_EMBEDDING
50+
);
51+
52+
// Verify that all other nodes land on the same assignment
53+
for (String nodeName : nodeNames.subList(1, nodeNames.size())) {
54+
var calculator = getCalculatorInstance(internalCluster(), nodeName);
55+
waitForRateLimitingAssignments(calculator);
56+
var currentAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
57+
assertEquals(firstAssignment, currentAssignment);
5858
}
5959
}
6060

61-
public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws IOException {
61+
public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws Exception {
6262
// Start with 3-5 nodes
6363
var numNodes = randomIntBetween(3, 5);
6464
var nodeNames = internalCluster().startNodes(numNodes);
@@ -77,7 +77,8 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
7777
ensureStableCluster(currentNumberOfNodes);
7878
}
7979

80-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeLeftInCluster);
80+
var calculator = getCalculatorInstance(internalCluster(), nodeLeftInCluster);
81+
waitForRateLimitingAssignments(calculator);
8182

8283
Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
8384

@@ -93,13 +94,14 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
9394
}
9495
}
9596

96-
public void testGrouping_RespectsMaxNodesPerGroupingLimit() {
97+
public void testGrouping_RespectsMaxNodesPerGroupingLimit() throws Exception {
9798
// Start with more nodes possible per grouping
9899
var numNodes = DEFAULT_MAX_NODES_PER_GROUPING + randomIntBetween(1, 3);
99100
var nodeNames = internalCluster().startNodes(numNodes);
100101
ensureStableCluster(numNodes);
101102

102-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0));
103+
var calculator = getCalculatorInstance(internalCluster(), nodeNames.get(0));
104+
waitForRateLimitingAssignments(calculator);
103105

104106
Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
105107

@@ -111,13 +113,14 @@ public void testGrouping_RespectsMaxNodesPerGroupingLimit() {
111113
}
112114
}
113115

114-
public void testInitialRateLimitsCalculation_Correct() throws IOException {
116+
public void testInitialRateLimitsCalculation_Correct() throws Exception {
115117
// Start with max nodes per grouping (=3)
116118
int numNodes = DEFAULT_MAX_NODES_PER_GROUPING;
117119
var nodeNames = internalCluster().startNodes(numNodes);
118120
ensureStableCluster(numNodes);
119121

120-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0));
122+
var calculator = getCalculatorInstance(internalCluster(), nodeNames.get(0));
123+
waitForRateLimitingAssignments(calculator);
121124

122125
Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
123126

@@ -129,7 +132,7 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
129132

130133
if ((service instanceof SenderService senderService)) {
131134
var sender = senderService.getSender();
132-
if (sender instanceof HttpRequestSender httpSender) {
135+
if (sender instanceof HttpRequestSender) {
133136
var assignment = calculator.getRateLimitAssignment(service.name(), TaskType.SPARSE_EMBEDDING);
134137

135138
assertNotNull(assignment);
@@ -141,13 +144,14 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
141144
}
142145
}
143146

144-
public void testRateLimits_Decrease_OnNodeJoin() {
147+
public void testRateLimits_Decrease_OnNodeJoin() throws Exception {
145148
// Start with 2 nodes
146149
var initialNodes = 2;
147150
var nodeNames = internalCluster().startNodes(initialNodes);
148151
ensureStableCluster(initialNodes);
149152

150-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0));
153+
var calculator = getCalculatorInstance(internalCluster(), nodeNames.get(0));
154+
waitForRateLimitingAssignments(calculator);
151155

152156
for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
153157
var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName);
@@ -159,6 +163,7 @@ public void testRateLimits_Decrease_OnNodeJoin() {
159163
// Add a new node
160164
internalCluster().startNode();
161165
ensureStableCluster(initialNodes + 1);
166+
waitForRateLimitingAssignments(calculator);
162167

163168
// Get updated assignments
164169
var updatedAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType());
@@ -169,13 +174,14 @@ public void testRateLimits_Decrease_OnNodeJoin() {
169174
}
170175
}
171176

172-
public void testRateLimits_Increase_OnNodeLeave() throws IOException {
177+
public void testRateLimits_Increase_OnNodeLeave() throws Exception {
173178
// Start with max nodes per grouping (=3)
174179
int numNodes = DEFAULT_MAX_NODES_PER_GROUPING;
175180
var nodeNames = internalCluster().startNodes(numNodes);
176181
ensureStableCluster(numNodes);
177182

178-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0));
183+
var calculator = getCalculatorInstance(internalCluster(), nodeNames.get(0));
184+
waitForRateLimitingAssignments(calculator);
179185

180186
for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
181187
var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName);
@@ -188,6 +194,7 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
188194
var nodeToRemove = nodeNames.get(numNodes - 1);
189195
internalCluster().stopNode(nodeToRemove);
190196
ensureStableCluster(numNodes - 1);
197+
waitForRateLimitingAssignments(calculator);
191198

192199
// Get updated assignments
193200
var updatedAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType());
@@ -202,4 +209,33 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
202209
protected Collection<Class<? extends Plugin>> nodePlugins() {
203210
return Arrays.asList(LocalStateInferencePlugin.class);
204211
}
212+
213+
private InferenceServiceNodeLocalRateLimitCalculator getCalculatorInstance(InternalTestCluster internalTestCluster, String nodeName) {
214+
InferenceServiceRateLimitCalculator calculatorInstance = internalTestCluster.getInstance(
215+
InferenceServiceRateLimitCalculator.class,
216+
nodeName
217+
);
218+
assertThat(
219+
"["
220+
+ InferenceServiceNodeLocalRateLimitCalculatorTests.class.getName()
221+
+ "] should use ["
222+
+ InferenceServiceNodeLocalRateLimitCalculator.class.getName()
223+
+ "] as implementation for ["
224+
+ InferenceServiceRateLimitCalculator.class.getName()
225+
+ "]. Provided implementation was ["
226+
+ calculatorInstance.getClass().getName()
227+
+ "].",
228+
calculatorInstance,
229+
instanceOf(InferenceServiceNodeLocalRateLimitCalculator.class)
230+
);
231+
return (InferenceServiceNodeLocalRateLimitCalculator) calculatorInstance;
232+
}
233+
234+
private void waitForRateLimitingAssignments(InferenceServiceNodeLocalRateLimitCalculator calculator) throws Exception {
235+
assertBusy(() -> {
236+
var assignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
237+
assertNotNull(assignment);
238+
assertFalse(assignment.responsibleNodes().isEmpty());
239+
}, RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS, TimeUnit.SECONDS);
240+
}
205241
}

0 commit comments

Comments
 (0)