77
88package org .elasticsearch .xpack .inference .action ;
99
10+ import org .apache .logging .log4j .LogManager ;
11+ import org .apache .logging .log4j .Logger ;
1012import org .elasticsearch .ElasticsearchStatusException ;
1113import org .elasticsearch .action .ActionListener ;
1214import org .elasticsearch .action .support .ActionFilters ;
1315import org .elasticsearch .action .support .HandledTransportAction ;
1416import org .elasticsearch .common .logging .DeprecationLogger ;
1517import org .elasticsearch .common .util .concurrent .EsExecutors ;
1618import org .elasticsearch .common .xcontent .ChunkedToXContent ;
19+ import org .elasticsearch .core .Nullable ;
1720import org .elasticsearch .inference .InferenceService ;
1821import org .elasticsearch .inference .InferenceServiceRegistry ;
1922import org .elasticsearch .inference .InferenceServiceResults ;
2629import org .elasticsearch .transport .TransportService ;
2730import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
2831import org .elasticsearch .xpack .inference .action .task .StreamingTaskManager ;
32+ import org .elasticsearch .xpack .inference .common .DelegatingProcessor ;
2933import org .elasticsearch .xpack .inference .registry .ModelRegistry ;
3034import org .elasticsearch .xpack .inference .telemetry .InferenceStats ;
35+ import org .elasticsearch .xpack .inference .telemetry .InferenceTimer ;
3136
32- import java .util .Set ;
3337import java .util .stream .Collectors ;
3438
3539import static org .elasticsearch .core .Strings .format ;
40+ import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .modelAttributes ;
41+ import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .responseAttributes ;
3642
3743public class TransportInferenceAction extends HandledTransportAction <InferenceAction .Request , InferenceAction .Response > {
44+ private static final Logger log = LogManager .getLogger (TransportInferenceAction .class );
3845 private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference" ;
3946 private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]" ;
4047
41- private static final Set <Class <? extends InferenceService >> supportsStreaming = Set .of ();
42-
4348 private final ModelRegistry modelRegistry ;
4449 private final InferenceServiceRegistry serviceRegistry ;
4550 private final InferenceStats inferenceStats ;
@@ -64,17 +69,22 @@ public TransportInferenceAction(
6469
6570 @ Override
6671 protected void doExecute (Task task , InferenceAction .Request request , ActionListener <InferenceAction .Response > listener ) {
72+ var timer = InferenceTimer .start ();
6773
68- ActionListener < UnparsedModel > getModelListener = listener . delegateFailureAndWrap (( delegate , unparsedModel ) -> {
74+ var getModelListener = ActionListener . wrap (( UnparsedModel unparsedModel ) -> {
6975 var service = serviceRegistry .getService (unparsedModel .service ());
7076 if (service .isEmpty ()) {
71- listener .onFailure (unknownServiceException (unparsedModel .service (), request .getInferenceEntityId ()));
77+ var e = unknownServiceException (unparsedModel .service (), request .getInferenceEntityId ());
78+ recordMetrics (unparsedModel , timer , e );
79+ listener .onFailure (e );
7280 return ;
7381 }
7482
7583 if (request .getTaskType ().isAnyOrSame (unparsedModel .taskType ()) == false ) {
7684 // not the wildcard task type and not the model task type
77- listener .onFailure (incompatibleTaskTypeException (request .getTaskType (), unparsedModel .taskType ()));
85+ var e = incompatibleTaskTypeException (request .getTaskType (), unparsedModel .taskType ());
86+ recordMetrics (unparsedModel , timer , e );
87+ listener .onFailure (e );
7888 return ;
7989 }
8090
@@ -85,20 +95,69 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe
8595 unparsedModel .settings (),
8696 unparsedModel .secrets ()
8797 );
88- inferOnService (model , request , service .get (), delegate );
98+ inferOnServiceWithMetrics (model , request , service .get (), timer , listener );
99+ }, e -> {
100+ try {
101+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (e ));
102+ } catch (Exception metricsException ) {
103+ log .atDebug ().withThrowable (metricsException ).log ("Failed to record metrics when the model is missing, dropping metrics" );
104+ }
105+ listener .onFailure (e );
89106 });
90107
91108 modelRegistry .getModelWithSecrets (request .getInferenceEntityId (), getModelListener );
92109 }
93110
94- private void inferOnService (
111+ private void recordMetrics (UnparsedModel model , InferenceTimer timer , @ Nullable Throwable t ) {
112+ try {
113+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (model , t ));
114+ } catch (Exception e ) {
115+ log .atDebug ().withThrowable (e ).log ("Failed to record metrics with an unparsed model, dropping metrics" );
116+ }
117+ }
118+
119+ private void inferOnServiceWithMetrics (
95120 Model model ,
96121 InferenceAction .Request request ,
97122 InferenceService service ,
123+ InferenceTimer timer ,
98124 ActionListener <InferenceAction .Response > listener
125+ ) {
126+ inferenceStats .requestCount ().incrementBy (1 , modelAttributes (model ));
127+ inferOnService (model , request , service , ActionListener .wrap (inferenceResults -> {
128+ if (request .isStreaming ()) {
129+ var taskProcessor = streamingTaskManager .<ChunkedToXContent >create (STREAMING_INFERENCE_TASK_TYPE , STREAMING_TASK_ACTION );
130+ inferenceResults .publisher ().subscribe (taskProcessor );
131+
132+ var instrumentedStream = new PublisherWithMetrics (timer , model );
133+ taskProcessor .subscribe (instrumentedStream );
134+
135+ listener .onResponse (new InferenceAction .Response (inferenceResults , instrumentedStream ));
136+ } else {
137+ recordMetrics (model , timer , null );
138+ listener .onResponse (new InferenceAction .Response (inferenceResults ));
139+ }
140+ }, e -> {
141+ recordMetrics (model , timer , e );
142+ listener .onFailure (e );
143+ }));
144+ }
145+
146+ private void recordMetrics (Model model , InferenceTimer timer , @ Nullable Throwable t ) {
147+ try {
148+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (model , t ));
149+ } catch (Exception e ) {
150+ log .atDebug ().withThrowable (e ).log ("Failed to record metrics with a parsed model, dropping metrics" );
151+ }
152+ }
153+
154+ private void inferOnService (
155+ Model model ,
156+ InferenceAction .Request request ,
157+ InferenceService service ,
158+ ActionListener <InferenceServiceResults > listener
99159 ) {
100160 if (request .isStreaming () == false || service .canStream (request .getTaskType ())) {
101- inferenceStats .incrementRequestCount (model );
102161 service .infer (
103162 model ,
104163 request .getQuery (),
@@ -107,7 +166,7 @@ private void inferOnService(
107166 request .getTaskSettings (),
108167 request .getInputType (),
109168 request .getInferenceTimeout (),
110- createListener ( request , listener )
169+ listener
111170 );
112171 } else {
113172 listener .onFailure (unsupportedStreamingTaskException (request , service ));
@@ -135,20 +194,6 @@ private ElasticsearchStatusException unsupportedStreamingTaskException(Inference
135194 }
136195 }
137196
138- private ActionListener <InferenceServiceResults > createListener (
139- InferenceAction .Request request ,
140- ActionListener <InferenceAction .Response > listener
141- ) {
142- if (request .isStreaming ()) {
143- return listener .delegateFailureAndWrap ((l , inferenceResults ) -> {
144- var taskProcessor = streamingTaskManager .<ChunkedToXContent >create (STREAMING_INFERENCE_TASK_TYPE , STREAMING_TASK_ACTION );
145- inferenceResults .publisher ().subscribe (taskProcessor );
146- l .onResponse (new InferenceAction .Response (inferenceResults , taskProcessor ));
147- });
148- }
149- return listener .delegateFailureAndWrap ((l , inferenceResults ) -> l .onResponse (new InferenceAction .Response (inferenceResults )));
150- }
151-
152197 private static ElasticsearchStatusException unknownServiceException (String service , String inferenceId ) {
153198 return new ElasticsearchStatusException ("Unknown service [{}] for model [{}]. " , RestStatus .BAD_REQUEST , service , inferenceId );
154199 }
@@ -162,4 +207,37 @@ private static ElasticsearchStatusException incompatibleTaskTypeException(TaskTy
162207 );
163208 }
164209
210+ private class PublisherWithMetrics extends DelegatingProcessor <ChunkedToXContent , ChunkedToXContent > {
211+ private final InferenceTimer timer ;
212+ private final Model model ;
213+
214+ private PublisherWithMetrics (InferenceTimer timer , Model model ) {
215+ this .timer = timer ;
216+ this .model = model ;
217+ }
218+
219+ @ Override
220+ protected void next (ChunkedToXContent item ) {
221+ downstream ().onNext (item );
222+ }
223+
224+ @ Override
225+ public void onError (Throwable throwable ) {
226+ recordMetrics (model , timer , throwable );
227+ super .onError (throwable );
228+ }
229+
230+ @ Override
231+ protected void onCancel () {
232+ recordMetrics (model , timer , null );
233+ super .onCancel ();
234+ }
235+
236+ @ Override
237+ public void onComplete () {
238+ recordMetrics (model , timer , null );
239+ super .onComplete ();
240+ }
241+ }
242+
165243}
0 commit comments