Skip to content

Commit 9ab10b7

Browse files
Adding support for JinaAI late chunking
1 parent 4275bc7 commit 9ab10b7

File tree

5 files changed

+66
-12
lines changed

5 files changed

+66
-12
lines changed

server/src/main/java/org/elasticsearch/inference/TaskSettings.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,8 @@ public interface TaskSettings extends ToXContentObject, VersionedNamedWriteable
1919
boolean isEmpty();
2020

2121
TaskSettings updatedTaskSettings(Map<String, Object> newSettings);
22+
23+
default Boolean isLateChunkingEnabled() {
24+
return null;
25+
}
2226
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ public EmbeddingRequestChunker(
8787
List<ChunkInferenceInput> inputs,
8888
int maxNumberOfInputsPerBatch,
8989
ChunkingSettings defaultChunkingSettings
90+
) {
91+
this(inputs, maxNumberOfInputsPerBatch, false, defaultChunkingSettings);
92+
}
93+
94+
public EmbeddingRequestChunker(
95+
List<ChunkInferenceInput> inputs,
96+
int maxNumberOfInputsPerBatch,
97+
Boolean isLateChunkingEnabled,
98+
ChunkingSettings defaultChunkingSettings
9099
) {
91100
this.resultEmbeddings = new ArrayList<>(inputs.size());
92101
this.resultOffsetStarts = new ArrayList<>(inputs.size());
@@ -133,13 +142,24 @@ public EmbeddingRequestChunker(
133142
}
134143
}
135144

136-
AtomicInteger counter = new AtomicInteger();
137-
this.batchRequests = allRequests.stream()
138-
.collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch))
139-
.values()
140-
.stream()
141-
.map(BatchRequest::new)
142-
.toList();
145+
if (isLateChunkingEnabled != null && isLateChunkingEnabled) {
146+
// This must be true for late chunking cases otherwise we can't pass all chunks in a single request
147+
assert (maxNumberOfInputsPerBatch >= MAX_CHUNKS);
148+
this.batchRequests = allRequests.stream()
149+
.collect(Collectors.groupingBy(Request::inputIndex))
150+
.values()
151+
.stream()
152+
.map(BatchRequest::new)
153+
.toList();
154+
} else {
155+
AtomicInteger counter = new AtomicInteger();
156+
this.batchRequests = allRequests.stream()
157+
.collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch))
158+
.values()
159+
.stream()
160+
.map(BatchRequest::new)
161+
.toList();
162+
}
143163
}
144164

145165
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ protected void doChunkedInfer(
282282
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
283283
inputs.getInputs(),
284284
EMBEDDING_MAX_BATCH_SIZE,
285+
jinaaiModel.getTaskSettings().isLateChunkingEnabled(),
285286
jinaaiModel.getConfigurations().getChunkingSettings()
286287
).batchRequestsWithListeners(listener);
287288

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.Objects;
2525

2626
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
27+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
2728
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
2829
import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIService.VALID_INPUT_TYPE_VALUES;
2930

@@ -36,6 +37,7 @@ public class JinaAIEmbeddingsTaskSettings implements TaskSettings {
3637
public static final String NAME = "jinaai_embeddings_task_settings";
3738
public static final JinaAIEmbeddingsTaskSettings EMPTY_SETTINGS = new JinaAIEmbeddingsTaskSettings((InputType) null);
3839
static final String INPUT_TYPE = "input_type";
40+
static final String LATE_CHUNKING = "late_chunking";
3941

4042
public static JinaAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
4143
if (map == null || map.isEmpty()) {
@@ -53,11 +55,13 @@ public static JinaAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
5355
validationException
5456
);
5557

58+
Boolean lateChunking = extractOptionalBoolean(map, LATE_CHUNKING, validationException);
59+
5660
if (validationException.validationErrors().isEmpty() == false) {
5761
throw validationException;
5862
}
5963

60-
return new JinaAIEmbeddingsTaskSettings(inputType);
64+
return new JinaAIEmbeddingsTaskSettings(inputType, lateChunking);
6165
}
6266

