Skip to content

Commit 1d86ce8

Browse files
committed
the rest
1 parent e5b96c6 commit 1d86ce8

File tree

6 files changed

+0
-412
lines changed

6 files changed

+0
-412
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
2828
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
2929
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
30-
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
3130
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;
3231

3332
import java.util.concurrent.Flow;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,6 @@ interface RateLimiterCreator {
9797
private static final TimeValue RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1);
9898

9999
private final ConcurrentMap<Object, RateLimitingEndpointHandler> rateLimitGroupings = new ConcurrentHashMap<>();
100-
// TODO: add one atomic integer (number of nodes); also explain the assumption and why this works
101-
// TODO: document that this impacts chat completion (and increase the default rate limit)
102100
private final AtomicInteger rateLimitDivisor = new AtomicInteger(1);
103101
private final ThreadPool threadPool;
104102
private final CountDownLatch startupLatch;
@@ -404,10 +402,6 @@ public void init() {
404402
}
405403

406404
/**
407-
* This method is solely called by {@link InferenceServiceNodeLocalRateLimitCalculator} to update
408-
* rate limits, so they're "node-local".
409-
* The general idea is described in {@link InferenceServiceNodeLocalRateLimitCalculator} in more detail.
410-
*
411405
* @param divisor - divisor to divide the initial requests per time unit by
412406
*/
413407
private synchronized void updateTokensPerTimeUnit(Integer divisor) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
3333
import org.elasticsearch.xpack.inference.InferencePlugin;
3434
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
35-
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
3635
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;
3736
import org.junit.Before;
3837
import org.mockito.ArgumentCaptor;
@@ -68,7 +67,6 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
6867
protected static final String localNodeId = "local-node-id";
6968
protected InferenceServiceRegistry serviceRegistry;
7069
protected InferenceStats inferenceStats;
71-
protected InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator;
7270
protected TransportService transportService;
7371
protected NodeClient nodeClient;
7472

