3939import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
4040import org .elasticsearch .xpack .inference .InferencePlugin ;
4141import org .elasticsearch .xpack .inference .action .task .StreamingTaskManager ;
42- import org .elasticsearch .xpack .inference .common .DelegatingProcessor ;
4342import org .elasticsearch .xpack .inference .common .InferenceServiceNodeLocalRateLimitCalculator ;
4443import org .elasticsearch .xpack .inference .common .InferenceServiceRateLimitCalculator ;
4544import org .elasticsearch .xpack .inference .registry .ModelRegistry ;
@@ -297,8 +296,7 @@ private void inferOnServiceWithMetrics(
297296 );
298297 inferenceResults .publisher ().subscribe (taskProcessor );
299298
300- var instrumentedStream = new PublisherWithMetrics (timer , model , request , localNodeId );
301- taskProcessor .subscribe (instrumentedStream );
299+ var instrumentedStream = publisherWithMetrics (timer , model , request , localNodeId , taskProcessor );
302300
303301 var streamErrorHandler = streamErrorHandler (instrumentedStream );
304302
@@ -313,7 +311,52 @@ private void inferOnServiceWithMetrics(
313311 }));
314312 }
315313
316- protected <T > Flow .Publisher <T > streamErrorHandler (Flow .Processor <T , T > upstream ) {
314+ private <T > Flow .Publisher <T > publisherWithMetrics (
315+ InferenceTimer timer ,
316+ Model model ,
317+ Request request ,
318+ String localNodeId ,
319+ Flow .Processor <T , T > upstream
320+ ) {
321+ return downstream -> {
322+ upstream .subscribe (new Flow .Subscriber <>() {
323+ @ Override
324+ public void onSubscribe (Flow .Subscription subscription ) {
325+ downstream .onSubscribe (new Flow .Subscription () {
326+ @ Override
327+ public void request (long n ) {
328+ subscription .request (n );
329+ }
330+
331+ @ Override
332+ public void cancel () {
333+ recordRequestDurationMetrics (model , timer , request , localNodeId , null );
334+ subscription .cancel ();
335+ }
336+ });
337+ }
338+
339+ @ Override
340+ public void onNext (T item ) {
341+ downstream .onNext (item );
342+ }
343+
344+ @ Override
345+ public void onError (Throwable throwable ) {
346+ recordRequestDurationMetrics (model , timer , request , localNodeId , throwable );
347+ downstream .onError (throwable );
348+ }
349+
350+ @ Override
351+ public void onComplete () {
352+ recordRequestDurationMetrics (model , timer , request , localNodeId , null );
353+ downstream .onComplete ();
354+ }
355+ });
356+ };
357+ }
358+
359+ protected <T > Flow .Publisher <T > streamErrorHandler (Flow .Publisher <T > upstream ) {
317360 return upstream ;
318361 }
319362
@@ -386,44 +429,6 @@ private static ElasticsearchStatusException requestModelTaskTypeMismatchExceptio
386429 );
387430 }
388431
389- private class PublisherWithMetrics extends DelegatingProcessor <InferenceServiceResults .Result , InferenceServiceResults .Result > {
390-
391- private final InferenceTimer timer ;
392- private final Model model ;
393- private final Request request ;
394- private final String localNodeId ;
395-
396- private PublisherWithMetrics (InferenceTimer timer , Model model , Request request , String localNodeId ) {
397- this .timer = timer ;
398- this .model = model ;
399- this .request = request ;
400- this .localNodeId = localNodeId ;
401- }
402-
403- @ Override
404- protected void next (InferenceServiceResults .Result item ) {
405- downstream ().onNext (item );
406- }
407-
408- @ Override
409- public void onError (Throwable throwable ) {
410- recordRequestDurationMetrics (model , timer , request , localNodeId , throwable );
411- super .onError (throwable );
412- }
413-
414- @ Override
415- protected void onCancel () {
416- recordRequestDurationMetrics (model , timer , request , localNodeId , null );
417- super .onCancel ();
418- }
419-
420- @ Override
421- public void onComplete () {
422- recordRequestDurationMetrics (model , timer , request , localNodeId , null );
423- super .onComplete ();
424- }
425- }
426-
427432 private record NodeRoutingDecision (boolean currentNodeShouldHandleRequest , DiscoveryNode targetNode ) {
428433 static NodeRoutingDecision handleLocally () {
429434 return new NodeRoutingDecision (true , null );
0 commit comments