1919import org .elasticsearch .transport .TransportService ;
2020import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
2121import org .elasticsearch .xpack .inference .action .task .StreamingTaskManager ;
22- import org .elasticsearch .xpack .inference .common .InferenceServiceNodeLocalRateLimitCalculator ;
22+ import org .elasticsearch .xpack .inference .common .InferenceServiceRateLimitCalculator ;
2323import org .elasticsearch .xpack .inference .common .RateLimitAssignment ;
2424import org .elasticsearch .xpack .inference .registry .ModelRegistry ;
2525import org .elasticsearch .xpack .inference .telemetry .InferenceStats ;
@@ -50,7 +50,7 @@ protected BaseTransportInferenceAction<InferenceAction.Request> createAction(
5050 InferenceServiceRegistry serviceRegistry ,
5151 InferenceStats inferenceStats ,
5252 StreamingTaskManager streamingTaskManager ,
53- InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator ,
53+ InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator ,
5454 NodeClient nodeClient ,
5555 ThreadPool threadPool
5656 ) {
@@ -77,7 +77,7 @@ public void testNoRerouting_WhenTaskTypeNotSupported() {
7777 TaskType unsupportedTaskType = TaskType .COMPLETION ;
7878 mockService (listener -> listener .onResponse (mock ()));
7979
80- when (inferenceServiceNodeLocalRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , unsupportedTaskType )).thenReturn (false );
80+ when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , unsupportedTaskType )).thenReturn (false );
8181
8282 var listener = doExecute (unsupportedTaskType );
8383
@@ -89,8 +89,8 @@ public void testNoRerouting_WhenTaskTypeNotSupported() {
8989 public void testNoRerouting_WhenNoGroupingCalculatedYet () {
9090 mockService (listener -> listener .onResponse (mock ()));
9191
92- when (inferenceServiceNodeLocalRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
93- when (inferenceServiceNodeLocalRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (null );
92+ when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
93+ when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (null );
9494
9595 var listener = doExecute (taskType );
9696
@@ -102,8 +102,8 @@ public void testNoRerouting_WhenNoGroupingCalculatedYet() {
102102 public void testNoRerouting_WhenEmptyNodeList () {
103103 mockService (listener -> listener .onResponse (mock ()));
104104
105- when (inferenceServiceNodeLocalRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
106- when (inferenceServiceNodeLocalRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (
105+ when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
106+ when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (
107107 new RateLimitAssignment (List .of ())
108108 );
109109
@@ -120,10 +120,10 @@ public void testRerouting_ToOtherNode() {
120120
121121 // The local node is different to the "other-node" responsible for serviceId
122122 when (nodeClient .getLocalNodeId ()).thenReturn ("local-node" );
123- when (inferenceServiceNodeLocalRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
123+ when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
124124 // Requests for serviceId are always routed to "other-node"
125125 var assignment = new RateLimitAssignment (List .of (otherNode ));
126- when (inferenceServiceNodeLocalRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
126+ when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
127127
128128 mockService (listener -> listener .onResponse (mock ()));
129129 var listener = doExecute (taskType );
@@ -141,9 +141,9 @@ public void testRerouting_ToLocalNode_WithoutGoingThroughTransportLayerAgain() {
141141
142142 // The local node is the only one responsible for serviceId
143143 when (nodeClient .getLocalNodeId ()).thenReturn (localNodeId );
144- when (inferenceServiceNodeLocalRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
144+ when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
145145 var assignment = new RateLimitAssignment (List .of (localNode ));
146- when (inferenceServiceNodeLocalRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
146+ when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
147147
148148 mockService (listener -> listener .onResponse (mock ()));
149149 var listener = doExecute (taskType );
@@ -158,9 +158,9 @@ public void testRerouting_HandlesTransportException_FromOtherNode() {
158158 when (otherNode .getId ()).thenReturn ("other-node" );
159159
160160 when (nodeClient .getLocalNodeId ()).thenReturn ("local-node" );
161- when (inferenceServiceNodeLocalRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
161+ when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
162162 var assignment = new RateLimitAssignment (List .of (otherNode ));
163- when (inferenceServiceNodeLocalRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
163+ when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
164164
165165 mockService (listener -> listener .onResponse (mock ()));
166166
@@ -173,6 +173,10 @@ public void testRerouting_HandlesTransportException_FromOtherNode() {
173173
174174 var listener = doExecute (taskType );
175175
176+ // Verify request was rerouted
177+ verify (transportService ).sendRequest (same (otherNode ), eq (InferenceAction .NAME ), any (), any ());
178+ // Verify local execution didn't happen
179+ verify (listener , never ()).onResponse (any ());
176180 // Verify exception was propagated from "other-node" to "local-node"
177181 verify (listener ).onFailure (same (expectedException ));
178182 }
0 commit comments