1010import org .elasticsearch .inference .TaskType ;
1111import org .elasticsearch .plugins .Plugin ;
1212import org .elasticsearch .test .ESIntegTestCase ;
13+ import org .elasticsearch .test .InternalTestCluster ;
1314import org .elasticsearch .xpack .inference .LocalStateInferencePlugin ;
1415import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSender ;
1516import org .elasticsearch .xpack .inference .services .SenderService ;
2324import static org .elasticsearch .xpack .inference .common .InferenceServiceNodeLocalRateLimitCalculator .DEFAULT_MAX_NODES_PER_GROUPING ;
2425import static org .elasticsearch .xpack .inference .common .InferenceServiceNodeLocalRateLimitCalculator .SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS ;
2526import static org .hamcrest .Matchers .equalTo ;
27+ import static org .hamcrest .Matchers .instanceOf ;
2628
2729@ ESIntegTestCase .ClusterScope (scope = ESIntegTestCase .Scope .SUITE , numDataNodes = 0 )
2830public class InferenceServiceNodeLocalRateLimitCalculatorTests extends ESIntegTestCase {
@@ -40,7 +42,7 @@ public void testInitialClusterGrouping_Correct() {
4042 RateLimitAssignment firstAssignment = null ;
4143
4244 for (String nodeName : nodeNames ) {
43- var calculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeName );
45+ var calculator = getCalculatorInstance ( internalCluster (), nodeName );
4446
4547 // Check first node's assignments
4648 if (firstAssignment == null ) {
@@ -77,7 +79,7 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
7779 ensureStableCluster (currentNumberOfNodes );
7880 }
7981
80- var calculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeLeftInCluster );
82+ var calculator = getCalculatorInstance ( internalCluster (), nodeLeftInCluster );
8183
8284 Set <String > supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ();
8385
@@ -99,7 +101,7 @@ public void testGrouping_RespectsMaxNodesPerGroupingLimit() {
99101 var nodeNames = internalCluster ().startNodes (numNodes );
100102 ensureStableCluster (numNodes );
101103
102- var calculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeNames .getFirst ());
104+ var calculator = getCalculatorInstance ( internalCluster (), nodeNames .getFirst ());
103105
104106 Set <String > supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ();
105107
@@ -117,7 +119,7 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
117119 var nodeNames = internalCluster ().startNodes (numNodes );
118120 ensureStableCluster (numNodes );
119121
120- var calculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeNames .getFirst ());
122+ var calculator = getCalculatorInstance ( internalCluster (), nodeNames .getFirst ());
121123
122124 Set <String > supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ();
123125
@@ -129,7 +131,7 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
129131
130132 if ((service instanceof SenderService senderService )) {
131133 var sender = senderService .getSender ();
132- if (sender instanceof HttpRequestSender httpSender ) {
134+ if (sender instanceof HttpRequestSender ) {
133135 var assignment = calculator .getRateLimitAssignment (service .name (), TaskType .SPARSE_EMBEDDING );
134136
135137 assertNotNull (assignment );
@@ -147,7 +149,7 @@ public void testRateLimits_Decrease_OnNodeJoin() {
147149 var nodeNames = internalCluster ().startNodes (initialNodes );
148150 ensureStableCluster (initialNodes );
149151
150- var calculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeNames .getFirst ());
152+ var calculator = getCalculatorInstance ( internalCluster (), nodeNames .getFirst ());
151153
152154 for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ()) {
153155 var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .get (serviceName );
@@ -175,7 +177,7 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
175177 var nodeNames = internalCluster ().startNodes (numNodes );
176178 ensureStableCluster (numNodes );
177179
178- var calculator = internalCluster (). getInstance ( InferenceServiceNodeLocalRateLimitCalculator . class , nodeNames .getFirst ());
180+ var calculator = getCalculatorInstance ( internalCluster (), nodeNames .getFirst ());
179181
180182 for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ()) {
181183 var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .get (serviceName );
@@ -202,4 +204,25 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
202204 protected Collection <Class <? extends Plugin >> nodePlugins () {
203205 return Arrays .asList (LocalStateInferencePlugin .class );
204206 }
207+
208+ private InferenceServiceNodeLocalRateLimitCalculator getCalculatorInstance (InternalTestCluster internalTestCluster , String nodeName ) {
209+ InferenceServiceRateLimitCalculator calculatorInstance = internalTestCluster .getInstance (
210+ InferenceServiceRateLimitCalculator .class ,
211+ nodeName
212+ );
213+ assertThat (
214+ "["
215+ + InferenceServiceNodeLocalRateLimitCalculatorTests .class .getName ()
216+ + "] should use ["
217+ + InferenceServiceNodeLocalRateLimitCalculator .class .getName ()
218+ + "] as implementation for ["
219+ + InferenceServiceRateLimitCalculator .class .getName ()
220+ + "]. Provided implementation was ["
221+ + calculatorInstance .getClass ().getName ()
222+ + "]." ,
223+ calculatorInstance ,
224+ instanceOf (InferenceServiceNodeLocalRateLimitCalculator .class )
225+ );
226+ return (InferenceServiceNodeLocalRateLimitCalculator ) calculatorInstance ;
227+ }
205228}
0 commit comments