1515import org .elasticsearch .xpack .inference .services .SenderService ;
1616import org .elasticsearch .xpack .inference .services .elastic .ElasticInferenceService ;
1717
18- import java .io .IOException ;
1918import java .util .Arrays ;
2019import java .util .Collection ;
2120import java .util .Set ;
21+ import java .util .concurrent .TimeUnit ;
2222
2323import static org .elasticsearch .xpack .inference .common .InferenceServiceNodeLocalRateLimitCalculator .DEFAULT_MAX_NODES_PER_GROUPING ;
2424import static org .elasticsearch .xpack .inference .common .InferenceServiceNodeLocalRateLimitCalculator .SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS ;
2727@ ESIntegTestCase .ClusterScope (scope = ESIntegTestCase .Scope .SUITE , numDataNodes = 0 )
2828public class InferenceServiceNodeLocalRateLimitCalculatorTests extends ESIntegTestCase {
2929
30+ private static final Integer RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS = 15 ;
31+
3032 public void setUp () throws Exception {
3133 super .setUp ();
3234 }
3335
34- public void testInitialClusterGrouping_Correct () {
36+ public void testInitialClusterGrouping_Correct () throws Exception {
3537 // Start with 2-5 nodes
3638 var numNodes = randomIntBetween (2 , 5 );
3739 var nodeNames = internalCluster ().startNodes (numNodes );
3840 ensureStableCluster (numNodes );
3941
40- RateLimitAssignment firstAssignment = null ;
42+ var firstCalculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeNames .getFirst ());
43+ waitForRateLimitingAssignments (firstCalculator );
4144
42- for (String nodeName : nodeNames ) {
43- var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeName );
45+ RateLimitAssignment firstAssignment = firstCalculator .getRateLimitAssignment (
46+ ElasticInferenceService .NAME ,
47+ TaskType .SPARSE_EMBEDDING
48+ );
4449
45- // Check first node's assignments
46- if (firstAssignment == null ) {
47- // Get assignment for a specific service (e.g., EIS)
48- firstAssignment = calculator .getRateLimitAssignment (ElasticInferenceService .NAME , TaskType .SPARSE_EMBEDDING );
49-
50- assertNotNull (firstAssignment );
51- // Verify there are assignments for this service
52- assertFalse (firstAssignment .responsibleNodes ().isEmpty ());
53- } else {
54- // Verify other nodes see the same assignment
55- var currentAssignment = calculator .getRateLimitAssignment (ElasticInferenceService .NAME , TaskType .SPARSE_EMBEDDING );
56- assertEquals (firstAssignment , currentAssignment );
57- }
50+ // Verify that all other nodes land on the same assignment
51+ for (String nodeName : nodeNames .subList (1 , nodeNames .size ())) {
52+ var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeName );
53+ waitForRateLimitingAssignments (calculator );
54+ var currentAssignment = calculator .getRateLimitAssignment (ElasticInferenceService .NAME , TaskType .SPARSE_EMBEDDING );
55+ assertEquals (firstAssignment , currentAssignment );
5856 }
5957 }
6058
61- public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster () throws IOException {
59+ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster () throws Exception {
6260 // Start with 3-5 nodes
6361 var numNodes = randomIntBetween (3 , 5 );
6462 var nodeNames = internalCluster ().startNodes (numNodes );
@@ -78,6 +76,7 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
7876 }
7977
8078 var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeLeftInCluster );
79+ waitForRateLimitingAssignments (calculator );
8180
8281 Set <String > supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ();
8382
@@ -93,13 +92,14 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
9392 }
9493 }
9594
96- public void testGrouping_RespectsMaxNodesPerGroupingLimit () {
95+ public void testGrouping_RespectsMaxNodesPerGroupingLimit () throws Exception {
9796 // Start with more nodes possible per grouping
9897 var numNodes = DEFAULT_MAX_NODES_PER_GROUPING + randomIntBetween (1 , 3 );
9998 var nodeNames = internalCluster ().startNodes (numNodes );
10099 ensureStableCluster (numNodes );
101100
102101 var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeNames .getFirst ());
102+ waitForRateLimitingAssignments (calculator );
103103
104104 Set <String > supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ();
105105
@@ -111,13 +111,14 @@ public void testGrouping_RespectsMaxNodesPerGroupingLimit() {
111111 }
112112 }
113113
114- public void testInitialRateLimitsCalculation_Correct () throws IOException {
114+ public void testInitialRateLimitsCalculation_Correct () throws Exception {
115115 // Start with max nodes per grouping (=3)
116116 int numNodes = DEFAULT_MAX_NODES_PER_GROUPING ;
117117 var nodeNames = internalCluster ().startNodes (numNodes );
118118 ensureStableCluster (numNodes );
119119
120120 var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeNames .getFirst ());
121+ waitForRateLimitingAssignments (calculator );
121122
122123 Set <String > supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ();
123124
@@ -129,7 +130,7 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
129130
130131 if ((service instanceof SenderService senderService )) {
131132 var sender = senderService .getSender ();
132- if (sender instanceof HttpRequestSender httpSender ) {
133+ if (sender instanceof HttpRequestSender ) {
133134 var assignment = calculator .getRateLimitAssignment (service .name (), TaskType .SPARSE_EMBEDDING );
134135
135136 assertNotNull (assignment );
@@ -141,13 +142,14 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
141142 }
142143 }
143144
144- public void testRateLimits_Decrease_OnNodeJoin () {
145+ public void testRateLimits_Decrease_OnNodeJoin () throws Exception {
145146 // Start with 2 nodes
146147 var initialNodes = 2 ;
147148 var nodeNames = internalCluster ().startNodes (initialNodes );
148149 ensureStableCluster (initialNodes );
149150
150151 var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeNames .getFirst ());
152+ waitForRateLimitingAssignments (calculator );
151153
152154 for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ()) {
153155 var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .get (serviceName );
@@ -159,6 +161,7 @@ public void testRateLimits_Decrease_OnNodeJoin() {
159161 // Add a new node
160162 internalCluster ().startNode ();
161163 ensureStableCluster (initialNodes + 1 );
164+ waitForRateLimitingAssignments (calculator );
162165
163166 // Get updated assignments
164167 var updatedAssignment = calculator .getRateLimitAssignment (serviceName , config .taskType ());
@@ -169,13 +172,14 @@ public void testRateLimits_Decrease_OnNodeJoin() {
169172 }
170173 }
171174
172- public void testRateLimits_Increase_OnNodeLeave () throws IOException {
175+ public void testRateLimits_Increase_OnNodeLeave () throws Exception {
173176 // Start with max nodes per grouping (=3)
174177 int numNodes = DEFAULT_MAX_NODES_PER_GROUPING ;
175178 var nodeNames = internalCluster ().startNodes (numNodes );
176179 ensureStableCluster (numNodes );
177180
178181 var calculator = internalCluster ().getInstance (InferenceServiceNodeLocalRateLimitCalculator .class , nodeNames .getFirst ());
182+ waitForRateLimitingAssignments (calculator );
179183
180184 for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .keySet ()) {
181185 var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS .get (serviceName );
@@ -188,6 +192,7 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
188192 var nodeToRemove = nodeNames .get (numNodes - 1 );
189193 internalCluster ().stopNode (nodeToRemove );
190194 ensureStableCluster (numNodes - 1 );
195+ waitForRateLimitingAssignments (calculator );
191196
192197 // Get updated assignments
193198 var updatedAssignment = calculator .getRateLimitAssignment (serviceName , config .taskType ());
@@ -202,4 +207,12 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
202207 protected Collection <Class <? extends Plugin >> nodePlugins () {
203208 return Arrays .asList (LocalStateInferencePlugin .class );
204209 }
210+
211+ private void waitForRateLimitingAssignments (InferenceServiceNodeLocalRateLimitCalculator calculator ) throws Exception {
212+ assertBusy (() -> {
213+ var assignment = calculator .getRateLimitAssignment (ElasticInferenceService .NAME , TaskType .SPARSE_EMBEDDING );
214+ assertNotNull (assignment );
215+ assertFalse (assignment .responsibleNodes ().isEmpty ());
216+ }, RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS , TimeUnit .SECONDS );
217+ }
205218}
0 commit comments