Skip to content

Commit 612d2fb

Browse files
committed
PR optimization
1 parent 546e333 commit 612d2fb

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.inference.ChunkInferenceInput;
1414
import org.elasticsearch.inference.ChunkedInference;
1515
import org.elasticsearch.inference.ChunkingSettings;
16+
import org.elasticsearch.inference.ChunkingStrategy;
1617
import org.elasticsearch.inference.InferenceServiceResults;
1718
import org.elasticsearch.rest.RestStatus;
1819
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
@@ -22,6 +23,8 @@
2223

2324
import java.util.ArrayList;
2425
import java.util.List;
26+
import java.util.Map;
27+
import java.util.Objects;
2528
import java.util.concurrent.atomic.AtomicInteger;
2629
import java.util.concurrent.atomic.AtomicReferenceArray;
2730
import java.util.function.Supplier;
@@ -94,13 +97,20 @@ public EmbeddingRequestChunker(
9497
defaultChunkingSettings = DEFAULT_CHUNKING_SETTINGS;
9598
}
9699

100+
Map<ChunkingStrategy, Chunker> chunkers = inputs.stream()
101+
.map(ChunkInferenceInput::chunkingSettings)
102+
.filter(Objects::nonNull)
103+
.map(ChunkingSettings::getChunkingStrategy)
104+
.distinct()
105+
.collect(Collectors.toMap(chunkingStrategy -> chunkingStrategy, ChunkerBuilder::fromChunkingStrategy));
106+
97107
List<Request> allRequests = new ArrayList<>();
98108
for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) {
99109
ChunkingSettings chunkingSettings = inputs.get(inputIndex).chunkingSettings();
100110
if (chunkingSettings == null) {
101111
chunkingSettings = defaultChunkingSettings;
102112
}
103-
Chunker chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
113+
Chunker chunker = chunkers.get(chunkingSettings.getChunkingStrategy());
104114
List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings);
105115
int resultCount = Math.min(chunks.size(), MAX_CHUNKS);
106116
resultEmbeddings.add(new AtomicReferenceArray<>(resultCount));

0 commit comments

Comments
 (0)