|
50 | 50 | import org.elasticsearch.inference.telemetry.InferenceStats; |
51 | 51 | import org.elasticsearch.license.LicenseUtils; |
52 | 52 | import org.elasticsearch.license.XPackLicenseState; |
| 53 | +import org.elasticsearch.logging.LogManager; |
| 54 | +import org.elasticsearch.logging.Logger; |
53 | 55 | import org.elasticsearch.rest.RestStatus; |
54 | 56 | import org.elasticsearch.tasks.Task; |
55 | 57 | import org.elasticsearch.xcontent.XContent; |
|
92 | 94 | * |
93 | 95 | */ |
94 | 96 | public class ShardBulkInferenceActionFilter implements MappedActionFilter { |
| 97 | + private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class); |
| 98 | + |
95 | 99 | private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1); |
96 | 100 |
|
97 | 101 | /** |
@@ -325,127 +329,131 @@ private void executeChunkedInferenceAsync( |
325 | 329 | final Releasable onFinish |
326 | 330 | ) { |
327 | 331 | if (inferenceProvider == null) { |
328 | | - ActionListener<UnparsedModel> modelLoadingListener = new ActionListener<>() { |
329 | | - @Override |
330 | | - public void onResponse(UnparsedModel unparsedModel) { |
331 | | - var service = inferenceServiceRegistry.getService(unparsedModel.service()); |
332 | | - if (service.isEmpty() == false) { |
333 | | - var provider = new InferenceProvider( |
334 | | - service.get(), |
335 | | - service.get() |
336 | | - .parsePersistedConfigWithSecrets( |
337 | | - inferenceId, |
338 | | - unparsedModel.taskType(), |
339 | | - unparsedModel.settings(), |
340 | | - unparsedModel.secrets() |
341 | | - ) |
342 | | - ); |
343 | | - executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish); |
344 | | - } else { |
345 | | - try (onFinish) { |
346 | | - for (FieldInferenceRequest request : requests) { |
347 | | - inferenceResults.get(request.bulkItemIndex).failures.add( |
348 | | - new ResourceNotFoundException( |
349 | | - "Inference service [{}] not found for field [{}]", |
350 | | - unparsedModel.service(), |
351 | | - request.field |
352 | | - ) |
353 | | - ); |
354 | | - } |
355 | | - } |
356 | | - } |
357 | | - } |
358 | | - |
359 | | - @Override |
360 | | - public void onFailure(Exception exc) { |
| 332 | + ActionListener<UnparsedModel> modelLoadingListener = ActionListener.wrap(unparsedModel -> { |
| 333 | + var service = inferenceServiceRegistry.getService(unparsedModel.service()); |
| 334 | + if (service.isEmpty() == false) { |
| 335 | + var provider = new InferenceProvider( |
| 336 | + service.get(), |
| 337 | + service.get() |
| 338 | + .parsePersistedConfigWithSecrets( |
| 339 | + inferenceId, |
| 340 | + unparsedModel.taskType(), |
| 341 | + unparsedModel.settings(), |
| 342 | + unparsedModel.secrets() |
| 343 | + ) |
| 344 | + ); |
| 345 | + executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish); |
| 346 | + } else { |
361 | 347 | try (onFinish) { |
362 | 348 | for (FieldInferenceRequest request : requests) { |
363 | | - Exception failure; |
364 | | - if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) { |
365 | | - failure = new ResourceNotFoundException( |
366 | | - "Inference id [{}] not found for field [{}]", |
367 | | - inferenceId, |
| 349 | + inferenceResults.get(request.bulkItemIndex).failures.add( |
| 350 | + new ResourceNotFoundException( |
| 351 | + "Inference service [{}] not found for field [{}]", |
| 352 | + unparsedModel.service(), |
368 | 353 | request.field |
369 | | - ); |
370 | | - } else { |
371 | | - failure = new InferenceException( |
372 | | - "Error loading inference for inference id [{}] on field [{}]", |
373 | | - exc, |
374 | | - inferenceId, |
375 | | - request.field |
376 | | - ); |
377 | | - } |
378 | | - inferenceResults.get(request.bulkItemIndex).failures.add(failure); |
| 354 | + ) |
| 355 | + ); |
379 | 356 | } |
380 | 357 | } |
381 | 358 | } |
382 | | - }; |
383 | | - modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); |
384 | | - return; |
385 | | - } |
386 | | - final List<ChunkInferenceInput> inputs = requests.stream() |
387 | | - .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) |
388 | | - .collect(Collectors.toList()); |
389 | | - |
390 | | - ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() { |
391 | | - |
392 | | - @Override |
393 | | - public void onResponse(List<ChunkedInference> results) { |
| 359 | + }, exc -> { |
394 | 360 | try (onFinish) { |
395 | | - var requestsIterator = requests.iterator(); |
396 | | - int success = 0; |
397 | | - for (ChunkedInference result : results) { |
398 | | - var request = requestsIterator.next(); |
399 | | - var acc = inferenceResults.get(request.bulkItemIndex); |
400 | | - if (result instanceof ChunkedInferenceError error) { |
401 | | - recordRequestCountMetrics(inferenceProvider.model, 1, error.exception()); |
402 | | - acc.addFailure( |
403 | | - new InferenceException( |
404 | | - "Exception when running inference id [{}] on field [{}]", |
405 | | - error.exception(), |
406 | | - inferenceProvider.model.getInferenceEntityId(), |
407 | | - request.field |
408 | | - ) |
| 361 | + for (FieldInferenceRequest request : requests) { |
| 362 | + Exception failure; |
| 363 | + if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) { |
| 364 | + failure = new ResourceNotFoundException( |
| 365 | + "Inference id [{}] not found for field [{}]", |
| 366 | + inferenceId, |
| 367 | + request.field |
409 | 368 | ); |
410 | 369 | } else { |
411 | | - success++; |
412 | | - acc.addOrUpdateResponse( |
413 | | - new FieldInferenceResponse( |
414 | | - request.field(), |
415 | | - request.sourceField(), |
416 | | - useLegacyFormat ? request.input() : null, |
417 | | - request.inputOrder(), |
418 | | - request.offsetAdjustment(), |
419 | | - inferenceProvider.model, |
420 | | - result |
421 | | - ) |
| 370 | + failure = new InferenceException( |
| 371 | + "Error loading inference for inference id [{}] on field [{}]", |
| 372 | + exc, |
| 373 | + inferenceId, |
| 374 | + request.field |
422 | 375 | ); |
423 | 376 | } |
| 377 | + inferenceResults.get(request.bulkItemIndex).failures.add(failure); |
424 | 378 | } |
425 | | - if (success > 0) { |
426 | | - recordRequestCountMetrics(inferenceProvider.model, success, null); |
| 379 | + |
| 380 | + if (ExceptionsHelper.status(exc).getStatus() >= 500) { |
| 381 | + List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList(); |
| 382 | + logger.error("Error loading inference for inference id [" + inferenceId + "] on fields " + fields, exc); |
427 | 383 | } |
428 | 384 | } |
429 | | - } |
| 385 | + }); |
| 386 | + modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); |
| 387 | + return; |
| 388 | + } |
| 389 | + final List<ChunkInferenceInput> inputs = requests.stream() |
| 390 | + .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) |
| 391 | + .collect(Collectors.toList()); |
430 | 392 |
|
431 | | - @Override |
432 | | - public void onFailure(Exception exc) { |
433 | | - try (onFinish) { |
434 | | - recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc); |
435 | | - for (FieldInferenceRequest request : requests) { |
436 | | - addInferenceResponseFailure( |
437 | | - request.bulkItemIndex, |
| 393 | + ActionListener<List<ChunkedInference>> completionListener = ActionListener.wrap(results -> { |
| 394 | + try (onFinish) { |
| 395 | + var requestsIterator = requests.iterator(); |
| 396 | + int success = 0; |
| 397 | + for (ChunkedInference result : results) { |
| 398 | + var request = requestsIterator.next(); |
| 399 | + var acc = inferenceResults.get(request.bulkItemIndex); |
| 400 | + if (result instanceof ChunkedInferenceError error) { |
| 401 | + recordRequestCountMetrics(inferenceProvider.model, 1, error.exception()); |
| 402 | + acc.addFailure( |
438 | 403 | new InferenceException( |
439 | 404 | "Exception when running inference id [{}] on field [{}]", |
440 | | - exc, |
| 405 | + error.exception(), |
441 | 406 | inferenceProvider.model.getInferenceEntityId(), |
442 | 407 | request.field |
443 | 408 | ) |
444 | 409 | ); |
| 410 | + } else { |
| 411 | + success++; |
| 412 | + acc.addOrUpdateResponse( |
| 413 | + new FieldInferenceResponse( |
| 414 | + request.field(), |
| 415 | + request.sourceField(), |
| 416 | + useLegacyFormat ? request.input() : null, |
| 417 | + request.inputOrder(), |
| 418 | + request.offsetAdjustment(), |
| 419 | + inferenceProvider.model, |
| 420 | + result |
| 421 | + ) |
| 422 | + ); |
445 | 423 | } |
446 | 424 | } |
| 425 | + if (success > 0) { |
| 426 | + recordRequestCountMetrics(inferenceProvider.model, success, null); |
| 427 | + } |
447 | 428 | } |
448 | | - }; |
| 429 | + }, exc -> { |
| 430 | + try (onFinish) { |
| 431 | + recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc); |
| 432 | + for (FieldInferenceRequest request : requests) { |
| 433 | + addInferenceResponseFailure( |
| 434 | + request.bulkItemIndex, |
| 435 | + new InferenceException( |
| 436 | + "Exception when running inference id [{}] on field [{}]", |
| 437 | + exc, |
| 438 | + inferenceProvider.model.getInferenceEntityId(), |
| 439 | + request.field |
| 440 | + ) |
| 441 | + ); |
| 442 | + } |
| 443 | + |
| 444 | + if (ExceptionsHelper.status(exc).getStatus() >= 500) { |
| 445 | + List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList(); |
| 446 | + logger.error( |
| 447 | + "Exception when running inference id [" |
| 448 | + + inferenceProvider.model.getInferenceEntityId() |
| 449 | + + "] on fields " |
| 450 | + + fields, |
| 451 | + exc |
| 452 | + ); |
| 453 | + } |
| 454 | + } |
| 455 | + }); |
| 456 | + |
449 | 457 | inferenceProvider.service() |
450 | 458 | .chunkedInfer( |
451 | 459 | inferenceProvider.model(), |
|
0 commit comments