Skip to content

Commit c9bfa32

Browse files
committed
A bit of cleanup
1 parent c2c9a52 commit c9bfa32

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
import java.util.LinkedHashMap;
5959
import java.util.List;
6060
import java.util.Map;
61-
import java.util.Optional;
6261
import java.util.stream.Collectors;
6362

6463
/**
@@ -307,24 +306,26 @@ public void onFailure(Exception exc) {
307306
return;
308307
}
309308
int currentBatchSize = Math.min(requests.size(), batchSize);
310-
final List<FieldInferenceRequest> currentBatch = requests.subList(0, currentBatchSize);
309+
ChunkingSettings chunkingSettings = requests.get(0).chunkingSettings;
310+
List<FieldInferenceRequest> currentBatch = new ArrayList<>();
311+
List<FieldInferenceRequest> others = new ArrayList<>();
312+
for (int i = 0; i < currentBatchSize; i++) {
313+
FieldInferenceRequest request = requests.get(i);
314+
if ((chunkingSettings == null && request.chunkingSettings == null) || request.chunkingSettings.equals(chunkingSettings)) {
315+
currentBatch.add(request);
316+
} else {
317+
others.add(request);
318+
}
319+
}
320+
311321
final List<FieldInferenceRequest> nextBatch = requests.subList(currentBatchSize, requests.size());
312-
//
313-
// List<ChunkedInputs> chunkedInputs = currentBatch.stream()
314-
// .map(request -> new ChunkedInputs(request.chunkingSettings(), List.of(request.input())))
315-
// .toList();
316-
317-
List<ChunkedInputs> chunkedInputs = currentBatch.stream()
318-
.collect(Collectors.groupingBy(request -> Optional.ofNullable(request.chunkingSettings())))
319-
.entrySet()
320-
.stream()
321-
.map(
322-
entry -> new ChunkedInputs(
323-
entry.getKey().orElse(null),
324-
entry.getValue().stream().map(FieldInferenceRequest::input).collect(Collectors.toList())
325-
)
326-
)
327-
.toList();
322+
nextBatch.addAll(others);
323+
324+
// We can assume current batch has all the same chunking settings
325+
ChunkedInputs chunkedInputs = new ChunkedInputs(
326+
chunkingSettings,
327+
currentBatch.stream().map(r -> r.input).collect(Collectors.toList())
328+
);
328329

329330
ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {
330331
@Override
@@ -390,19 +391,17 @@ private void onFinish() {
390391
}
391392
};
392393

393-
for (ChunkedInputs chunkedInput : chunkedInputs) {
394-
inferenceProvider.service()
395-
.chunkedInfer(
396-
inferenceProvider.model(),
397-
null,
398-
chunkedInput.inputs(),
399-
Map.of(),
400-
chunkedInput.chunkingSettings(),
401-
InputType.INGEST,
402-
TimeValue.MAX_VALUE,
403-
completionListener
404-
);
405-
}
394+
inferenceProvider.service()
395+
.chunkedInfer(
396+
inferenceProvider.model(),
397+
null,
398+
chunkedInputs.inputs(),
399+
Map.of(),
400+
chunkedInputs.chunkingSettings(),
401+
InputType.INGEST,
402+
TimeValue.MAX_VALUE,
403+
completionListener
404+
);
406405
}
407406

408407
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ protected void assertEqualInstances(SemanticTextField expectedInstance, Semantic
7070
assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues()));
7171
assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings()));
7272
assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size()));
73+
assertThat(newInstance.chunkingSettings(), equalTo(expectedInstance.chunkingSettings()));
7374
MinimalServiceSettings modelSettings = newInstance.inference().modelSettings();
7475
for (var entry : newInstance.inference().chunks().entrySet()) {
7576
var expectedChunks = expectedInstance.inference().chunks().get(entry.getKey());

0 commit comments

Comments
 (0)