Skip to content

Commit 28e31da

Browse files
committed
Request interface in InferenceServiceNodeLocalRateLimitCalculatorTests and cast to concrete instance and make sure it's the correct one for the test.
1 parent 2f3053d commit 28e31da

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,6 @@ public Collection<?> createComponents(PluginServices services) {
331331

332332
// Add binding for interface -> implementation
333333
components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator));
334-
components.add(calculator);
335334

336335
return components;
337336
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.inference.TaskType;
1111
import org.elasticsearch.plugins.Plugin;
1212
import org.elasticsearch.test.ESIntegTestCase;
13+
import org.elasticsearch.test.InternalTestCluster;
1314
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
1415
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
1516
import org.elasticsearch.xpack.inference.services.SenderService;
@@ -23,6 +24,7 @@
2324
import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.DEFAULT_MAX_NODES_PER_GROUPING;
2425
import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS;
2526
import static org.hamcrest.Matchers.equalTo;
27+
import static org.hamcrest.Matchers.instanceOf;
2628

2729
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 0)
2830
public class InferenceServiceNodeLocalRateLimitCalculatorTests extends ESIntegTestCase {
@@ -40,7 +42,7 @@ public void testInitialClusterGrouping_Correct() {
4042
RateLimitAssignment firstAssignment = null;
4143

4244
for (String nodeName : nodeNames) {
43-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeName);
45+
var calculator = getCalculatorInstance(internalCluster(), nodeName);
4446

4547
// Check first node's assignments
4648
if (firstAssignment == null) {
@@ -77,7 +79,7 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
7779
ensureStableCluster(currentNumberOfNodes);
7880
}
7981

80-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeLeftInCluster);
82+
var calculator = getCalculatorInstance(internalCluster(), nodeLeftInCluster);
8183

8284
Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
8385

@@ -99,7 +101,7 @@ public void testGrouping_RespectsMaxNodesPerGroupingLimit() {
99101
var nodeNames = internalCluster().startNodes(numNodes);
100102
ensureStableCluster(numNodes);
101103

102-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
104+
var calculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst());
103105

104106
Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
105107

@@ -117,7 +119,7 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
117119
var nodeNames = internalCluster().startNodes(numNodes);
118120
ensureStableCluster(numNodes);
119121

120-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
122+
var calculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst());
121123

122124
Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
123125

@@ -129,7 +131,7 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
129131

130132
if ((service instanceof SenderService senderService)) {
131133
var sender = senderService.getSender();
132-
if (sender instanceof HttpRequestSender httpSender) {
134+
if (sender instanceof HttpRequestSender) {
133135
var assignment = calculator.getRateLimitAssignment(service.name(), TaskType.SPARSE_EMBEDDING);
134136

135137
assertNotNull(assignment);
@@ -147,7 +149,7 @@ public void testRateLimits_Decrease_OnNodeJoin() {
147149
var nodeNames = internalCluster().startNodes(initialNodes);
148150
ensureStableCluster(initialNodes);
149151

150-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
152+
var calculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst());
151153

152154
for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
153155
var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName);
@@ -175,7 +177,7 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
175177
var nodeNames = internalCluster().startNodes(numNodes);
176178
ensureStableCluster(numNodes);
177179

178-
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.getFirst());
180+
var calculator = getCalculatorInstance(internalCluster(), nodeNames.getFirst());
179181

180182
for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
181183
var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName);
@@ -202,4 +204,25 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
202204
protected Collection<Class<? extends Plugin>> nodePlugins() {
203205
return Arrays.asList(LocalStateInferencePlugin.class);
204206
}
207+
208+
private InferenceServiceNodeLocalRateLimitCalculator getCalculatorInstance(InternalTestCluster internalTestCluster, String nodeName) {
209+
InferenceServiceRateLimitCalculator calculatorInstance = internalTestCluster.getInstance(
210+
InferenceServiceRateLimitCalculator.class,
211+
nodeName
212+
);
213+
assertThat(
214+
"["
215+
+ InferenceServiceNodeLocalRateLimitCalculatorTests.class.getName()
216+
+ "] should use ["
217+
+ InferenceServiceNodeLocalRateLimitCalculator.class.getName()
218+
+ "] as implementation for ["
219+
+ InferenceServiceRateLimitCalculator.class.getName()
220+
+ "]. Provided implementation was ["
221+
+ calculatorInstance.getClass().getName()
222+
+ "].",
223+
calculatorInstance,
224+
instanceOf(InferenceServiceNodeLocalRateLimitCalculator.class)
225+
);
226+
return (InferenceServiceNodeLocalRateLimitCalculator) calculatorInstance;
227+
}
205228
}

0 commit comments

Comments
 (0)