4747import org .elasticsearch .xpack .inference .telemetry .InferenceTimer ;
4848
4949import java .io .IOException ;
50+ import java .util .HashMap ;
51+ import java .util .Map ;
5052import java .util .Random ;
5153import java .util .concurrent .Executor ;
5254import java .util .concurrent .Flow ;
5860import static org .elasticsearch .xpack .inference .InferencePlugin .INFERENCE_API_FEATURE ;
5961import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .modelAttributes ;
6062import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .responseAttributes ;
63+ import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .routingAttributes ;
6164
6265/**
6366 * Base class for transport actions that handle inference requests.
@@ -138,13 +141,14 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
138141 try {
139142 validateRequest (request , unparsedModel );
140143 } catch (Exception e ) {
141- recordMetrics (unparsedModel , timer , e );
144+ recordRequestDurationMetrics (unparsedModel , timer , e );
142145 listener .onFailure (e );
143146 return ;
144147 }
145148
146149 var service = serviceRegistry .getService (serviceName ).get ();
147- var routingDecision = determineRouting (serviceName , request , unparsedModel );
150+ var localNodeId = nodeClient .getLocalNodeId ();
151+ var routingDecision = determineRouting (serviceName , request , unparsedModel , localNodeId );
148152
149153 if (routingDecision .currentNodeShouldHandleRequest ()) {
150154 var model = service .parsePersistedConfigWithSecrets (
@@ -153,7 +157,7 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
153157 unparsedModel .settings (),
154158 unparsedModel .secrets ()
155159 );
156- inferOnServiceWithMetrics (model , request , service , timer , listener );
160+ inferOnServiceWithMetrics (model , request , service , timer , localNodeId , listener );
157161 } else {
158162 // Reroute request
159163 request .setHasBeenRerouted (true );
@@ -187,7 +191,7 @@ private void validateRequest(Request request, UnparsedModel unparsedModel) {
187191 );
188192 }
189193
190- private NodeRoutingDecision determineRouting (String serviceName , Request request , UnparsedModel unparsedModel ) {
194+ private NodeRoutingDecision determineRouting (String serviceName , Request request , UnparsedModel unparsedModel , String localNodeId ) {
191195 var modelTaskType = unparsedModel .taskType ();
192196
193197 // Rerouting not supported or request was already rerouted
@@ -211,7 +215,6 @@ private NodeRoutingDecision determineRouting(String serviceName, Request request
211215 }
212216
213217 var nodeToHandleRequest = responsibleNodes .get (random .nextInt (responsibleNodes .size ()));
214- String localNodeId = nodeClient .getLocalNodeId ();
215218
216219 // The drawn node is the current node
217220 if (nodeToHandleRequest .getId ().equals (localNodeId )) {
@@ -257,9 +260,13 @@ public InferenceAction.Response read(StreamInput in) throws IOException {
257260 );
258261 }
259262
260- private void recordMetrics (UnparsedModel model , InferenceTimer timer , @ Nullable Throwable t ) {
263+ private void recordRequestDurationMetrics (UnparsedModel model , InferenceTimer timer , @ Nullable Throwable t ) {
261264 try {
262- inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (model , t ));
265+ Map <String , Object > metricAttributes = new HashMap <>();
266+ metricAttributes .putAll (modelAttributes (model ));
267+ metricAttributes .putAll (responseAttributes (unwrapCause (t )));
268+
269+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), metricAttributes );
263270 } catch (Exception e ) {
264271 log .atDebug ().withThrowable (e ).log ("Failed to record metrics with an unparsed model, dropping metrics" );
265272 }
@@ -270,9 +277,10 @@ private void inferOnServiceWithMetrics(
270277 Request request ,
271278 InferenceService service ,
272279 InferenceTimer timer ,
280+ String localNodeId ,
273281 ActionListener <InferenceAction .Response > listener
274282 ) {
275- inferenceStats . requestCount (). incrementBy ( 1 , modelAttributes ( model ) );
283+ recordRequestCountMetrics ( model , request , localNodeId );
276284 inferOnService (model , request , service , ActionListener .wrap (inferenceResults -> {
277285 if (request .isStreaming ()) {
278286 var taskProcessor = streamingTaskManager .<InferenceServiceResults .Result >create (
@@ -281,18 +289,18 @@ private void inferOnServiceWithMetrics(
281289 );
282290 inferenceResults .publisher ().subscribe (taskProcessor );
283291
284- var instrumentedStream = new PublisherWithMetrics (timer , model );
292+ var instrumentedStream = new PublisherWithMetrics (timer , model , request , localNodeId );
285293 taskProcessor .subscribe (instrumentedStream );
286294
287295 var streamErrorHandler = streamErrorHandler (instrumentedStream );
288296
289297 listener .onResponse (new InferenceAction .Response (inferenceResults , streamErrorHandler ));
290298 } else {
291- recordMetrics (model , timer , null );
299+ recordRequestDurationMetrics (model , timer , request , localNodeId , null );
292300 listener .onResponse (new InferenceAction .Response (inferenceResults ));
293301 }
294302 }, e -> {
295- recordMetrics (model , timer , e );
303+ recordRequestDurationMetrics (model , timer , request , localNodeId , e );
296304 listener .onFailure (e );
297305 }));
298306 }
@@ -301,9 +309,28 @@ protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream
301309 return upstream ;
302310 }
303311
304- private void recordMetrics (Model model , InferenceTimer timer , @ Nullable Throwable t ) {
312+ private void recordRequestCountMetrics (Model model , Request request , String localNodeId ) {
313+ Map <String , Object > requestCountAttributes = new HashMap <>();
314+ requestCountAttributes .putAll (modelAttributes (model ));
315+ requestCountAttributes .putAll (routingAttributes (request , localNodeId ));
316+
317+ inferenceStats .requestCount ().incrementBy (1 , requestCountAttributes );
318+ }
319+
320+ private void recordRequestDurationMetrics (
321+ Model model ,
322+ InferenceTimer timer ,
323+ Request request ,
324+ String localNodeId ,
325+ @ Nullable Throwable t
326+ ) {
305327 try {
306- inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (model , unwrapCause (t )));
328+ Map <String , Object > metricAttributes = new HashMap <>();
329+ metricAttributes .putAll (modelAttributes (model ));
330+ metricAttributes .putAll (routingAttributes (request , localNodeId ));
331+ metricAttributes .putAll (responseAttributes (unwrapCause (t )));
332+
333+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), metricAttributes );
307334 } catch (Exception e ) {
308335 log .atDebug ().withThrowable (e ).log ("Failed to record metrics with a parsed model, dropping metrics" );
309336 }
@@ -355,10 +382,14 @@ private class PublisherWithMetrics extends DelegatingProcessor<InferenceServiceR
355382
356383 private final InferenceTimer timer ;
357384 private final Model model ;
385+ private final Request request ;
386+ private final String localNodeId ;
358387
359- private PublisherWithMetrics (InferenceTimer timer , Model model ) {
388+ private PublisherWithMetrics (InferenceTimer timer , Model model , Request request , String localNodeId ) {
360389 this .timer = timer ;
361390 this .model = model ;
391+ this .request = request ;
392+ this .localNodeId = localNodeId ;
362393 }
363394
364395 @ Override
@@ -368,19 +399,19 @@ protected void next(InferenceServiceResults.Result item) {
368399
369400 @ Override
370401 public void onError (Throwable throwable ) {
371- recordMetrics (model , timer , throwable );
402+ recordRequestDurationMetrics (model , timer , request , localNodeId , throwable );
372403 super .onError (throwable );
373404 }
374405
375406 @ Override
376407 protected void onCancel () {
377- recordMetrics (model , timer , null );
408+ recordRequestDurationMetrics (model , timer , request , localNodeId , null );
378409 super .onCancel ();
379410 }
380411
381412 @ Override
382413 public void onComplete () {
383- recordMetrics (model , timer , null );
414+ recordRequestDurationMetrics (model , timer , request , localNodeId , null );
384415 super .onComplete ();
385416 }
386417 }
0 commit comments