6367
/**
@@ -77,7 +81,8 @@ public static JinaAIEmbeddingsTaskSettings of(
7781
) {
7882
var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings);
7983

80-
return new JinaAIEmbeddingsTaskSettings(inputTypeToUse);
84+
return new JinaAIEmbeddingsTaskSettings(inputTypeToUse, requestTaskSettings.lateChunking);
85+
// TODO: Check the above
8186
}
8287

8388
private static InputType getValidInputType(
@@ -94,14 +99,22 @@ private static InputType getValidInputType(
9499
}
95100

96101
private final InputType inputType;
102+
private final Boolean lateChunking;
97103

98104
public JinaAIEmbeddingsTaskSettings(StreamInput in) throws IOException {
99-
this(in.readOptionalEnum(InputType.class));
105+
this(in.readOptionalEnum(InputType.class), in.readOptionalBoolean());
106+
}
107+
108+
public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable Boolean lateChunking) {
109+
validateInputType(inputType);
110+
this.inputType = inputType;
111+
this.lateChunking = lateChunking;
100112
}
101113

102114
public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType) {
103115
validateInputType(inputType);
104116
this.inputType = inputType;
117+
this.lateChunking = null;
105118
}
106119

107120
private static void validateInputType(InputType inputType) {
@@ -124,6 +137,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
124137
builder.field(INPUT_TYPE, inputType);
125138
}
126139

140+
// TODO: Add a transport version
141+
if (lateChunking != null) {
142+
builder.field(LATE_CHUNKING, lateChunking);
143+
}
144+
127145
builder.endObject();
128146
return builder;
129147
}
@@ -132,6 +150,11 @@ public InputType getInputType() {
132150
return inputType;
133151
}
134152

153+
@Override
154+
public Boolean isLateChunkingEnabled() {
155+
return lateChunking;
156+
}
157+
135158
@Override
136159
public String getWriteableName() {
137160
return NAME;
@@ -145,19 +168,20 @@ public TransportVersion getMinimalSupportedVersion() {
145168
@Override
146169
public void writeTo(StreamOutput out) throws IOException {
147170
out.writeOptionalEnum(inputType);
171+
out.writeOptionalBoolean(lateChunking);
148172
}
149173

150174
@Override
151175
public boolean equals(Object o) {
152176
if (this == o) return true;
153177
if (o == null || getClass() != o.getClass()) return false;
154178
JinaAIEmbeddingsTaskSettings that = (JinaAIEmbeddingsTaskSettings) o;
155-
return Objects.equals(inputType, that.inputType);
179+
return Objects.equals(inputType, that.inputType) && lateChunking == that.lateChunking;
156180
}
157181

158182
@Override
159183
public int hashCode() {
160-
return Objects.hash(inputType);
184+
return Objects.hash(inputType, lateChunking);
161185
}
162186

163187
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public record JinaAIEmbeddingsRequestEntity(
3434
private static final String CLASSIFICATION = "classification";
3535
private static final String INPUT_FIELD = "input";
3636
private static final String MODEL_FIELD = "model";
37+
private static final String LATE_CHUNKING = "late_chunking";
3738
public static final String TASK_TYPE_FIELD = "task";
3839
static final String EMBEDDING_TYPE_FIELD = "embedding_type";
3940

@@ -49,6 +50,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4950
builder.field(INPUT_FIELD, input);
5051
builder.field(MODEL_FIELD, model);
5152

53+
if (taskSettings.isLateChunkingEnabled() != null) {
54+
builder.field(LATE_CHUNKING, taskSettings.isLateChunkingEnabled());
55+
}
56+
5257
if (embeddingType != null) {
5358
builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toRequestString());
5459
}

0 commit comments

Comments
 (0)