77
88package org .elasticsearch .xpack .inference .action ;
99
10+ import org .apache .logging .log4j .LogManager ;
11+ import org .apache .logging .log4j .Logger ;
1012import org .elasticsearch .ElasticsearchStatusException ;
1113import org .elasticsearch .action .ActionListener ;
1214import org .elasticsearch .action .support .ActionFilters ;
1315import org .elasticsearch .action .support .HandledTransportAction ;
1416import org .elasticsearch .common .util .concurrent .EsExecutors ;
1517import org .elasticsearch .common .xcontent .ChunkedToXContent ;
18+ import org .elasticsearch .core .Nullable ;
1619import org .elasticsearch .inference .InferenceService ;
1720import org .elasticsearch .inference .InferenceServiceRegistry ;
1821import org .elasticsearch .inference .InferenceServiceResults ;
2528import org .elasticsearch .transport .TransportService ;
2629import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
2730import org .elasticsearch .xpack .inference .action .task .StreamingTaskManager ;
31+ import org .elasticsearch .xpack .inference .common .DelegatingProcessor ;
2832import org .elasticsearch .xpack .inference .registry .ModelRegistry ;
2933import org .elasticsearch .xpack .inference .telemetry .InferenceStats ;
34+ import org .elasticsearch .xpack .inference .telemetry .InferenceTimer ;
3035
31- import java .util .Set ;
3236import java .util .stream .Collectors ;
3337
3438import static org .elasticsearch .core .Strings .format ;
39+ import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .modelAttributes ;
40+ import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .responseAttributes ;
3541
3642public class TransportInferenceAction extends HandledTransportAction <InferenceAction .Request , InferenceAction .Response > {
43+ private static final Logger log = LogManager .getLogger (TransportInferenceAction .class );
3744 private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference" ;
3845 private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]" ;
3946
40- private static final Set <Class <? extends InferenceService >> supportsStreaming = Set .of ();
41-
4247 private final ModelRegistry modelRegistry ;
4348 private final InferenceServiceRegistry serviceRegistry ;
4449 private final InferenceStats inferenceStats ;
@@ -62,17 +67,22 @@ public TransportInferenceAction(
6267
6368 @ Override
6469 protected void doExecute (Task task , InferenceAction .Request request , ActionListener <InferenceAction .Response > listener ) {
70+ var timer = InferenceTimer .start ();
6571
66- ActionListener < UnparsedModel > getModelListener = listener . delegateFailureAndWrap (( delegate , unparsedModel ) -> {
72+ var getModelListener = ActionListener . wrap (( UnparsedModel unparsedModel ) -> {
6773 var service = serviceRegistry .getService (unparsedModel .service ());
6874 if (service .isEmpty ()) {
69- listener .onFailure (unknownServiceException (unparsedModel .service (), request .getInferenceEntityId ()));
75+ var e = unknownServiceException (unparsedModel .service (), request .getInferenceEntityId ());
76+ recordMetrics (unparsedModel , timer , e );
77+ listener .onFailure (e );
7078 return ;
7179 }
7280
7381 if (request .getTaskType ().isAnyOrSame (unparsedModel .taskType ()) == false ) {
7482 // not the wildcard task type and not the model task type
75- listener .onFailure (incompatibleTaskTypeException (request .getTaskType (), unparsedModel .taskType ()));
83+ var e = incompatibleTaskTypeException (request .getTaskType (), unparsedModel .taskType ());
84+ recordMetrics (unparsedModel , timer , e );
85+ listener .onFailure (e );
7686 return ;
7787 }
7888
@@ -83,20 +93,69 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe
8393 unparsedModel .settings (),
8494 unparsedModel .secrets ()
8595 );
86- inferOnService (model , request , service .get (), delegate );
96+ inferOnServiceWithMetrics (model , request , service .get (), timer , listener );
97+ }, e -> {
98+ try {
99+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (e ));
100+ } catch (Exception metricsException ) {
101+ log .atDebug ().withThrowable (metricsException ).log ("Failed to record metrics when the model is missing, dropping metrics" );
102+ }
103+ listener .onFailure (e );
87104 });
88105
89106 modelRegistry .getModelWithSecrets (request .getInferenceEntityId (), getModelListener );
90107 }
91108
92- private void inferOnService (
109+ private void recordMetrics (UnparsedModel model , InferenceTimer timer , @ Nullable Throwable t ) {
110+ try {
111+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (model , t ));
112+ } catch (Exception e ) {
113+ log .atDebug ().withThrowable (e ).log ("Failed to record metrics with an unparsed model, dropping metrics" );
114+ }
115+ }
116+
117+ private void inferOnServiceWithMetrics (
93118 Model model ,
94119 InferenceAction .Request request ,
95120 InferenceService service ,
121+ InferenceTimer timer ,
96122 ActionListener <InferenceAction .Response > listener
123+ ) {
124+ inferenceStats .requestCount ().incrementBy (1 , modelAttributes (model ));
125+ inferOnService (model , request , service , ActionListener .wrap (inferenceResults -> {
126+ if (request .isStreaming ()) {
127+ var taskProcessor = streamingTaskManager .<ChunkedToXContent >create (STREAMING_INFERENCE_TASK_TYPE , STREAMING_TASK_ACTION );
128+ inferenceResults .publisher ().subscribe (taskProcessor );
129+
130+ var instrumentedStream = new PublisherWithMetrics (timer , model );
131+ taskProcessor .subscribe (instrumentedStream );
132+
133+ listener .onResponse (new InferenceAction .Response (inferenceResults , instrumentedStream ));
134+ } else {
135+ recordMetrics (model , timer , null );
136+ listener .onResponse (new InferenceAction .Response (inferenceResults ));
137+ }
138+ }, e -> {
139+ recordMetrics (model , timer , e );
140+ listener .onFailure (e );
141+ }));
142+ }
143+
144+ private void recordMetrics (Model model , InferenceTimer timer , @ Nullable Throwable t ) {
145+ try {
146+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (model , t ));
147+ } catch (Exception e ) {
148+ log .atDebug ().withThrowable (e ).log ("Failed to record metrics with a parsed model, dropping metrics" );
149+ }
150+ }
151+
152+ private void inferOnService (
153+ Model model ,
154+ InferenceAction .Request request ,
155+ InferenceService service ,
156+ ActionListener <InferenceServiceResults > listener
97157 ) {
98158 if (request .isStreaming () == false || service .canStream (request .getTaskType ())) {
99- inferenceStats .incrementRequestCount (model );
100159 service .infer (
101160 model ,
102161 request .getQuery (),
@@ -105,7 +164,7 @@ private void inferOnService(
105164 request .getTaskSettings (),
106165 request .getInputType (),
107166 request .getInferenceTimeout (),
108- createListener ( request , listener )
167+ listener
109168 );
110169 } else {
111170 listener .onFailure (unsupportedStreamingTaskException (request , service ));
@@ -133,20 +192,6 @@ private ElasticsearchStatusException unsupportedStreamingTaskException(Inference
133192 }
134193 }
135194
136- private ActionListener <InferenceServiceResults > createListener (
137- InferenceAction .Request request ,
138- ActionListener <InferenceAction .Response > listener
139- ) {
140- if (request .isStreaming ()) {
141- return listener .delegateFailureAndWrap ((l , inferenceResults ) -> {
142- var taskProcessor = streamingTaskManager .<ChunkedToXContent >create (STREAMING_INFERENCE_TASK_TYPE , STREAMING_TASK_ACTION );
143- inferenceResults .publisher ().subscribe (taskProcessor );
144- l .onResponse (new InferenceAction .Response (inferenceResults , taskProcessor ));
145- });
146- }
147- return listener .delegateFailureAndWrap ((l , inferenceResults ) -> l .onResponse (new InferenceAction .Response (inferenceResults )));
148- }
149-
150195 private static ElasticsearchStatusException unknownServiceException (String service , String inferenceId ) {
151196 return new ElasticsearchStatusException ("Unknown service [{}] for model [{}]. " , RestStatus .BAD_REQUEST , service , inferenceId );
152197 }
@@ -160,4 +205,37 @@ private static ElasticsearchStatusException incompatibleTaskTypeException(TaskTy
160205 );
161206 }
162207
208+ private class PublisherWithMetrics extends DelegatingProcessor <ChunkedToXContent , ChunkedToXContent > {
209+ private final InferenceTimer timer ;
210+ private final Model model ;
211+
212+ private PublisherWithMetrics (InferenceTimer timer , Model model ) {
213+ this .timer = timer ;
214+ this .model = model ;
215+ }
216+
217+ @ Override
218+ protected void next (ChunkedToXContent item ) {
219+ downstream ().onNext (item );
220+ }
221+
222+ @ Override
223+ public void onError (Throwable throwable ) {
224+ recordMetrics (model , timer , throwable );
225+ super .onError (throwable );
226+ }
227+
228+ @ Override
229+ protected void onCancel () {
230+ recordMetrics (model , timer , null );
231+ super .onCancel ();
232+ }
233+
234+ @ Override
235+ public void onComplete () {
236+ recordMetrics (model , timer , null );
237+ super .onComplete ();
238+ }
239+ }
240+
163241}
0 commit comments