99
1010import org .elasticsearch .action .support .ActionFilters ;
1111import org .elasticsearch .client .internal .node .NodeClient ;
12- import org .elasticsearch .cluster .node .DiscoveryNode ;
1312import org .elasticsearch .inference .InferenceServiceRegistry ;
1413import org .elasticsearch .inference .TaskType ;
1514import org .elasticsearch .inference .telemetry .InferenceStats ;
1615import org .elasticsearch .license .MockLicenseState ;
1716import org .elasticsearch .threadpool .ThreadPool ;
18- import org .elasticsearch .transport .TransportException ;
19- import org .elasticsearch .transport .TransportResponseHandler ;
2017import org .elasticsearch .transport .TransportService ;
2118import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
2219import org .elasticsearch .xpack .inference .action .task .StreamingTaskManager ;
23- import org .elasticsearch .xpack .inference .common .InferenceServiceRateLimitCalculator ;
24- import org .elasticsearch .xpack .inference .common .RateLimitAssignment ;
2520import org .elasticsearch .xpack .inference .registry .InferenceEndpointRegistry ;
2621
27- import java .util .List ;
28-
29- import static org .hamcrest .Matchers .is ;
30- import static org .mockito .ArgumentMatchers .any ;
31- import static org .mockito .ArgumentMatchers .anyLong ;
32- import static org .mockito .ArgumentMatchers .assertArg ;
33- import static org .mockito .ArgumentMatchers .eq ;
34- import static org .mockito .ArgumentMatchers .same ;
35- import static org .mockito .Mockito .doAnswer ;
3622import static org .mockito .Mockito .mock ;
37- import static org .mockito .Mockito .never ;
38- import static org .mockito .Mockito .verify ;
39- import static org .mockito .Mockito .when ;
4023
4124public class TransportInferenceActionTests extends BaseTransportInferenceActionTestCase <InferenceAction .Request > {
4225
@@ -53,7 +36,6 @@ protected BaseTransportInferenceAction<InferenceAction.Request> createAction(
5336 InferenceServiceRegistry serviceRegistry ,
5437 InferenceStats inferenceStats ,
5538 StreamingTaskManager streamingTaskManager ,
56- InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator ,
5739 NodeClient nodeClient ,
5840 ThreadPool threadPool
5941 ) {
@@ -65,7 +47,6 @@ protected BaseTransportInferenceAction<InferenceAction.Request> createAction(
6547 serviceRegistry ,
6648 inferenceStats ,
6749 streamingTaskManager ,
68- inferenceServiceNodeLocalRateLimitCalculator ,
6950 nodeClient ,
7051 threadPool
7152 );
@@ -75,136 +56,4 @@ protected BaseTransportInferenceAction<InferenceAction.Request> createAction(
7556 protected InferenceAction .Request createRequest () {
7657 return mock (InferenceAction .Request .class );
7758 }
78-
79- public void testNoRerouting_WhenTaskTypeNotSupported () {
80- TaskType unsupportedTaskType = TaskType .COMPLETION ;
81- mockService (listener -> listener .onResponse (mock ()));
82-
83- when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , unsupportedTaskType )).thenReturn (false );
84-
85- var listener = doExecute (unsupportedTaskType );
86-
87- verify (listener ).onResponse (any ());
88- // Verify request was handled locally (not rerouted using TransportService)
89- verify (transportService , never ()).sendRequest (any (), any (), any (), any ());
90- // Verify request metric attributes were recorded on the node performing inference
91- verify (inferenceStats .inferenceDuration ()).record (anyLong (), assertArg (attributes -> {
92- assertThat (attributes .get ("rerouted" ), is (Boolean .FALSE ));
93- assertThat (attributes .get ("node_id" ), is (localNodeId ));
94- }));
95- }
96-
97- public void testNoRerouting_WhenNoGroupingCalculatedYet () {
98- mockService (listener -> listener .onResponse (mock ()));
99-
100- when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
101- when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (null );
102-
103- var listener = doExecute (taskType );
104-
105- verify (listener ).onResponse (any ());
106- // Verify request was handled locally (not rerouted using TransportService)
107- verify (transportService , never ()).sendRequest (any (), any (), any (), any ());
108- // Verify request metric attributes were recorded on the node performing inference
109- verify (inferenceStats .inferenceDuration ()).record (anyLong (), assertArg (attributes -> {
110- assertThat (attributes .get ("rerouted" ), is (Boolean .FALSE ));
111- assertThat (attributes .get ("node_id" ), is (localNodeId ));
112- }));
113- }
114-
115- public void testNoRerouting_WhenEmptyNodeList () {
116- mockService (listener -> listener .onResponse (mock ()));
117-
118- when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
119- when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (
120- new RateLimitAssignment (List .of ())
121- );
122-
123- var listener = doExecute (taskType );
124-
125- verify (listener ).onResponse (any ());
126- // Verify request was handled locally (not rerouted using TransportService)
127- verify (transportService , never ()).sendRequest (any (), any (), any (), any ());
128- // Verify request metric attributes were recorded on the node performing inference
129- verify (inferenceStats .inferenceDuration ()).record (anyLong (), assertArg (attributes -> {
130- assertThat (attributes .get ("rerouted" ), is (Boolean .FALSE ));
131- assertThat (attributes .get ("node_id" ), is (localNodeId ));
132- }));
133- }
134-
135- public void testRerouting_ToOtherNode () {
136- DiscoveryNode otherNode = mock (DiscoveryNode .class );
137- when (otherNode .getId ()).thenReturn ("other-node" );
138-
139- // The local node is different to the "other-node" responsible for serviceId
140- when (nodeClient .getLocalNodeId ()).thenReturn ("local-node" );
141- when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
142- // Requests for serviceId are always routed to "other-node"
143- var assignment = new RateLimitAssignment (List .of (otherNode ));
144- when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
145-
146- mockService (listener -> listener .onResponse (mock ()));
147- var listener = doExecute (taskType );
148-
149- // Verify request was rerouted
150- verify (transportService ).sendRequest (same (otherNode ), eq (InferenceAction .NAME ), any (), any ());
151- // Verify local execution didn't happen
152- verify (listener , never ()).onResponse (any ());
153- // Verify that request metric attributes were NOT recorded on the node rerouting the request to another node
154- verify (inferenceStats .inferenceDuration (), never ()).record (anyLong (), any ());
155- }
156-
157- public void testRerouting_ToLocalNode_WithoutGoingThroughTransportLayerAgain () {
158- DiscoveryNode localNode = mock (DiscoveryNode .class );
159- String localNodeId = "local-node" ;
160- when (localNode .getId ()).thenReturn (localNodeId );
161-
162- // The local node is the only one responsible for serviceId
163- when (nodeClient .getLocalNodeId ()).thenReturn (localNodeId );
164- when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
165- var assignment = new RateLimitAssignment (List .of (localNode ));
166- when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
167-
168- mockService (listener -> listener .onResponse (mock ()));
169- var listener = doExecute (taskType );
170-
171- verify (listener ).onResponse (any ());
172- // Verify request was handled locally (not rerouted using TransportService)
173- verify (transportService , never ()).sendRequest (any (), any (), any (), any ());
174- // Verify request metric attributes were recorded on the node performing inference
175- verify (inferenceStats .inferenceDuration ()).record (anyLong (), assertArg (attributes -> {
176- assertThat (attributes .get ("rerouted" ), is (Boolean .FALSE ));
177- assertThat (attributes .get ("node_id" ), is (localNodeId ));
178- }));
179- }
180-
181- public void testRerouting_HandlesTransportException_FromOtherNode () {
182- DiscoveryNode otherNode = mock (DiscoveryNode .class );
183- when (otherNode .getId ()).thenReturn ("other-node" );
184-
185- when (nodeClient .getLocalNodeId ()).thenReturn ("local-node" );
186- when (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceId , taskType )).thenReturn (true );
187- var assignment = new RateLimitAssignment (List .of (otherNode ));
188- when (inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceId , taskType )).thenReturn (assignment );
189-
190- mockService (listener -> listener .onResponse (mock ()));
191-
192- TransportException expectedException = new TransportException ("Failed to route" );
193- doAnswer (invocation -> {
194- TransportResponseHandler <?> handler = invocation .getArgument (3 );
195- handler .handleException (expectedException );
196- return null ;
197- }).when (transportService ).sendRequest (any (), any (), any (), any ());
198-
199- var listener = doExecute (taskType );
200-
201- // Verify request was rerouted
202- verify (transportService ).sendRequest (same (otherNode ), eq (InferenceAction .NAME ), any (), any ());
203- // Verify local execution didn't happen
204- verify (listener , never ()).onResponse (any ());
205- // Verify exception was propagated from "other-node" to "local-node"
206- verify (listener ).onFailure (same (expectedException ));
207- // Verify that request metric attributes were NOT recorded on the node rerouting the request to another node
208- verify (inferenceStats .inferenceDuration (), never ()).record (anyLong (), any ());
209- }
21059}
0 commit comments