Skip to content

Commit 5d94f26

Browse files
authored
Use Wrapped Action Listeners in ShardBulkInferenceActionFilter (#138505) (#138528)
1 parent e3a09d1 commit 5d94f26

File tree

1 file changed

+105
-97
lines changed

1 file changed

+105
-97
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 105 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
import org.elasticsearch.inference.telemetry.InferenceStats;
5151
import org.elasticsearch.license.LicenseUtils;
5252
import org.elasticsearch.license.XPackLicenseState;
53+
import org.elasticsearch.logging.LogManager;
54+
import org.elasticsearch.logging.Logger;
5355
import org.elasticsearch.rest.RestStatus;
5456
import org.elasticsearch.tasks.Task;
5557
import org.elasticsearch.xcontent.XContent;
@@ -92,6 +94,8 @@
9294
*
9395
*/
9496
public class ShardBulkInferenceActionFilter implements MappedActionFilter {
97+
private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class);
98+
9599
private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1);
96100

97101
/**
@@ -325,127 +329,131 @@ private void executeChunkedInferenceAsync(
325329
final Releasable onFinish
326330
) {
327331
if (inferenceProvider == null) {
328-
ActionListener<UnparsedModel> modelLoadingListener = new ActionListener<>() {
329-
@Override
330-
public void onResponse(UnparsedModel unparsedModel) {
331-
var service = inferenceServiceRegistry.getService(unparsedModel.service());
332-
if (service.isEmpty() == false) {
333-
var provider = new InferenceProvider(
334-
service.get(),
335-
service.get()
336-
.parsePersistedConfigWithSecrets(
337-
inferenceId,
338-
unparsedModel.taskType(),
339-
unparsedModel.settings(),
340-
unparsedModel.secrets()
341-
)
342-
);
343-
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
344-
} else {
345-
try (onFinish) {
346-
for (FieldInferenceRequest request : requests) {
347-
inferenceResults.get(request.bulkItemIndex).failures.add(
348-
new ResourceNotFoundException(
349-
"Inference service [{}] not found for field [{}]",
350-
unparsedModel.service(),
351-
request.field
352-
)
353-
);
354-
}
355-
}
356-
}
357-
}
358-
359-
@Override
360-
public void onFailure(Exception exc) {
332+
ActionListener<UnparsedModel> modelLoadingListener = ActionListener.wrap(unparsedModel -> {
333+
var service = inferenceServiceRegistry.getService(unparsedModel.service());
334+
if (service.isEmpty() == false) {
335+
var provider = new InferenceProvider(
336+
service.get(),
337+
service.get()
338+
.parsePersistedConfigWithSecrets(
339+
inferenceId,
340+
unparsedModel.taskType(),
341+
unparsedModel.settings(),
342+
unparsedModel.secrets()
343+
)
344+
);
345+
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
346+
} else {
361347
try (onFinish) {
362348
for (FieldInferenceRequest request : requests) {
363-
Exception failure;
364-
if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) {
365-
failure = new ResourceNotFoundException(
366-
"Inference id [{}] not found for field [{}]",
367-
inferenceId,
349+
inferenceResults.get(request.bulkItemIndex).failures.add(
350+
new ResourceNotFoundException(
351+
"Inference service [{}] not found for field [{}]",
352+
unparsedModel.service(),
368353
request.field
369-
);
370-
} else {
371-
failure = new InferenceException(
372-
"Error loading inference for inference id [{}] on field [{}]",
373-
exc,
374-
inferenceId,
375-
request.field
376-
);
377-
}
378-
inferenceResults.get(request.bulkItemIndex).failures.add(failure);
354+
)
355+
);
379356
}
380357
}
381358
}
382-
};
383-
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
384-
return;
385-
}
386-
final List<ChunkInferenceInput> inputs = requests.stream()
387-
.map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings))
388-
.collect(Collectors.toList());
389-
390-
ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {
391-
392-
@Override
393-
public void onResponse(List<ChunkedInference> results) {
359+
}, exc -> {
394360
try (onFinish) {
395-
var requestsIterator = requests.iterator();
396-
int success = 0;
397-
for (ChunkedInference result : results) {
398-
var request = requestsIterator.next();
399-
var acc = inferenceResults.get(request.bulkItemIndex);
400-
if (result instanceof ChunkedInferenceError error) {
401-
recordRequestCountMetrics(inferenceProvider.model, 1, error.exception());
402-
acc.addFailure(
403-
new InferenceException(
404-
"Exception when running inference id [{}] on field [{}]",
405-
error.exception(),
406-
inferenceProvider.model.getInferenceEntityId(),
407-
request.field
408-
)
361+
for (FieldInferenceRequest request : requests) {
362+
Exception failure;
363+
if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) {
364+
failure = new ResourceNotFoundException(
365+
"Inference id [{}] not found for field [{}]",
366+
inferenceId,
367+
request.field
409368
);
410369
} else {
411-
success++;
412-
acc.addOrUpdateResponse(
413-
new FieldInferenceResponse(
414-
request.field(),
415-
request.sourceField(),
416-
useLegacyFormat ? request.input() : null,
417-
request.inputOrder(),
418-
request.offsetAdjustment(),
419-
inferenceProvider.model,
420-
result
421-
)
370+
failure = new InferenceException(
371+
"Error loading inference for inference id [{}] on field [{}]",
372+
exc,
373+
inferenceId,
374+
request.field
422375
);
423376
}
377+
inferenceResults.get(request.bulkItemIndex).failures.add(failure);
424378
}
425-
if (success > 0) {
426-
recordRequestCountMetrics(inferenceProvider.model, success, null);
379+
380+
if (ExceptionsHelper.status(exc).getStatus() >= 500) {
381+
List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList();
382+
logger.error("Error loading inference for inference id [" + inferenceId + "] on fields " + fields, exc);
427383
}
428384
}
429-
}
385+
});
386+
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
387+
return;
388+
}
389+
final List<ChunkInferenceInput> inputs = requests.stream()
390+
.map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings))
391+
.collect(Collectors.toList());
430392

431-
@Override
432-
public void onFailure(Exception exc) {
433-
try (onFinish) {
434-
recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc);
435-
for (FieldInferenceRequest request : requests) {
436-
addInferenceResponseFailure(
437-
request.bulkItemIndex,
393+
ActionListener<List<ChunkedInference>> completionListener = ActionListener.wrap(results -> {
394+
try (onFinish) {
395+
var requestsIterator = requests.iterator();
396+
int success = 0;
397+
for (ChunkedInference result : results) {
398+
var request = requestsIterator.next();
399+
var acc = inferenceResults.get(request.bulkItemIndex);
400+
if (result instanceof ChunkedInferenceError error) {
401+
recordRequestCountMetrics(inferenceProvider.model, 1, error.exception());
402+
acc.addFailure(
438403
new InferenceException(
439404
"Exception when running inference id [{}] on field [{}]",
440-
exc,
405+
error.exception(),
441406
inferenceProvider.model.getInferenceEntityId(),
442407
request.field
443408
)
444409
);
410+
} else {
411+
success++;
412+
acc.addOrUpdateResponse(
413+
new FieldInferenceResponse(
414+
request.field(),
415+
request.sourceField(),
416+
useLegacyFormat ? request.input() : null,
417+
request.inputOrder(),
418+
request.offsetAdjustment(),
419+
inferenceProvider.model,
420+
result
421+
)
422+
);
445423
}
446424
}
425+
if (success > 0) {
426+
recordRequestCountMetrics(inferenceProvider.model, success, null);
427+
}
447428
}
448-
};
429+
}, exc -> {
430+
try (onFinish) {
431+
recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc);
432+
for (FieldInferenceRequest request : requests) {
433+
addInferenceResponseFailure(
434+
request.bulkItemIndex,
435+
new InferenceException(
436+
"Exception when running inference id [{}] on field [{}]",
437+
exc,
438+
inferenceProvider.model.getInferenceEntityId(),
439+
request.field
440+
)
441+
);
442+
}
443+
444+
if (ExceptionsHelper.status(exc).getStatus() >= 500) {
445+
List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList();
446+
logger.error(
447+
"Exception when running inference id ["
448+
+ inferenceProvider.model.getInferenceEntityId()
449+
+ "] on fields "
450+
+ fields,
451+
exc
452+
);
453+
}
454+
}
455+
});
456+
449457
inferenceProvider.service()
450458
.chunkedInfer(
451459
inferenceProvider.model(),

0 commit comments

Comments
 (0)