1010import org .elasticsearch .inference .TaskType ;
1111import org .elasticsearch .plugins .Plugin ;
1212import org .elasticsearch .test .ESIntegTestCase ;
13+ import org .elasticsearch .test .InternalTestCluster ;
1314import org .elasticsearch .xpack .inference .LocalStateInferencePlugin ;
1415import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSender ;
1516import org .elasticsearch .xpack .inference .services .SenderService ;
1617import org .elasticsearch .xpack .inference .services .elastic .ElasticInferenceService ;
1718
18- import java .io .IOException ;
1919import java .util .Arrays ;
2020import java .util .Collection ;
2121import java .util .Set ;
22+ import java .util .concurrent .TimeUnit ;
2223
2324import static org .elasticsearch .xpack .inference .common .InferenceServiceNodeLocalRateLimitCalculator .DEFAULT_MAX_NODES_PER_GROUPING ;
2425import static org .elasticsearch .xpack .inference .common .InferenceServiceNodeLocalRateLimitCalculator .SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS ;
2526import static org .hamcrest .Matchers .equalTo ;
27+ import static org .hamcrest .Matchers .instanceOf ;
2628
2729@ ESIntegTestCase .ClusterScope (scope = ESIntegTestCase .Scope .SUITE , numDataNodes = 0 )
2830public class InferenceServiceNodeLocalRateLimitCalculatorTests extends ESIntegTestCase {
2931
32+ private static final Integer RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS = 15 ;
33+
3034 public void setUp () throws Exception {
3135 super .setUp ();
3236 }
3337
34- public void testInitialClusterGrouping_Correct () {
38+ public void testInitialClusterGrouping_Correct () throws Exception {
3539 // Start with 2-5 nodes
3640 var numNodes = randomIntBetween (2 , 5 );
3741 var nodeNames = internalCluster ().startNodes (numNodes );
3842 ensureStableCluster (numNodes );
3943
40- RateLimitAssignment firstAssignment = null ;
44+ var firstCalculator = getCalculatorInstance (internalCluster (), nodeNames .get (0 ));
45+ waitForRateLimitingAssignments (firstCalculator );
4146
42- for (String nodeName : nodeNames ) {
43- var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeName );
44-
45- // Check first node's assignments
46- if (firstAssignment == null ) {
47- // Get assignment for a specific service (e.g., EIS)
48- firstAssignment = calculator .getRateLimitAssignment (ElasticInferenceService .NAME , TaskType .SPARSE_EMBEDDING );
49-
50- assertNotNull (firstAssignment );
51- // Verify there are assignments for this service
52- assertFalse (firstAssignment .responsibleNodes ().isEmpty ());
53- } else {
54- // Verify other nodes see the same assignment
55- var currentAssignment = calculator .getRateLimitAssignment (ElasticInferenceService .NAME , TaskType .SPARSE_EMBEDDING );
56- assertEquals (firstAssignment , currentAssignment );
57- }
47+ RateLimitAssignment firstAssignment = firstCalculator .getRateLimitAssignment (
48+ ElasticInferenceService .NAME ,
49+ TaskType .SPARSE_EMBEDDING
50+ );
51+
52+ // Verify that all other nodes land on the same assignment
53+ for (String nodeName : nodeNames .subList (1 , nodeNames .size ())) {
54+ var calculator = getCalculatorInstance (internalCluster (), nodeName );
55+ waitForRateLimitingAssignments (calculator );
56+ var currentAssignment = calculator .getRateLimitAssignment (ElasticInferenceService .NAME , TaskType .SPARSE_EMBEDDING );
57+ assertEquals (firstAssignment , currentAssignment );
5858 }
5959 }
6060
61- public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster () throws IOException {
61+ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster () throws Exception {
6262 // Start with 3-5 nodes
6363 var numNodes = randomIntBetween (3 , 5 );
6464 var nodeNames = internalCluster ().startNodes (numNodes );
@@ -77,7 +77,8 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
7777 ensureStableCluster (currentNumberOfNodes );
7878 }
7979
80- var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeLeftInCluster );
80+ var calculator = getCalculatorInstance (internalCluster (), nodeLeftInCluster );
81+ waitForRateLimitingAssignments (calculator );
8182
8283 Set <String > supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ();
8384
@@ -93,13 +94,14 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
9394 }
9495 }
9596
96- public void testGrouping_RespectsMaxNodesPerGroupingLimit () {
97+ public void testGrouping_RespectsMaxNodesPerGroupingLimit () throws Exception {
9798 // Start with more nodes possible per grouping
9899 var numNodes = DEFAULT_MAX_NODES_PER_GROUPING + randomIntBetween (1 , 3 );
99100 var nodeNames = internalCluster ().startNodes (numNodes );
100101 ensureStableCluster (numNodes );
101102
102- var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeNames .get (0 ));
103+ var calculator = getCalculatorInstance (internalCluster (), nodeNames .get (0 ));
104+ waitForRateLimitingAssignments (calculator );
103105
104106 Set <String > supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ();
105107
@@ -111,13 +113,14 @@ public void testGrouping_RespectsMaxNodesPerGroupingLimit() {
111113 }
112114 }
113115
114- public void testInitialRateLimitsCalculation_Correct () throws IOException {
116+ public void testInitialRateLimitsCalculation_Correct () throws Exception {
115117 // Start with max nodes per grouping (=3)
116118 int numNodes = DEFAULT_MAX_NODES_PER_GROUPING ;
117119 var nodeNames = internalCluster ().startNodes (numNodes );
118120 ensureStableCluster (numNodes );
119121
120- var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeNames .get (0 ));
122+ var calculator = getCalculatorInstance (internalCluster (), nodeNames .get (0 ));
123+ waitForRateLimitingAssignments (calculator );
121124
122125 Set <String > supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ();
123126
@@ -129,7 +132,7 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
129132
130133 if ((service instanceof SenderService senderService )) {
131134 var sender = senderService .getSender ();
132- if (sender instanceof HttpRequestSender httpSender ) {
135+ if (sender instanceof HttpRequestSender ) {
133136 var assignment = calculator .getRateLimitAssignment (service .name (), TaskType .SPARSE_EMBEDDING );
134137
135138 assertNotNull (assignment );
@@ -141,13 +144,14 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
141144 }
142145 }
143146
144- public void testRateLimits_Decrease_OnNodeJoin () {
147+ public void testRateLimits_Decrease_OnNodeJoin () throws Exception {
145148 // Start with 2 nodes
146149 var initialNodes = 2 ;
147150 var nodeNames = internalCluster ().startNodes (initialNodes );
148151 ensureStableCluster (initialNodes );
149152
150- var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeNames .get (0 ));
153+ var calculator = getCalculatorInstance (internalCluster (), nodeNames .get (0 ));
154+ waitForRateLimitingAssignments (calculator );
151155
152156 for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ()) {
153157 var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .get (serviceName );
@@ -159,6 +163,7 @@ public void testRateLimits_Decrease_OnNodeJoin() {
159163 // Add a new node
160164 internalCluster ().startNode ();
161165 ensureStableCluster (initialNodes + 1 );
166+ waitForRateLimitingAssignments (calculator );
162167
163168 // Get updated assignments
164169 var updatedAssignment = calculator .getRateLimitAssignment (serviceName , config .taskType ());
@@ -169,13 +174,14 @@ public void testRateLimits_Decrease_OnNodeJoin() {
169174 }
170175 }
171176
172- public void testRateLimits_Increase_OnNodeLeave () throws IOException {
177+ public void testRateLimits_Increase_OnNodeLeave () throws Exception {
173178 // Start with max nodes per grouping (=3)
174179 int numNodes = DEFAULT_MAX_NODES_PER_GROUPING ;
175180 var nodeNames = internalCluster ().startNodes (numNodes );
176181 ensureStableCluster (numNodes );
177182
178- var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeNames .get (0 ));
183+ var calculator = getCalculatorInstance (internalCluster (), nodeNames .get (0 ));
184+ waitForRateLimitingAssignments (calculator );
179185
180186 for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ()) {
181187 var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .get (serviceName );
@@ -188,6 +194,7 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
188194 var nodeToRemove = nodeNames .get (numNodes - 1 );
189195 internalCluster ().stopNode (nodeToRemove );
190196 ensureStableCluster (numNodes - 1 );
197+ waitForRateLimitingAssignments (calculator );
191198
192199 // Get updated assignments
193200 var updatedAssignment = calculator .getRateLimitAssignment (serviceName , config .taskType ());
@@ -202,4 +209,33 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
202209 protected Collection <Class <? extends Plugin >> nodePlugins () {
203210 return Arrays .asList (LocalStateInferencePlugin .class );
204211 }
212+
213+ private InferenceServiceNodeLocalRateLimitCalculator getCalculatorInstance (InternalTestCluster internalTestCluster , String nodeName ) {
214+ InferenceServiceRateLimitCalculator calculatorInstance = internalTestCluster .getInstance (
215+ InferenceServiceRateLimitCalculator .class ,
216+ nodeName
217+ );
218+ assertThat (
219+ "["
220+ + InferenceServiceNodeLocalRateLimitCalculatorTests .class .getName ()
221+ + "] should use ["
222+ + InferenceServiceNodeLocalRateLimitCalculator .class .getName ()
223+ + "] as implementation for ["
224+ + InferenceServiceRateLimitCalculator .class .getName ()
225+ + "]. Provided implementation was ["
226+ + calculatorInstance .getClass ().getName ()
227+ + "]." ,
228+ calculatorInstance ,
229+ instanceOf (InferenceServiceNodeLocalRateLimitCalculator .class )
230+ );
231+ return (InferenceServiceNodeLocalRateLimitCalculator ) calculatorInstance ;
232+ }
233+
234+ private void waitForRateLimitingAssignments (InferenceServiceNodeLocalRateLimitCalculator calculator ) throws Exception {
235+ assertBusy (() -> {
236+ var assignment = calculator .getRateLimitAssignment (ElasticInferenceService .NAME , TaskType .SPARSE_EMBEDDING );
237+ assertNotNull (assignment );
238+ assertFalse (assignment .responsibleNodes ().isEmpty ());
239+ }, RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS , TimeUnit .SECONDS );
240+ }
205241}
0 commit comments