Skip to content

Commit ff687d0

Browse files
authored
Use Wrapped Action Listeners in ShardBulkInferenceActionFilter (elastic#138505) (elastic#138532)
1 parent 6beebea commit ff687d0

File tree

1 file changed

+105
-96
lines changed

1 file changed

+105
-96
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 105 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
import org.elasticsearch.inference.UnparsedModel;
4949
import org.elasticsearch.license.LicenseUtils;
5050
import org.elasticsearch.license.XPackLicenseState;
51+
import org.elasticsearch.logging.LogManager;
52+
import org.elasticsearch.logging.Logger;
5153
import org.elasticsearch.rest.RestStatus;
5254
import org.elasticsearch.tasks.Task;
5355
import org.elasticsearch.xcontent.XContent;
@@ -92,6 +94,8 @@
9294
*
9395
*/
9496
public class ShardBulkInferenceActionFilter implements MappedActionFilter {
97+
private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class);
98+
9599
private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1);
96100

97101
/**
@@ -325,126 +329,131 @@ private void executeChunkedInferenceAsync(
325329
final Releasable onFinish
326330
) {
327331
if (inferenceProvider == null) {
328-
ActionListener<UnparsedModel> modelLoadingListener = new ActionListener<>() {
329-
@Override
330-
public void onResponse(UnparsedModel unparsedModel) {
331-
var service = inferenceServiceRegistry.getService(unparsedModel.service());
332-
if (service.isEmpty() == false) {
333-
var provider = new InferenceProvider(
334-
service.get(),
335-
service.get()
336-
.parsePersistedConfigWithSecrets(
337-
inferenceId,
338-
unparsedModel.taskType(),
339-
unparsedModel.settings(),
340-
unparsedModel.secrets()
341-
)
342-
);
343-
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
344-
} else {
345-
try (onFinish) {
346-
for (FieldInferenceRequest request : requests) {
347-
inferenceResults.get(request.bulkItemIndex).failures.add(
348-
new ResourceNotFoundException(
349-
"Inference service [{}] not found for field [{}]",
350-
unparsedModel.service(),
351-
request.field
352-
)
353-
);
354-
}
355-
}
356-
}
357-
}
358-
359-
@Override
360-
public void onFailure(Exception exc) {
332+
ActionListener<UnparsedModel> modelLoadingListener = ActionListener.wrap(unparsedModel -> {
333+
var service = inferenceServiceRegistry.getService(unparsedModel.service());
334+
if (service.isEmpty() == false) {
335+
var provider = new InferenceProvider(
336+
service.get(),
337+
service.get()
338+
.parsePersistedConfigWithSecrets(
339+
inferenceId,
340+
unparsedModel.taskType(),
341+
unparsedModel.settings(),
342+
unparsedModel.secrets()
343+
)
344+
);
345+
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
346+
} else {
361347
try (onFinish) {
362348
for (FieldInferenceRequest request : requests) {
363-
Exception failure;
364-
if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) {
365-
failure = new ResourceNotFoundException(
366-
"Inference id [{}] not found for field [{}]",
367-
inferenceId,
349+
inferenceResults.get(request.bulkItemIndex).failures.add(
350+
new ResourceNotFoundException(
351+
"Inference service [{}] not found for field [{}]",
352+
unparsedModel.service(),
368353
request.field
369-
);
370-
} else {
371-
failure = new InferenceException(
372-
"Error loading inference for inference id [{}] on field [{}]",
373-
exc,
374-
inferenceId,
375-
request.field
376-
);
377-
}
378-
inferenceResults.get(request.bulkItemIndex).failures.add(failure);
354+
)
355+
);
379356
}
380357
}
381358
}
382-
};
383-
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
384-
return;
385-
}
386-
final List<ChunkInferenceInput> inputs = requests.stream()
387-
.map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings))
388-
.collect(Collectors.toList());
389-
390-
ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {
391-
@Override
392-
public void onResponse(List<ChunkedInference> results) {
359+
}, exc -> {
393360
try (onFinish) {
394-
var requestsIterator = requests.iterator();
395-
int success = 0;
396-
for (ChunkedInference result : results) {
397-
var request = requestsIterator.next();
398-
var acc = inferenceResults.get(request.bulkItemIndex);
399-
if (result instanceof ChunkedInferenceError error) {
400-
recordRequestCountMetrics(inferenceProvider.model, 1, error.exception());
401-
acc.addFailure(
402-
new InferenceException(
403-
"Exception when running inference id [{}] on field [{}]",
404-
error.exception(),
405-
inferenceProvider.model.getInferenceEntityId(),
406-
request.field
407-
)
361+
for (FieldInferenceRequest request : requests) {
362+
Exception failure;
363+
if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) {
364+
failure = new ResourceNotFoundException(
365+
"Inference id [{}] not found for field [{}]",
366+
inferenceId,
367+
request.field
408368
);
409369
} else {
410-
success++;
411-
acc.addOrUpdateResponse(
412-
new FieldInferenceResponse(
413-
request.field(),
414-
request.sourceField(),
415-
useLegacyFormat ? request.input() : null,
416-
request.inputOrder(),
417-
request.offsetAdjustment(),
418-
inferenceProvider.model,
419-
result
420-
)
370+
failure = new InferenceException(
371+
"Error loading inference for inference id [{}] on field [{}]",
372+
exc,
373+
inferenceId,
374+
request.field
421375
);
422376
}
377+
inferenceResults.get(request.bulkItemIndex).failures.add(failure);
423378
}
424-
if (success > 0) {
425-
recordRequestCountMetrics(inferenceProvider.model, success, null);
379+
380+
if (ExceptionsHelper.status(exc).getStatus() >= 500) {
381+
List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList();
382+
logger.error("Error loading inference for inference id [" + inferenceId + "] on fields " + fields, exc);
426383
}
427384
}
428-
}
385+
});
386+
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
387+
return;
388+
}
389+
final List<ChunkInferenceInput> inputs = requests.stream()
390+
.map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings))
391+
.collect(Collectors.toList());
429392

430-
@Override
431-
public void onFailure(Exception exc) {
432-
try (onFinish) {
433-
recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc);
434-
for (FieldInferenceRequest request : requests) {
435-
addInferenceResponseFailure(
436-
request.bulkItemIndex,
393+
ActionListener<List<ChunkedInference>> completionListener = ActionListener.wrap(results -> {
394+
try (onFinish) {
395+
var requestsIterator = requests.iterator();
396+
int success = 0;
397+
for (ChunkedInference result : results) {
398+
var request = requestsIterator.next();
399+
var acc = inferenceResults.get(request.bulkItemIndex);
400+
if (result instanceof ChunkedInferenceError error) {
401+
recordRequestCountMetrics(inferenceProvider.model, 1, error.exception());
402+
acc.addFailure(
437403
new InferenceException(
438404
"Exception when running inference id [{}] on field [{}]",
439-
exc,
405+
error.exception(),
440406
inferenceProvider.model.getInferenceEntityId(),
441407
request.field
442408
)
443409
);
410+
} else {
411+
success++;
412+
acc.addOrUpdateResponse(
413+
new FieldInferenceResponse(
414+
request.field(),
415+
request.sourceField(),
416+
useLegacyFormat ? request.input() : null,
417+
request.inputOrder(),
418+
request.offsetAdjustment(),
419+
inferenceProvider.model,
420+
result
421+
)
422+
);
444423
}
445424
}
425+
if (success > 0) {
426+
recordRequestCountMetrics(inferenceProvider.model, success, null);
427+
}
446428
}
447-
};
429+
}, exc -> {
430+
try (onFinish) {
431+
recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc);
432+
for (FieldInferenceRequest request : requests) {
433+
addInferenceResponseFailure(
434+
request.bulkItemIndex,
435+
new InferenceException(
436+
"Exception when running inference id [{}] on field [{}]",
437+
exc,
438+
inferenceProvider.model.getInferenceEntityId(),
439+
request.field
440+
)
441+
);
442+
}
443+
444+
if (ExceptionsHelper.status(exc).getStatus() >= 500) {
445+
List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList();
446+
logger.error(
447+
"Exception when running inference id ["
448+
+ inferenceProvider.model.getInferenceEntityId()
449+
+ "] on fields "
450+
+ fields,
451+
exc
452+
);
453+
}
454+
}
455+
});
456+
448457
inferenceProvider.service()
449458
.chunkedInfer(
450459
inferenceProvider.model(),

0 commit comments

Comments
 (0)