Skip to content

Commit 40cc956

Browse files
Mikep86afoucret
authored andcommitted
Use Wrapped Action Listeners in ShardBulkInferenceActionFilter (elastic#138505)
1 parent 4abbb19 commit 40cc956

File tree

1 file changed

+106
-98
lines changed

1 file changed

+106
-98
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 106 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
import org.elasticsearch.inference.UnparsedModel;
5151
import org.elasticsearch.inference.telemetry.InferenceStats;
5252
import org.elasticsearch.license.XPackLicenseState;
53+
import org.elasticsearch.logging.LogManager;
54+
import org.elasticsearch.logging.Logger;
5355
import org.elasticsearch.rest.RestStatus;
5456
import org.elasticsearch.tasks.Task;
5557
import org.elasticsearch.xcontent.XContent;
@@ -92,6 +94,8 @@
9294
*
9395
*/
9496
public class ShardBulkInferenceActionFilter implements MappedActionFilter {
97+
private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class);
98+
9599
private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1);
96100

97101
/**
@@ -325,61 +329,60 @@ private void executeChunkedInferenceAsync(
325329
final Releasable onFinish
326330
) {
327331
if (inferenceProvider == null) {
328-
ActionListener<UnparsedModel> modelLoadingListener = new ActionListener<>() {
329-
@Override
330-
public void onResponse(UnparsedModel unparsedModel) {
331-
var service = inferenceServiceRegistry.getService(unparsedModel.service());
332-
if (service.isEmpty() == false) {
333-
var provider = new InferenceProvider(
334-
service.get(),
335-
service.get()
336-
.parsePersistedConfigWithSecrets(
337-
inferenceId,
338-
unparsedModel.taskType(),
339-
unparsedModel.settings(),
340-
unparsedModel.secrets()
332+
ActionListener<UnparsedModel> modelLoadingListener = ActionListener.wrap(unparsedModel -> {
333+
var service = inferenceServiceRegistry.getService(unparsedModel.service());
334+
if (service.isEmpty() == false) {
335+
var provider = new InferenceProvider(
336+
service.get(),
337+
service.get()
338+
.parsePersistedConfigWithSecrets(
339+
inferenceId,
340+
unparsedModel.taskType(),
341+
unparsedModel.settings(),
342+
unparsedModel.secrets()
343+
)
344+
);
345+
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
346+
} else {
347+
try (onFinish) {
348+
for (FieldInferenceRequest request : requests) {
349+
inferenceResults.get(request.bulkItemIndex).failures.add(
350+
new ResourceNotFoundException(
351+
"Inference service [{}] not found for field [{}]",
352+
unparsedModel.service(),
353+
request.field
341354
)
342-
);
343-
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
344-
} else {
345-
try (onFinish) {
346-
for (FieldInferenceRequest request : requests) {
347-
inferenceResults.get(request.bulkItemIndex).failures.add(
348-
new ResourceNotFoundException(
349-
"Inference service [{}] not found for field [{}]",
350-
unparsedModel.service(),
351-
request.field
352-
)
353-
);
354-
}
355+
);
355356
}
356357
}
357358
}
358-
359-
@Override
360-
public void onFailure(Exception exc) {
361-
try (onFinish) {
362-
for (FieldInferenceRequest request : requests) {
363-
Exception failure;
364-
if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) {
365-
failure = new ResourceNotFoundException(
366-
"Inference id [{}] not found for field [{}]",
367-
inferenceId,
368-
request.field
369-
);
370-
} else {
371-
failure = new InferenceException(
372-
"Error loading inference for inference id [{}] on field [{}]",
373-
exc,
374-
inferenceId,
375-
request.field
376-
);
377-
}
378-
inferenceResults.get(request.bulkItemIndex).failures.add(failure);
359+
}, exc -> {
360+
try (onFinish) {
361+
for (FieldInferenceRequest request : requests) {
362+
Exception failure;
363+
if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) {
364+
failure = new ResourceNotFoundException(
365+
"Inference id [{}] not found for field [{}]",
366+
inferenceId,
367+
request.field
368+
);
369+
} else {
370+
failure = new InferenceException(
371+
"Error loading inference for inference id [{}] on field [{}]",
372+
exc,
373+
inferenceId,
374+
request.field
375+
);
379376
}
377+
inferenceResults.get(request.bulkItemIndex).failures.add(failure);
378+
}
379+
380+
if (ExceptionsHelper.status(exc).getStatus() >= 500) {
381+
List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList();
382+
logger.error("Error loading inference for inference id [" + inferenceId + "] on fields " + fields, exc);
380383
}
381384
}
382-
};
385+
});
383386
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
384387
return;
385388
}
@@ -398,65 +401,70 @@ public void onFailure(Exception exc) {
398401
.map(r -> new ChunkInferenceInput(new InferenceString(r.input, TEXT), r.chunkingSettings))
399402
.collect(Collectors.toList());
400403

401-
ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {
402-
403-
@Override
404-
public void onResponse(List<ChunkedInference> results) {
405-
try (onFinish) {
406-
var requestsIterator = requests.iterator();
407-
int success = 0;
408-
for (ChunkedInference result : results) {
409-
var request = requestsIterator.next();
410-
var acc = inferenceResults.get(request.bulkItemIndex);
411-
if (result instanceof ChunkedInferenceError error) {
412-
recordRequestCountMetrics(inferenceProvider.model, 1, error.exception());
413-
acc.addFailure(
414-
new InferenceException(
415-
"Exception when running inference id [{}] on field [{}]",
416-
error.exception(),
417-
inferenceProvider.model.getInferenceEntityId(),
418-
request.field
419-
)
420-
);
421-
} else {
422-
success++;
423-
acc.addOrUpdateResponse(
424-
new FieldInferenceResponse(
425-
request.field(),
426-
request.sourceField(),
427-
useLegacyFormat ? request.input() : null,
428-
request.inputOrder(),
429-
request.offsetAdjustment(),
430-
inferenceProvider.model,
431-
result
432-
)
433-
);
434-
}
435-
}
436-
if (success > 0) {
437-
recordRequestCountMetrics(inferenceProvider.model, success, null);
438-
}
439-
}
440-
}
441-
442-
@Override
443-
public void onFailure(Exception exc) {
444-
try (onFinish) {
445-
recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc);
446-
for (FieldInferenceRequest request : requests) {
447-
addInferenceResponseFailure(
448-
request.bulkItemIndex,
404+
ActionListener<List<ChunkedInference>> completionListener = ActionListener.wrap(results -> {
405+
try (onFinish) {
406+
var requestsIterator = requests.iterator();
407+
int success = 0;
408+
for (ChunkedInference result : results) {
409+
var request = requestsIterator.next();
410+
var acc = inferenceResults.get(request.bulkItemIndex);
411+
if (result instanceof ChunkedInferenceError error) {
412+
recordRequestCountMetrics(inferenceProvider.model, 1, error.exception());
413+
acc.addFailure(
449414
new InferenceException(
450415
"Exception when running inference id [{}] on field [{}]",
451-
exc,
416+
error.exception(),
452417
inferenceProvider.model.getInferenceEntityId(),
453418
request.field
454419
)
455420
);
421+
} else {
422+
success++;
423+
acc.addOrUpdateResponse(
424+
new FieldInferenceResponse(
425+
request.field(),
426+
request.sourceField(),
427+
useLegacyFormat ? request.input() : null,
428+
request.inputOrder(),
429+
request.offsetAdjustment(),
430+
inferenceProvider.model,
431+
result
432+
)
433+
);
456434
}
457435
}
436+
if (success > 0) {
437+
recordRequestCountMetrics(inferenceProvider.model, success, null);
438+
}
458439
}
459-
};
440+
}, exc -> {
441+
try (onFinish) {
442+
recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc);
443+
for (FieldInferenceRequest request : requests) {
444+
addInferenceResponseFailure(
445+
request.bulkItemIndex,
446+
new InferenceException(
447+
"Exception when running inference id [{}] on field [{}]",
448+
exc,
449+
inferenceProvider.model.getInferenceEntityId(),
450+
request.field
451+
)
452+
);
453+
}
454+
455+
if (ExceptionsHelper.status(exc).getStatus() >= 500) {
456+
List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList();
457+
logger.error(
458+
"Exception when running inference id ["
459+
+ inferenceProvider.model.getInferenceEntityId()
460+
+ "] on fields "
461+
+ fields,
462+
exc
463+
);
464+
}
465+
}
466+
});
467+
460468
inferenceProvider.service()
461469
.chunkedInfer(
462470
inferenceProvider.model(),

0 commit comments

Comments
 (0)