|
58 | 58 | import java.util.LinkedHashMap; |
59 | 59 | import java.util.List; |
60 | 60 | import java.util.Map; |
61 | | -import java.util.Optional; |
62 | 61 | import java.util.stream.Collectors; |
63 | 62 |
|
64 | 63 | /** |
@@ -307,24 +306,26 @@ public void onFailure(Exception exc) { |
307 | 306 | return; |
308 | 307 | } |
309 | 308 | int currentBatchSize = Math.min(requests.size(), batchSize); |
310 | | - final List<FieldInferenceRequest> currentBatch = requests.subList(0, currentBatchSize); |
| 309 | + ChunkingSettings chunkingSettings = requests.get(0).chunkingSettings; |
| 310 | + List<FieldInferenceRequest> currentBatch = new ArrayList<>(); |
| 311 | + List<FieldInferenceRequest> others = new ArrayList<>(); |
| 312 | + for (int i = 0; i < currentBatchSize; i++) { |
| 313 | + FieldInferenceRequest request = requests.get(i); |
| 314 | + if ((chunkingSettings == null && request.chunkingSettings == null) || request.chunkingSettings.equals(chunkingSettings)) { |
| 315 | + currentBatch.add(request); |
| 316 | + } else { |
| 317 | + others.add(request); |
| 318 | + } |
| 319 | + } |
| 320 | + |
311 | 321 | final List<FieldInferenceRequest> nextBatch = requests.subList(currentBatchSize, requests.size()); |
312 | | - // |
313 | | - // List<ChunkedInputs> chunkedInputs = currentBatch.stream() |
314 | | - // .map(request -> new ChunkedInputs(request.chunkingSettings(), List.of(request.input()))) |
315 | | - // .toList(); |
316 | | - |
317 | | - List<ChunkedInputs> chunkedInputs = currentBatch.stream() |
318 | | - .collect(Collectors.groupingBy(request -> Optional.ofNullable(request.chunkingSettings()))) |
319 | | - .entrySet() |
320 | | - .stream() |
321 | | - .map( |
322 | | - entry -> new ChunkedInputs( |
323 | | - entry.getKey().orElse(null), |
324 | | - entry.getValue().stream().map(FieldInferenceRequest::input).collect(Collectors.toList()) |
325 | | - ) |
326 | | - ) |
327 | | - .toList(); |
| 322 | + nextBatch.addAll(others); |
| 323 | + |
| 324 | + // We can assume current batch has all the same chunking settings |
| 325 | + ChunkedInputs chunkedInputs = new ChunkedInputs( |
| 326 | + chunkingSettings, |
| 327 | + currentBatch.stream().map(r -> r.input).collect(Collectors.toList()) |
| 328 | + ); |
328 | 329 |
|
329 | 330 | ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() { |
330 | 331 | @Override |
@@ -390,19 +391,17 @@ private void onFinish() { |
390 | 391 | } |
391 | 392 | }; |
392 | 393 |
|
393 | | - for (ChunkedInputs chunkedInput : chunkedInputs) { |
394 | | - inferenceProvider.service() |
395 | | - .chunkedInfer( |
396 | | - inferenceProvider.model(), |
397 | | - null, |
398 | | - chunkedInput.inputs(), |
399 | | - Map.of(), |
400 | | - chunkedInput.chunkingSettings(), |
401 | | - InputType.INGEST, |
402 | | - TimeValue.MAX_VALUE, |
403 | | - completionListener |
404 | | - ); |
405 | | - } |
| 394 | + inferenceProvider.service() |
| 395 | + .chunkedInfer( |
| 396 | + inferenceProvider.model(), |
| 397 | + null, |
| 398 | + chunkedInputs.inputs(), |
| 399 | + Map.of(), |
| 400 | + chunkedInputs.chunkingSettings(), |
| 401 | + InputType.INGEST, |
| 402 | + TimeValue.MAX_VALUE, |
| 403 | + completionListener |
| 404 | + ); |
406 | 405 | } |
407 | 406 |
|
408 | 407 | private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { |
|
0 commit comments