|
13 | 13 | import org.elasticsearch.inference.ChunkInferenceInput; |
14 | 14 | import org.elasticsearch.inference.ChunkedInference; |
15 | 15 | import org.elasticsearch.inference.ChunkingSettings; |
| 16 | +import org.elasticsearch.inference.ChunkingStrategy; |
16 | 17 | import org.elasticsearch.inference.InferenceServiceResults; |
17 | 18 | import org.elasticsearch.rest.RestStatus; |
18 | 19 | import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; |
|
22 | 23 |
|
23 | 24 | import java.util.ArrayList; |
24 | 25 | import java.util.List; |
| 26 | +import java.util.Map; |
| 27 | +import java.util.Objects; |
25 | 28 | import java.util.concurrent.atomic.AtomicInteger; |
26 | 29 | import java.util.concurrent.atomic.AtomicReferenceArray; |
27 | 30 | import java.util.function.Supplier; |
@@ -94,13 +97,20 @@ public EmbeddingRequestChunker( |
94 | 97 | defaultChunkingSettings = DEFAULT_CHUNKING_SETTINGS; |
95 | 98 | } |
96 | 99 |
|
| 100 | + Map<ChunkingStrategy, Chunker> chunkers = inputs.stream() |
| 101 | + .map(ChunkInferenceInput::chunkingSettings) |
| 102 | + .filter(Objects::nonNull) |
| 103 | + .map(ChunkingSettings::getChunkingStrategy) |
| 104 | + .distinct() |
| 105 | + .collect(Collectors.toMap(chunkingStrategy -> chunkingStrategy, ChunkerBuilder::fromChunkingStrategy)); |
| 106 | + |
97 | 107 | List<Request> allRequests = new ArrayList<>(); |
98 | 108 | for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) { |
99 | 109 | ChunkingSettings chunkingSettings = inputs.get(inputIndex).chunkingSettings(); |
100 | 110 | if (chunkingSettings == null) { |
101 | 111 | chunkingSettings = defaultChunkingSettings; |
102 | 112 | } |
103 | | - Chunker chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); |
| 113 | + Chunker chunker = chunkers.get(chunkingSettings.getChunkingStrategy()); |
104 | 114 | List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings); |
105 | 115 | int resultCount = Math.min(chunks.size(), MAX_CHUNKS); |
106 | 116 | resultEmbeddings.add(new AtomicReferenceArray<>(resultCount)); |
|
0 commit comments