4848import org .elasticsearch .xpack .inference .telemetry .InferenceTimer ;
4949
5050import java .io .IOException ;
51+ import java .util .HashMap ;
52+ import java .util .Map ;
5153import java .util .Random ;
5254import java .util .concurrent .Executor ;
5355import java .util .concurrent .Flow ;
5961import static org .elasticsearch .xpack .inference .InferencePlugin .INFERENCE_API_FEATURE ;
6062import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .modelAttributes ;
6163import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .responseAttributes ;
64+ import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .routingAttributes ;
6265
6366/**
6467 * Base class for transport actions that handle inference requests.
@@ -145,7 +148,8 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
145148 }
146149
147150 var service = serviceRegistry .getService (serviceName ).get ();
148- var routingDecision = determineRouting (serviceName , request , unparsedModel );
151+ var localNodeId = nodeClient .getLocalNodeId ();
152+ var routingDecision = determineRouting (serviceName , request , unparsedModel , localNodeId );
149153
150154 if (routingDecision .currentNodeShouldHandleRequest ()) {
151155 var model = service .parsePersistedConfigWithSecrets (
@@ -154,7 +158,7 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
154158 unparsedModel .settings (),
155159 unparsedModel .secrets ()
156160 );
157- inferOnServiceWithMetrics (model , request , service , timer , listener );
161+ inferOnServiceWithMetrics (model , request , service , timer , localNodeId , listener );
158162 } else {
159163 // Reroute request
160164 request .setHasBeenRerouted (true );
@@ -188,7 +192,7 @@ private void validateRequest(Request request, UnparsedModel unparsedModel) {
188192 );
189193 }
190194
191- private NodeRoutingDecision determineRouting (String serviceName , Request request , UnparsedModel unparsedModel ) {
195+ private NodeRoutingDecision determineRouting (String serviceName , Request request , UnparsedModel unparsedModel , String localNodeId ) {
192196 var modelTaskType = unparsedModel .taskType ();
193197
194198 // Rerouting not supported or request was already rerouted
@@ -212,7 +216,6 @@ private NodeRoutingDecision determineRouting(String serviceName, Request request
212216 }
213217
214218 var nodeToHandleRequest = responsibleNodes .get (random .nextInt (responsibleNodes .size ()));
215- String localNodeId = nodeClient .getLocalNodeId ();
216219
217220 // The drawn node is the current node
218221 if (nodeToHandleRequest .getId ().equals (localNodeId )) {
@@ -260,7 +263,11 @@ public InferenceAction.Response read(StreamInput in) throws IOException {
260263
261264 private void recordMetrics (UnparsedModel model , InferenceTimer timer , @ Nullable Throwable t ) {
262265 try {
263- inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (model , t ));
266+ Map <String , Object > metricAttributes = new HashMap <>();
267+ metricAttributes .putAll (modelAttributes (model ));
268+ metricAttributes .putAll (responseAttributes (unwrapCause (t )));
269+
270+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), metricAttributes );
264271 } catch (Exception e ) {
265272 log .atDebug ().withThrowable (e ).log ("Failed to record metrics with an unparsed model, dropping metrics" );
266273 }
@@ -271,6 +278,7 @@ private void inferOnServiceWithMetrics(
271278 Request request ,
272279 InferenceService service ,
273280 InferenceTimer timer ,
281+ String localNodeId ,
274282 ActionListener <InferenceAction .Response > listener
275283 ) {
276284 inferenceStats .requestCount ().incrementBy (1 , modelAttributes (model ));
@@ -279,18 +287,18 @@ private void inferOnServiceWithMetrics(
279287 var taskProcessor = streamingTaskManager .<ChunkedToXContent >create (STREAMING_INFERENCE_TASK_TYPE , STREAMING_TASK_ACTION );
280288 inferenceResults .publisher ().subscribe (taskProcessor );
281289
282- var instrumentedStream = new PublisherWithMetrics (timer , model );
290+ var instrumentedStream = new PublisherWithMetrics (timer , model , request , localNodeId );
283291 taskProcessor .subscribe (instrumentedStream );
284292
285293 var streamErrorHandler = streamErrorHandler (instrumentedStream );
286294
287295 listener .onResponse (new InferenceAction .Response (inferenceResults , streamErrorHandler ));
288296 } else {
289- recordMetrics (model , timer , null );
297+ recordMetrics (model , timer , request , localNodeId , null );
290298 listener .onResponse (new InferenceAction .Response (inferenceResults ));
291299 }
292300 }, e -> {
293- recordMetrics (model , timer , e );
301+ recordMetrics (model , timer , request , localNodeId , e );
294302 listener .onFailure (e );
295303 }));
296304 }
@@ -299,9 +307,14 @@ protected Flow.Publisher<ChunkedToXContent> streamErrorHandler(Flow.Processor<Ch
299307 return upstream ;
300308 }
301309
302- private void recordMetrics (Model model , InferenceTimer timer , @ Nullable Throwable t ) {
310+ private void recordMetrics (Model model , InferenceTimer timer , Request request , String localNodeId , @ Nullable Throwable t ) {
303311 try {
304- inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (model , unwrapCause (t )));
312+ Map <String , Object > metricAttributes = new HashMap <>();
313+ metricAttributes .putAll (modelAttributes (model ));
314+ metricAttributes .putAll (routingAttributes (request , localNodeId ));
315+ metricAttributes .putAll (responseAttributes (unwrapCause (t )));
316+
317+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), metricAttributes );
305318 } catch (Exception e ) {
306319 log .atDebug ().withThrowable (e ).log ("Failed to record metrics with a parsed model, dropping metrics" );
307320 }
@@ -353,10 +366,14 @@ private class PublisherWithMetrics extends DelegatingProcessor<ChunkedToXContent
353366
354367 private final InferenceTimer timer ;
355368 private final Model model ;
369+ private final Request request ;
370+ private final String localNodeId ;
356371
357- private PublisherWithMetrics (InferenceTimer timer , Model model ) {
372+ private PublisherWithMetrics (InferenceTimer timer , Model model , Request request , String localNodeId ) {
358373 this .timer = timer ;
359374 this .model = model ;
375+ this .request = request ;
376+ this .localNodeId = localNodeId ;
360377 }
361378
362379 @ Override
@@ -366,19 +383,19 @@ protected void next(ChunkedToXContent item) {
366383
367384 @ Override
368385 public void onError (Throwable throwable ) {
369- recordMetrics (model , timer , throwable );
386+ recordMetrics (model , timer , request , localNodeId , throwable );
370387 super .onError (throwable );
371388 }
372389
373390 @ Override
374391 protected void onCancel () {
375- recordMetrics (model , timer , null );
392+ recordMetrics (model , timer , request , localNodeId , null );
376393 super .onCancel ();
377394 }
378395
379396 @ Override
380397 public void onComplete () {
381- recordMetrics (model , timer , null );
398+ recordMetrics (model , timer , request , localNodeId , null );
382399 super .onComplete ();
383400 }
384401 }
0 commit comments