@@ -83,7 +81,6 @@ public void setUp() throws Exception {
8381
threadPool = mock();
8482
nodeClient = mock();
8583
transportService = mock();
86-
inferenceServiceRateLimitCalculator = mock();
8784
licenseState = mock();
8885
inferenceEndpointRegistry = mock();
8986
serviceRegistry = mock();
@@ -98,7 +95,6 @@ public void setUp() throws Exception {
9895
serviceRegistry,
9996
inferenceStats,
10097
streamingTaskManager,
101-
inferenceServiceRateLimitCalculator,
10298
nodeClient,
10399
threadPool
104100
);
@@ -115,7 +111,6 @@ protected abstract BaseTransportInferenceAction<Request> createAction(
115111
InferenceServiceRegistry serviceRegistry,
116112
InferenceStats inferenceStats,
117113
StreamingTaskManager streamingTaskManager,
118-
InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
119114
NodeClient nodeClient,
120115
ThreadPool threadPool
121116
);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java

Lines changed: 0 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,17 @@
99

1010
import org.elasticsearch.action.support.ActionFilters;
1111
import org.elasticsearch.client.internal.node.NodeClient;
12-
import org.elasticsearch.cluster.node.DiscoveryNode;
1312
import org.elasticsearch.inference.InferenceServiceRegistry;
1413
import org.elasticsearch.inference.TaskType;
1514
import org.elasticsearch.inference.telemetry.InferenceStats;
1615
import org.elasticsearch.license.MockLicenseState;
1716
import org.elasticsearch.threadpool.ThreadPool;
18-
import org.elasticsearch.transport.TransportException;
19-
import org.elasticsearch.transport.TransportResponseHandler;
2017
import org.elasticsearch.transport.TransportService;
2118
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2219
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
23-
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
24-
import org.elasticsearch.xpack.inference.common.RateLimitAssignment;
2520
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;
2621

27-
import java.util.List;
28-
29-
import static org.hamcrest.Matchers.is;
30-
import static org.mockito.ArgumentMatchers.any;
31-
import static org.mockito.ArgumentMatchers.anyLong;
32-
import static org.mockito.ArgumentMatchers.assertArg;
33-
import static org.mockito.ArgumentMatchers.eq;
34-
import static org.mockito.ArgumentMatchers.same;
35-
import static org.mockito.Mockito.doAnswer;
3622
import static org.mockito.Mockito.mock;
37-
import static org.mockito.Mockito.never;
38-
import static org.mockito.Mockito.verify;
39-
import static org.mockito.Mockito.when;
4023

4124
public class TransportInferenceActionTests extends BaseTransportInferenceActionTestCase<InferenceAction.Request> {
4225

@@ -53,7 +36,6 @@ protected BaseTransportInferenceAction<InferenceAction.Request> createAction(
5336
InferenceServiceRegistry serviceRegistry,
5437
InferenceStats inferenceStats,
5538
StreamingTaskManager streamingTaskManager,
56-
InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
5739
NodeClient nodeClient,
5840
ThreadPool threadPool
5941
) {
@@ -65,7 +47,6 @@ protected BaseTransportInferenceAction<InferenceAction.Request> createAction(
6547
serviceRegistry,
6648
inferenceStats,
6749
streamingTaskManager,
68-
inferenceServiceNodeLocalRateLimitCalculator,
6950
nodeClient,
7051
threadPool
7152
);
@@ -75,136 +56,4 @@ protected BaseTransportInferenceAction<InferenceAction.Request> createAction(
7556
protected InferenceAction.Request createRequest() {
7657
return mock(InferenceAction.Request.class);
7758
}
78-
79-
public void testNoRerouting_WhenTaskTypeNotSupported() {
80-
TaskType unsupportedTaskType = TaskType.COMPLETION;
81-
mockService(listener -> listener.onResponse(mock()));
82-
83-
when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, unsupportedTaskType)).thenReturn(false);
84-
85-
var listener = doExecute(unsupportedTaskType);
86-
87-
verify(listener).onResponse(any());
88-
// Verify request was handled locally (not rerouted using TransportService)
89-
verify(transportService, never()).sendRequest(any(), any(), any(), any());
90-
// Verify request metric attributes were recorded on the node performing inference
91-
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
92-
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
93-
assertThat(attributes.get("node_id"), is(localNodeId));
94-
}));
95-
}
96-
97-
public void testNoRerouting_WhenNoGroupingCalculatedYet() {
98-
mockService(listener -> listener.onResponse(mock()));
99-
100-
when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
101-
when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(null);
102-
103-
var listener = doExecute(taskType);
104-
105-
verify(listener).onResponse(any());
106-
// Verify request was handled locally (not rerouted using TransportService)
107-
verify(transportService, never()).sendRequest(any(), any(), any(), any());
108-
// Verify request metric attributes were recorded on the node performing inference
109-
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
110-
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
111-
assertThat(attributes.get("node_id"), is(localNodeId));
112-
}));
113-
}
114-
115-
public void testNoRerouting_WhenEmptyNodeList() {
116-
mockService(listener -> listener.onResponse(mock()));
117-
118-
when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
119-
when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(
120-
new RateLimitAssignment(List.of())
121-
);
122-
123-
var listener = doExecute(taskType);
124-
125-
verify(listener).onResponse(any());
126-
// Verify request was handled locally (not rerouted using TransportService)
127-
verify(transportService, never()).sendRequest(any(), any(), any(), any());
128-
// Verify request metric attributes were recorded on the node performing inference
129-
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
130-
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
131-
assertThat(attributes.get("node_id"), is(localNodeId));
132-
}));
133-
}
134-
135-
public void testRerouting_ToOtherNode() {
136-
DiscoveryNode otherNode = mock(DiscoveryNode.class);
137-
when(otherNode.getId()).thenReturn("other-node");
138-
139-
// The local node is different to the "other-node" responsible for serviceId
140-
when(nodeClient.getLocalNodeId()).thenReturn("local-node");
141-
when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
142-
// Requests for serviceId are always routed to "other-node"
143-
var assignment = new RateLimitAssignment(List.of(otherNode));
144-
when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
145-
146-
mockService(listener -> listener.onResponse(mock()));
147-
var listener = doExecute(taskType);
148-
149-
// Verify request was rerouted
150-
verify(transportService).sendRequest(same(otherNode), eq(InferenceAction.NAME), any(), any());
151-
// Verify local execution didn't happen
152-
verify(listener, never()).onResponse(any());
153-
// Verify that request metric attributes were NOT recorded on the node rerouting the request to another node
154-
verify(inferenceStats.inferenceDuration(), never()).record(anyLong(), any());
155-
}
156-
157-
public void testRerouting_ToLocalNode_WithoutGoingThroughTransportLayerAgain() {
158-
DiscoveryNode localNode = mock(DiscoveryNode.class);
159-
String localNodeId = "local-node";
160-
when(localNode.getId()).thenReturn(localNodeId);
161-
162-
// The local node is the only one responsible for serviceId
163-
when(nodeClient.getLocalNodeId()).thenReturn(localNodeId);
164-
when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
165-
var assignment = new RateLimitAssignment(List.of(localNode));
166-
when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
167-
168-
mockService(listener -> listener.onResponse(mock()));
169-
var listener = doExecute(taskType);
170-
171-
verify(listener).onResponse(any());
172-
// Verify request was handled locally (not rerouted using TransportService)
173-
verify(transportService, never()).sendRequest(any(), any(), any(), any());
174-
// Verify request metric attributes were recorded on the node performing inference
175-
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
176-
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
177-
assertThat(attributes.get("node_id"), is(localNodeId));
178-
}));
179-
}
180-
181-
public void testRerouting_HandlesTransportException_FromOtherNode() {
182-
DiscoveryNode otherNode = mock(DiscoveryNode.class);
183-
when(otherNode.getId()).thenReturn("other-node");
184-
185-
when(nodeClient.getLocalNodeId()).thenReturn("local-node");
186-
when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
187-
var assignment = new RateLimitAssignment(List.of(otherNode));
188-
when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
189-
190-
mockService(listener -> listener.onResponse(mock()));
191-
192-
TransportException expectedException = new TransportException("Failed to route");
193-
doAnswer(invocation -> {
194-
TransportResponseHandler<?> handler = invocation.getArgument(3);
195-
handler.handleException(expectedException);
196-
return null;
197-
}).when(transportService).sendRequest(any(), any(), any(), any());
198-
199-
var listener = doExecute(taskType);
200-
201-
// Verify request was rerouted
202-
verify(transportService).sendRequest(same(otherNode), eq(InferenceAction.NAME), any(), any());
203-
// Verify local execution didn't happen
204-
verify(listener, never()).onResponse(any());
205-
// Verify exception was propagated from "other-node" to "local-node"
206-
verify(listener).onFailure(same(expectedException));
207-
// Verify that request metric attributes were NOT recorded on the node rerouting the request to another node
208-
verify(inferenceStats.inferenceDuration(), never()).record(anyLong(), any());
209-
}
21059
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
2020
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
2121
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
22-
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
2322
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;
2423

2524
import java.util.Optional;
@@ -49,7 +48,6 @@ protected BaseTransportInferenceAction<UnifiedCompletionAction.Request> createAc
4948
InferenceServiceRegistry serviceRegistry,
5049
InferenceStats inferenceStats,
5150
StreamingTaskManager streamingTaskManager,
52-
InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator,
5351
NodeClient nodeClient,
5452
ThreadPool threadPool
5553
) {
@@ -61,7 +59,6 @@ protected BaseTransportInferenceAction<UnifiedCompletionAction.Request> createAc
6159
serviceRegistry,
6260
inferenceStats,
6361
streamingTaskManager,
64-
inferenceServiceRateLimitCalculator,
6562
nodeClient,
6663
threadPool
6764
);

0 commit comments

Comments
 (0)