5050import org .elasticsearch .inference .UnparsedModel ;
5151import org .elasticsearch .inference .telemetry .InferenceStats ;
5252import org .elasticsearch .license .XPackLicenseState ;
53+ import org .elasticsearch .logging .LogManager ;
54+ import org .elasticsearch .logging .Logger ;
5355import org .elasticsearch .rest .RestStatus ;
5456import org .elasticsearch .tasks .Task ;
5557import org .elasticsearch .xcontent .XContent ;
9294 *
9395 */
9496public class ShardBulkInferenceActionFilter implements MappedActionFilter {
97+ private static final Logger logger = LogManager .getLogger (ShardBulkInferenceActionFilter .class );
98+
9599 private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue .ofMb (1 );
96100
97101 /**
@@ -325,61 +329,60 @@ private void executeChunkedInferenceAsync(
325329 final Releasable onFinish
326330 ) {
327331 if (inferenceProvider == null ) {
328- ActionListener <UnparsedModel > modelLoadingListener = new ActionListener <>() {
329- @ Override
330- public void onResponse (UnparsedModel unparsedModel ) {
331- var service = inferenceServiceRegistry .getService (unparsedModel .service ());
332- if (service .isEmpty () == false ) {
333- var provider = new InferenceProvider (
334- service .get (),
335- service .get ()
336- .parsePersistedConfigWithSecrets (
337- inferenceId ,
338- unparsedModel .taskType (),
339- unparsedModel .settings (),
340- unparsedModel .secrets ()
332+ ActionListener <UnparsedModel > modelLoadingListener = ActionListener .wrap (unparsedModel -> {
333+ var service = inferenceServiceRegistry .getService (unparsedModel .service ());
334+ if (service .isEmpty () == false ) {
335+ var provider = new InferenceProvider (
336+ service .get (),
337+ service .get ()
338+ .parsePersistedConfigWithSecrets (
339+ inferenceId ,
340+ unparsedModel .taskType (),
341+ unparsedModel .settings (),
342+ unparsedModel .secrets ()
343+ )
344+ );
345+ executeChunkedInferenceAsync (inferenceId , provider , requests , onFinish );
346+ } else {
347+ try (onFinish ) {
348+ for (FieldInferenceRequest request : requests ) {
349+ inferenceResults .get (request .bulkItemIndex ).failures .add (
350+ new ResourceNotFoundException (
351+ "Inference service [{}] not found for field [{}]" ,
352+ unparsedModel .service (),
353+ request .field
341354 )
342- );
343- executeChunkedInferenceAsync (inferenceId , provider , requests , onFinish );
344- } else {
345- try (onFinish ) {
346- for (FieldInferenceRequest request : requests ) {
347- inferenceResults .get (request .bulkItemIndex ).failures .add (
348- new ResourceNotFoundException (
349- "Inference service [{}] not found for field [{}]" ,
350- unparsedModel .service (),
351- request .field
352- )
353- );
354- }
355+ );
355356 }
356357 }
357358 }
358-
359- @ Override
360- public void onFailure (Exception exc ) {
361- try (onFinish ) {
362- for (FieldInferenceRequest request : requests ) {
363- Exception failure ;
364- if (ExceptionsHelper .unwrap (exc , ResourceNotFoundException .class ) instanceof ResourceNotFoundException ) {
365- failure = new ResourceNotFoundException (
366- "Inference id [{}] not found for field [{}]" ,
367- inferenceId ,
368- request .field
369- );
370- } else {
371- failure = new InferenceException (
372- "Error loading inference for inference id [{}] on field [{}]" ,
373- exc ,
374- inferenceId ,
375- request .field
376- );
377- }
378- inferenceResults .get (request .bulkItemIndex ).failures .add (failure );
359+ }, exc -> {
360+ try (onFinish ) {
361+ for (FieldInferenceRequest request : requests ) {
362+ Exception failure ;
363+ if (ExceptionsHelper .unwrap (exc , ResourceNotFoundException .class ) instanceof ResourceNotFoundException ) {
364+ failure = new ResourceNotFoundException (
365+ "Inference id [{}] not found for field [{}]" ,
366+ inferenceId ,
367+ request .field
368+ );
369+ } else {
370+ failure = new InferenceException (
371+ "Error loading inference for inference id [{}] on field [{}]" ,
372+ exc ,
373+ inferenceId ,
374+ request .field
375+ );
379376 }
377+ inferenceResults .get (request .bulkItemIndex ).failures .add (failure );
378+ }
379+
380+ if (ExceptionsHelper .status (exc ).getStatus () >= 500 ) {
381+ List <String > fields = requests .stream ().map (FieldInferenceRequest ::field ).distinct ().toList ();
382+ logger .error ("Error loading inference for inference id [" + inferenceId + "] on fields " + fields , exc );
380383 }
381384 }
382- };
385+ }) ;
383386 modelRegistry .getModelWithSecrets (inferenceId , modelLoadingListener );
384387 return ;
385388 }
@@ -398,65 +401,70 @@ public void onFailure(Exception exc) {
398401 .map (r -> new ChunkInferenceInput (new InferenceString (r .input , TEXT ), r .chunkingSettings ))
399402 .collect (Collectors .toList ());
400403
401- ActionListener <List <ChunkedInference >> completionListener = new ActionListener <>() {
402-
403- @ Override
404- public void onResponse (List <ChunkedInference > results ) {
405- try (onFinish ) {
406- var requestsIterator = requests .iterator ();
407- int success = 0 ;
408- for (ChunkedInference result : results ) {
409- var request = requestsIterator .next ();
410- var acc = inferenceResults .get (request .bulkItemIndex );
411- if (result instanceof ChunkedInferenceError error ) {
412- recordRequestCountMetrics (inferenceProvider .model , 1 , error .exception ());
413- acc .addFailure (
414- new InferenceException (
415- "Exception when running inference id [{}] on field [{}]" ,
416- error .exception (),
417- inferenceProvider .model .getInferenceEntityId (),
418- request .field
419- )
420- );
421- } else {
422- success ++;
423- acc .addOrUpdateResponse (
424- new FieldInferenceResponse (
425- request .field (),
426- request .sourceField (),
427- useLegacyFormat ? request .input () : null ,
428- request .inputOrder (),
429- request .offsetAdjustment (),
430- inferenceProvider .model ,
431- result
432- )
433- );
434- }
435- }
436- if (success > 0 ) {
437- recordRequestCountMetrics (inferenceProvider .model , success , null );
438- }
439- }
440- }
441-
442- @ Override
443- public void onFailure (Exception exc ) {
444- try (onFinish ) {
445- recordRequestCountMetrics (inferenceProvider .model , requests .size (), exc );
446- for (FieldInferenceRequest request : requests ) {
447- addInferenceResponseFailure (
448- request .bulkItemIndex ,
404+ ActionListener <List <ChunkedInference >> completionListener = ActionListener .wrap (results -> {
405+ try (onFinish ) {
406+ var requestsIterator = requests .iterator ();
407+ int success = 0 ;
408+ for (ChunkedInference result : results ) {
409+ var request = requestsIterator .next ();
410+ var acc = inferenceResults .get (request .bulkItemIndex );
411+ if (result instanceof ChunkedInferenceError error ) {
412+ recordRequestCountMetrics (inferenceProvider .model , 1 , error .exception ());
413+ acc .addFailure (
449414 new InferenceException (
450415 "Exception when running inference id [{}] on field [{}]" ,
451- exc ,
416+ error . exception () ,
452417 inferenceProvider .model .getInferenceEntityId (),
453418 request .field
454419 )
455420 );
421+ } else {
422+ success ++;
423+ acc .addOrUpdateResponse (
424+ new FieldInferenceResponse (
425+ request .field (),
426+ request .sourceField (),
427+ useLegacyFormat ? request .input () : null ,
428+ request .inputOrder (),
429+ request .offsetAdjustment (),
430+ inferenceProvider .model ,
431+ result
432+ )
433+ );
456434 }
457435 }
436+ if (success > 0 ) {
437+ recordRequestCountMetrics (inferenceProvider .model , success , null );
438+ }
458439 }
459- };
440+ }, exc -> {
441+ try (onFinish ) {
442+ recordRequestCountMetrics (inferenceProvider .model , requests .size (), exc );
443+ for (FieldInferenceRequest request : requests ) {
444+ addInferenceResponseFailure (
445+ request .bulkItemIndex ,
446+ new InferenceException (
447+ "Exception when running inference id [{}] on field [{}]" ,
448+ exc ,
449+ inferenceProvider .model .getInferenceEntityId (),
450+ request .field
451+ )
452+ );
453+ }
454+
455+ if (ExceptionsHelper .status (exc ).getStatus () >= 500 ) {
456+ List <String > fields = requests .stream ().map (FieldInferenceRequest ::field ).distinct ().toList ();
457+ logger .error (
458+ "Exception when running inference id ["
459+ + inferenceProvider .model .getInferenceEntityId ()
460+ + "] on fields "
461+ + fields ,
462+ exc
463+ );
464+ }
465+ }
466+ });
467+
460468 inferenceProvider .service ()
461469 .chunkedInfer (
462470 inferenceProvider .model (),
0 commit comments