Skip to content

Commit 25223d2

Browse files
authored
[8.x] [Inference API] Fix tests in TransportInferenceActionTests (#121302) (#121404)
1 parent a070b6c commit 25223d2

File tree

3 files changed

+25
-21
lines changed

3 files changed

+25
-21
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2929
import org.elasticsearch.xpack.inference.InferencePlugin;
3030
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
31-
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
31+
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
3232
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
3333
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
3434
import org.junit.Before;
@@ -64,7 +64,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
6464
protected static final String inferenceId = "inferenceEntityId";
6565
protected InferenceServiceRegistry serviceRegistry;
6666
protected InferenceStats inferenceStats;
67-
protected InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator;
67+
protected InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator;
6868
protected TransportService transportService;
6969
protected NodeClient nodeClient;
7070

@@ -79,7 +79,7 @@ public void setUp() throws Exception {
7979
ThreadPool threadPool = mock();
8080
nodeClient = mock();
8181
transportService = mock();
82-
inferenceServiceNodeLocalRateLimitCalculator = mock();
82+
inferenceServiceRateLimitCalculator = mock();
8383
licenseState = mock();
8484
modelRegistry = mock();
8585
serviceRegistry = mock();
@@ -94,7 +94,7 @@ public void setUp() throws Exception {
9494
serviceRegistry,
9595
inferenceStats,
9696
streamingTaskManager,
97-
inferenceServiceNodeLocalRateLimitCalculator,
97+
inferenceServiceRateLimitCalculator,
9898
nodeClient,
9999
threadPool
100100
);
@@ -110,7 +110,7 @@ protected abstract BaseTransportInferenceAction<Request> createAction(
110110
InferenceServiceRegistry serviceRegistry,
111111
InferenceStats inferenceStats,
112112
StreamingTaskManager streamingTaskManager,
113-
InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
113+
InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
114114
NodeClient nodeClient,
115115
ThreadPool threadPool
116116
);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import org.elasticsearch.transport.TransportService;
2020
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2121
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
22-
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
22+
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
2323
import org.elasticsearch.xpack.inference.common.RateLimitAssignment;
2424
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2525
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
@@ -50,7 +50,7 @@ protected BaseTransportInferenceAction<InferenceAction.Request> createAction(
5050
InferenceServiceRegistry serviceRegistry,
5151
InferenceStats inferenceStats,
5252
StreamingTaskManager streamingTaskManager,
53-
InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
53+
InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
5454
NodeClient nodeClient,
5555
ThreadPool threadPool
5656
) {
@@ -77,7 +77,7 @@ public void testNoRerouting_WhenTaskTypeNotSupported() {
7777
TaskType unsupportedTaskType = TaskType.COMPLETION;
7878
mockService(listener -> listener.onResponse(mock()));
7979

80-
when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, unsupportedTaskType)).thenReturn(false);
80+
when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, unsupportedTaskType)).thenReturn(false);
8181

8282
var listener = doExecute(unsupportedTaskType);
8383

@@ -89,8 +89,8 @@ public void testNoRerouting_WhenTaskTypeNotSupported() {
8989
public void testNoRerouting_WhenNoGroupingCalculatedYet() {
9090
mockService(listener -> listener.onResponse(mock()));
9191

92-
when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
93-
when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(null);
92+
when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
93+
when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(null);
9494

9595
var listener = doExecute(taskType);
9696

@@ -102,8 +102,8 @@ public void testNoRerouting_WhenNoGroupingCalculatedYet() {
102102
public void testNoRerouting_WhenEmptyNodeList() {
103103
mockService(listener -> listener.onResponse(mock()));
104104

105-
when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
106-
when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(
105+
when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
106+
when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(
107107
new RateLimitAssignment(List.of())
108108
);
109109

@@ -120,10 +120,10 @@ public void testRerouting_ToOtherNode() {
120120

121121
// The local node is different to the "other-node" responsible for serviceId
122122
when(nodeClient.getLocalNodeId()).thenReturn("local-node");
123-
when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
123+
when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
124124
// Requests for serviceId are always routed to "other-node"
125125
var assignment = new RateLimitAssignment(List.of(otherNode));
126-
when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
126+
when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
127127

128128
mockService(listener -> listener.onResponse(mock()));
129129
var listener = doExecute(taskType);
@@ -141,9 +141,9 @@ public void testRerouting_ToLocalNode_WithoutGoingThroughTransportLayerAgain() {
141141

142142
// The local node is the only one responsible for serviceId
143143
when(nodeClient.getLocalNodeId()).thenReturn(localNodeId);
144-
when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
144+
when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
145145
var assignment = new RateLimitAssignment(List.of(localNode));
146-
when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
146+
when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
147147

148148
mockService(listener -> listener.onResponse(mock()));
149149
var listener = doExecute(taskType);
@@ -158,9 +158,9 @@ public void testRerouting_HandlesTransportException_FromOtherNode() {
158158
when(otherNode.getId()).thenReturn("other-node");
159159

160160
when(nodeClient.getLocalNodeId()).thenReturn("local-node");
161-
when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
161+
when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
162162
var assignment = new RateLimitAssignment(List.of(otherNode));
163-
when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
163+
when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
164164

165165
mockService(listener -> listener.onResponse(mock()));
166166

@@ -173,6 +173,10 @@ public void testRerouting_HandlesTransportException_FromOtherNode() {
173173

174174
var listener = doExecute(taskType);
175175

176+
// Verify request was rerouted
177+
verify(transportService).sendRequest(same(otherNode), eq(InferenceAction.NAME), any(), any());
178+
// Verify local execution didn't happen
179+
verify(listener, never()).onResponse(any());
176180
// Verify exception was propagated from "other-node" to "local-node"
177181
verify(listener).onFailure(same(expectedException));
178182
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import org.elasticsearch.transport.TransportService;
1919
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
2020
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
21-
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
21+
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
2222
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2323
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
2424

@@ -49,7 +49,7 @@ protected BaseTransportInferenceAction<UnifiedCompletionAction.Request> createAc
4949
InferenceServiceRegistry serviceRegistry,
5050
InferenceStats inferenceStats,
5151
StreamingTaskManager streamingTaskManager,
52-
InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
52+
InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator,
5353
NodeClient nodeClient,
5454
ThreadPool threadPool
5555
) {
@@ -61,7 +61,7 @@ protected BaseTransportInferenceAction<UnifiedCompletionAction.Request> createAc
6161
serviceRegistry,
6262
inferenceStats,
6363
streamingTaskManager,
64-
inferenceServiceNodeLocalRateLimitCalculator,
64+
inferenceServiceRateLimitCalculator,
6565
nodeClient,
6666
threadPool
6767
);

0 commit comments

Comments
 (0)