1010import org .elasticsearch .inference .TaskType ;
1111import org .elasticsearch .plugins .Plugin ;
1212import org .elasticsearch .test .ESIntegTestCase ;
13+ import org .elasticsearch .test .InternalTestCluster ;
1314import org .elasticsearch .xpack .inference .LocalStateInferencePlugin ;
1415import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSender ;
1516import org .elasticsearch .xpack .inference .services .SenderService ;
2324import static org .elasticsearch .xpack .inference .common .InferenceServiceNodeLocalRateLimitCalculator .DEFAULT_MAX_NODES_PER_GROUPING ;
2425import static org .elasticsearch .xpack .inference .common .InferenceServiceNodeLocalRateLimitCalculator .SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS ;
2526import static org .hamcrest .Matchers .equalTo ;
27+ import static org .hamcrest .Matchers .instanceOf ;
2628
2729@ ESIntegTestCase .ClusterScope (scope = ESIntegTestCase .Scope .SUITE , numDataNodes = 0 )
2830public class InferenceServiceNodeLocalRateLimitCalculatorTests extends ESIntegTestCase {
@@ -39,7 +41,7 @@ public void testInitialClusterGrouping_Correct() throws Exception {
3941 var nodeNames = internalCluster ().startNodes (numNodes );
4042 ensureStableCluster (numNodes );
4143
42- var firstCalculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeNames .getFirst ());
44+ var firstCalculator = getCalculatorInstance ( internalCluster (), nodeNames .getFirst ());
4345 waitForRateLimitingAssignments (firstCalculator );
4446
4547 RateLimitAssignment firstAssignment = firstCalculator .getRateLimitAssignment (
@@ -49,7 +51,7 @@ public void testInitialClusterGrouping_Correct() throws Exception {
4951
5052 // Verify that all other nodes land on the same assignment
5153 for (String nodeName : nodeNames .subList (1 , nodeNames .size ())) {
52- var calculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeName );
54+ var calculator = getCalculatorInstance ( internalCluster (), nodeName );
5355 waitForRateLimitingAssignments (calculator );
5456 var currentAssignment = calculator .getRateLimitAssignment (ElasticInferenceService .NAME , TaskType .SPARSE_EMBEDDING );
5557 assertEquals (firstAssignment , currentAssignment );
@@ -75,7 +77,7 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
7577 ensureStableCluster (currentNumberOfNodes );
7678 }
7779
78- var calculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeLeftInCluster );
80+ var calculator = getCalculatorInstance ( internalCluster (), nodeLeftInCluster );
7981 waitForRateLimitingAssignments (calculator );
8082
8183 Set <String > supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ();
@@ -98,7 +100,7 @@ public void testGrouping_RespectsMaxNodesPerGroupingLimit() throws Exception {
98100 var nodeNames = internalCluster ().startNodes (numNodes );
99101 ensureStableCluster (numNodes );
100102
101- var calculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeNames .getFirst ());
103+ var calculator = getCalculatorInstance ( internalCluster (), nodeNames .getFirst ());
102104 waitForRateLimitingAssignments (calculator );
103105
104106 Set <String > supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ();
@@ -117,7 +119,7 @@ public void testInitialRateLimitsCalculation_Correct() throws Exception {
117119 var nodeNames = internalCluster ().startNodes (numNodes );
118120 ensureStableCluster (numNodes );
119121
120- var calculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeNames .getFirst ());
122+ var calculator = getCalculatorInstance ( internalCluster (), nodeNames .getFirst ());
121123 waitForRateLimitingAssignments (calculator );
122124
123125 Set <String > supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ();
@@ -148,7 +150,7 @@ public void testRateLimits_Decrease_OnNodeJoin() throws Exception {
148150 var nodeNames = internalCluster ().startNodes (initialNodes );
149151 ensureStableCluster (initialNodes );
150152
151- var calculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeNames .getFirst ());
153+ var calculator = getCalculatorInstance ( internalCluster (), nodeNames .getFirst ());
152154 waitForRateLimitingAssignments (calculator );
153155
154156 for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ()) {
@@ -178,7 +180,7 @@ public void testRateLimits_Increase_OnNodeLeave() throws Exception {
178180 var nodeNames = internalCluster ().startNodes (numNodes );
179181 ensureStableCluster (numNodes );
180182
181- var calculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeNames .getFirst ());
183+ var calculator = getCalculatorInstance ( internalCluster (), nodeNames .getFirst ());
182184 waitForRateLimitingAssignments (calculator );
183185
184186 for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ()) {
@@ -208,6 +210,27 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
208210 return Arrays .asList (LocalStateInferencePlugin .class );
209211 }
210212
213+ private InferenceServiceNodeLocalRateLimitCalculator getCalculatorInstance (InternalTestCluster internalTestCluster , String nodeName ) {
214+ InferenceServiceRateLimitCalculator calculatorInstance = internalTestCluster .getInstance (
215+ InferenceServiceRateLimitCalculator .class ,
216+ nodeName
217+ );
218+ assertThat (
219+ "["
220+ + InferenceServiceNodeLocalRateLimitCalculatorTests .class .getName ()
221+ + "] should use ["
222+ + InferenceServiceNodeLocalRateLimitCalculator .class .getName ()
223+ + "] as implementation for ["
224+ + InferenceServiceRateLimitCalculator .class .getName ()
225+ + "]. Provided implementation was ["
226+ + calculatorInstance .getClass ().getName ()
227+ + "]." ,
228+ calculatorInstance ,
229+ instanceOf (InferenceServiceNodeLocalRateLimitCalculator .class )
230+ );
231+ return (InferenceServiceNodeLocalRateLimitCalculator ) calculatorInstance ;
232+ }
233+
211234 private void waitForRateLimitingAssignments (InferenceServiceNodeLocalRateLimitCalculator calculator ) throws Exception {
212235 assertBusy (() -> {
213236 var assignment = calculator .getRateLimitAssignment (ElasticInferenceService .NAME , TaskType .SPARSE_EMBEDDING );
0 commit comments