1919import  org .elasticsearch .transport .TransportService ;
2020import  org .elasticsearch .xpack .core .inference .action .InferenceAction ;
2121import  org .elasticsearch .xpack .inference .action .task .StreamingTaskManager ;
22- import  org .elasticsearch .xpack .inference .common .InferenceServiceNodeLocalRateLimitCalculator ;
22+ import  org .elasticsearch .xpack .inference .common .InferenceServiceRateLimitCalculator ;
2323import  org .elasticsearch .xpack .inference .common .RateLimitAssignment ;
2424import  org .elasticsearch .xpack .inference .registry .ModelRegistry ;
2525import  org .elasticsearch .xpack .inference .telemetry .InferenceStats ;
@@ -50,7 +50,7 @@ protected BaseTransportInferenceAction<InferenceAction.Request> createAction(
5050        InferenceServiceRegistry  serviceRegistry ,
5151        InferenceStats  inferenceStats ,
5252        StreamingTaskManager  streamingTaskManager ,
53-         InferenceServiceNodeLocalRateLimitCalculator  inferenceServiceNodeLocalRateLimitCalculator ,
53+         InferenceServiceRateLimitCalculator  inferenceServiceNodeLocalRateLimitCalculator ,
5454        NodeClient  nodeClient ,
5555        ThreadPool  threadPool 
5656    ) {
@@ -77,7 +77,7 @@ public void testNoRerouting_WhenTaskTypeNotSupported() {
7777        TaskType  unsupportedTaskType  = TaskType .COMPLETION ;
7878        mockService (listener  -> listener .onResponse (mock ()));
7979
80-         when (inferenceServiceNodeLocalRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , unsupportedTaskType )).thenReturn (false );
80+         when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , unsupportedTaskType )).thenReturn (false );
8181
8282        var  listener  = doExecute (unsupportedTaskType );
8383
@@ -89,8 +89,8 @@ public void testNoRerouting_WhenTaskTypeNotSupported() {
8989    public  void  testNoRerouting_WhenNoGroupingCalculatedYet () {
9090        mockService (listener  -> listener .onResponse (mock ()));
9191
92-         when (inferenceServiceNodeLocalRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
93-         when (inferenceServiceNodeLocalRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (null );
92+         when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
93+         when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (null );
9494
9595        var  listener  = doExecute (taskType );
9696
@@ -102,8 +102,8 @@ public void testNoRerouting_WhenNoGroupingCalculatedYet() {
102102    public  void  testNoRerouting_WhenEmptyNodeList () {
103103        mockService (listener  -> listener .onResponse (mock ()));
104104
105-         when (inferenceServiceNodeLocalRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
106-         when (inferenceServiceNodeLocalRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (
105+         when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
106+         when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (
107107            new  RateLimitAssignment (List .of ())
108108        );
109109
@@ -120,10 +120,10 @@ public void testRerouting_ToOtherNode() {
120120
121121        // The local node is different to the "other-node" responsible for serviceId 
122122        when (nodeClient .getLocalNodeId ()).thenReturn ("local-node" );
123-         when (inferenceServiceNodeLocalRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
123+         when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
124124        // Requests for serviceId are always routed to "other-node" 
125125        var  assignment  = new  RateLimitAssignment (List .of (otherNode ));
126-         when (inferenceServiceNodeLocalRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
126+         when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
127127
128128        mockService (listener  -> listener .onResponse (mock ()));
129129        var  listener  = doExecute (taskType );
@@ -141,9 +141,9 @@ public void testRerouting_ToLocalNode_WithoutGoingThroughTransportLayerAgain() {
141141
142142        // The local node is the only one responsible for serviceId 
143143        when (nodeClient .getLocalNodeId ()).thenReturn (localNodeId );
144-         when (inferenceServiceNodeLocalRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
144+         when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
145145        var  assignment  = new  RateLimitAssignment (List .of (localNode ));
146-         when (inferenceServiceNodeLocalRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
146+         when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
147147
148148        mockService (listener  -> listener .onResponse (mock ()));
149149        var  listener  = doExecute (taskType );
@@ -158,9 +158,9 @@ public void testRerouting_HandlesTransportException_FromOtherNode() {
158158        when (otherNode .getId ()).thenReturn ("other-node" );
159159
160160        when (nodeClient .getLocalNodeId ()).thenReturn ("local-node" );
161-         when (inferenceServiceNodeLocalRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
161+         when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
162162        var  assignment  = new  RateLimitAssignment (List .of (otherNode ));
163-         when (inferenceServiceNodeLocalRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
163+         when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
164164
165165        mockService (listener  -> listener .onResponse (mock ()));
166166
@@ -173,6 +173,10 @@ public void testRerouting_HandlesTransportException_FromOtherNode() {
173173
174174        var  listener  = doExecute (taskType );
175175
176+         // Verify request was rerouted 
177+         verify (transportService ).sendRequest (same (otherNode ), eq (InferenceAction .NAME ), any (), any ());
178+         // Verify local execution didn't happen 
179+         verify (listener , never ()).onResponse (any ());
176180        // Verify exception was propagated from "other-node" to "local-node" 
177181        verify (listener ).onFailure (same (expectedException ));
178182    }
0 commit comments