1616import org .elasticsearch .client .internal .node .NodeClient ;
1717import org .elasticsearch .cluster .node .DiscoveryNode ;
1818import org .elasticsearch .common .Randomness ;
19- import org .elasticsearch .common .io .stream .StreamInput ;
2019import org .elasticsearch .common .io .stream .Writeable ;
2120import org .elasticsearch .common .util .concurrent .EsExecutors ;
2221import org .elasticsearch .core .Nullable ;
3130import org .elasticsearch .rest .RestStatus ;
3231import org .elasticsearch .tasks .Task ;
3332import org .elasticsearch .threadpool .ThreadPool ;
34- import org .elasticsearch .transport .TransportException ;
35- import org .elasticsearch .transport .TransportResponseHandler ;
3633import org .elasticsearch .transport .TransportService ;
3734import org .elasticsearch .xpack .core .XPackField ;
3835import org .elasticsearch .xpack .core .inference .action .BaseInferenceActionRequest ;
3936import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
4037import org .elasticsearch .xpack .inference .InferencePlugin ;
4138import org .elasticsearch .xpack .inference .action .task .StreamingTaskManager ;
42- import org .elasticsearch .xpack .inference .common .InferenceServiceNodeLocalRateLimitCalculator ;
43- import org .elasticsearch .xpack .inference .common .InferenceServiceRateLimitCalculator ;
4439import org .elasticsearch .xpack .inference .registry .InferenceEndpointRegistry ;
4540import org .elasticsearch .xpack .inference .telemetry .InferenceTimer ;
4641
47- import java .io .IOException ;
4842import java .util .HashMap ;
4943import java .util .Map ;
5044import java .util .Objects ;
5145import java .util .Random ;
52- import java .util .concurrent .Executor ;
5346import java .util .concurrent .Flow ;
5447import java .util .function .Supplier ;
5548import java .util .stream .Collectors ;
6457
6558/**
6659 * Base class for transport actions that handle inference requests.
67- * Works in conjunction with {@link InferenceServiceNodeLocalRateLimitCalculator} to
68- * route requests to specific nodes, iff they support "node-local" rate limiting, which is described in detail
69- * in {@link InferenceServiceNodeLocalRateLimitCalculator}.
70- *
7160 * @param <Request> The specific type of inference request being handled
7261 */
7362public abstract class BaseTransportInferenceAction <Request extends BaseInferenceActionRequest > extends HandledTransportAction <
@@ -82,7 +71,6 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
8271 private final InferenceServiceRegistry serviceRegistry ;
8372 private final InferenceStats inferenceStats ;
8473 private final StreamingTaskManager streamingTaskManager ;
85- private final InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator ;
8674 private final NodeClient nodeClient ;
8775 private final ThreadPool threadPool ;
8876 private final TransportService transportService ;
@@ -98,7 +86,6 @@ public BaseTransportInferenceAction(
9886 InferenceStats inferenceStats ,
9987 StreamingTaskManager streamingTaskManager ,
10088 Writeable .Reader <Request > requestReader ,
101- InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator ,
10289 NodeClient nodeClient ,
10390 ThreadPool threadPool
10491 ) {
@@ -108,7 +95,6 @@ public BaseTransportInferenceAction(
10895 this .serviceRegistry = serviceRegistry ;
10996 this .inferenceStats = inferenceStats ;
11097 this .streamingTaskManager = streamingTaskManager ;
111- this .inferenceServiceRateLimitCalculator = inferenceServiceNodeLocalRateLimitCalculator ;
11298 this .nodeClient = nodeClient ;
11399 this .threadPool = threadPool ;
114100 this .transportService = transportService ;
@@ -161,15 +147,8 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
161147
162148 var service = serviceRegistry .getService (serviceName ).get ();
163149 var localNodeId = nodeClient .getLocalNodeId ();
164- var routingDecision = determineRouting ( serviceName , request , model . getTaskType (), localNodeId );
150+ inferOnServiceWithMetrics ( model , request , service , timer , localNodeId , listener );
165151
166- if (routingDecision .currentNodeShouldHandleRequest ()) {
167- inferOnServiceWithMetrics (model , request , service , timer , localNodeId , listener );
168- } else {
169- // Reroute request
170- request .setHasBeenRerouted (true );
171- rerouteRequest (request , listener , routingDecision .targetNode );
172- }
173152 }, e -> {
174153 try {
175154 inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (e ));
@@ -195,73 +174,12 @@ private void validateRequest(Request request, Model model) {
195174 validationHelper (() -> isInvalidTaskTypeForInferenceEndpoint (request , model ), () -> createInvalidTaskTypeException (request , model ));
196175 }
197176
198- private NodeRoutingDecision determineRouting (String serviceName , Request request , TaskType modelTaskType , String localNodeId ) {
199- // Rerouting not supported or request was already rerouted
200- if (inferenceServiceRateLimitCalculator .isTaskTypeReroutingSupported (serviceName , modelTaskType ) == false
201- || request .hasBeenRerouted ()) {
202- return NodeRoutingDecision .handleLocally ();
203- }
204-
205- var rateLimitAssignment = inferenceServiceRateLimitCalculator .getRateLimitAssignment (serviceName , modelTaskType );
206-
207- // No assignment yet
208- if (rateLimitAssignment == null ) {
209- return NodeRoutingDecision .handleLocally ();
210- }
211-
212- var responsibleNodes = rateLimitAssignment .responsibleNodes ();
213-
214- // Empty assignment
215- if (responsibleNodes == null || responsibleNodes .isEmpty ()) {
216- return NodeRoutingDecision .handleLocally ();
217- }
218-
219- var nodeToHandleRequest = responsibleNodes .get (random .nextInt (responsibleNodes .size ()));
220-
221- // The drawn node is the current node
222- if (nodeToHandleRequest .getId ().equals (localNodeId )) {
223- return NodeRoutingDecision .handleLocally ();
224- }
225-
226- // Reroute request
227- return NodeRoutingDecision .routeTo (nodeToHandleRequest );
228- }
229-
230177 private static void validationHelper (Supplier <Boolean > validationFailure , Supplier <ElasticsearchStatusException > exceptionCreator ) {
231178 if (validationFailure .get ()) {
232179 throw exceptionCreator .get ();
233180 }
234181 }
235182
236- private void rerouteRequest (Request request , ActionListener <InferenceAction .Response > listener , DiscoveryNode nodeToHandleRequest ) {
237- transportService .sendRequest (
238- nodeToHandleRequest ,
239- InferenceAction .NAME ,
240- request ,
241- new TransportResponseHandler <InferenceAction .Response >() {
242- @ Override
243- public Executor executor () {
244- return threadPool .executor (InferencePlugin .UTILITY_THREAD_POOL_NAME );
245- }
246-
247- @ Override
248- public void handleResponse (InferenceAction .Response response ) {
249- listener .onResponse (response );
250- }
251-
252- @ Override
253- public void handleException (TransportException exp ) {
254- listener .onFailure (exp );
255- }
256-
257- @ Override
258- public InferenceAction .Response read (StreamInput in ) throws IOException {
259- return new InferenceAction .Response (in );
260- }
261- }
262- );
263- }
264-
265183 private void recordRequestDurationMetrics (Model model , InferenceTimer timer , @ Nullable Throwable t ) {
266184 Map <String , Object > metricAttributes = new HashMap <>();
267185 metricAttributes .putAll (modelAttributes (model ));
0 commit comments