Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/130038.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 130038
summary: Adding support for JinaAI late chunking
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ public interface TaskSettings extends ToXContentObject, VersionedNamedWriteable
boolean isEmpty();

TaskSettings updatedTaskSettings(Map<String, Object> newSettings);

default Boolean isLateChunkingEnabled() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ public EmbeddingRequestChunker(
List<ChunkInferenceInput> inputs,
int maxNumberOfInputsPerBatch,
ChunkingSettings defaultChunkingSettings
) {
this(inputs, maxNumberOfInputsPerBatch, true, defaultChunkingSettings);
}

public EmbeddingRequestChunker(
List<ChunkInferenceInput> inputs,
int maxNumberOfInputsPerBatch,
Boolean shouldBatchAcrossInputs,
ChunkingSettings defaultChunkingSettings
) {
this.resultEmbeddings = new ArrayList<>(inputs.size());
this.resultOffsetStarts = new ArrayList<>(inputs.size());
Expand Down Expand Up @@ -133,13 +142,23 @@ public EmbeddingRequestChunker(
}
}

AtomicInteger counter = new AtomicInteger();
this.batchRequests = allRequests.stream()
.collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch))
.values()
.stream()
.map(BatchRequest::new)
.toList();
if (shouldBatchAcrossInputs == null || shouldBatchAcrossInputs) {
AtomicInteger counter = new AtomicInteger();
this.batchRequests = allRequests.stream()
.collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch))
.values()
.stream()
.map(BatchRequest::new)
.toList();
} else {
assert (maxNumberOfInputsPerBatch >= MAX_CHUNKS);
this.batchRequests = allRequests.stream()
.collect(Collectors.groupingBy(Request::inputIndex))
.values()
.stream()
.map(BatchRequest::new)
.toList();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,12 @@ protected void doChunkedInfer(
JinaAIModel jinaaiModel = (JinaAIModel) model;
var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents());

var isLateChunkingEnabled = jinaaiModel.getTaskSettings().isLateChunkingEnabled();
var shouldBatchAcrossInputs = isLateChunkingEnabled == null || isLateChunkingEnabled == false;
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
shouldBatchAcrossInputs,
jinaaiModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Objects;

import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIService.VALID_INPUT_TYPE_VALUES;

Expand All @@ -36,6 +37,7 @@ public class JinaAIEmbeddingsTaskSettings implements TaskSettings {
public static final String NAME = "jinaai_embeddings_task_settings";
public static final JinaAIEmbeddingsTaskSettings EMPTY_SETTINGS = new JinaAIEmbeddingsTaskSettings((InputType) null);
static final String INPUT_TYPE = "input_type";
static final String LATE_CHUNKING = "late_chunking";

public static JinaAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
if (map == null || map.isEmpty()) {
Expand All @@ -53,11 +55,13 @@ public static JinaAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
validationException
);

Boolean lateChunking = extractOptionalBoolean(map, LATE_CHUNKING, validationException);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return new JinaAIEmbeddingsTaskSettings(inputType);
return new JinaAIEmbeddingsTaskSettings(inputType, lateChunking);
}

/**
Expand All @@ -76,8 +80,12 @@ public static JinaAIEmbeddingsTaskSettings of(
JinaAIEmbeddingsTaskSettings requestTaskSettings
) {
var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings);
// TODO: Should a null late chunking override a non-null late chunking?
var shouldUseLateChunking = requestTaskSettings.lateChunking != null
? requestTaskSettings.lateChunking
: originalSettings.lateChunking;

return new JinaAIEmbeddingsTaskSettings(inputTypeToUse);
return new JinaAIEmbeddingsTaskSettings(inputTypeToUse, shouldUseLateChunking);
}

private static InputType getValidInputType(
Expand All @@ -94,14 +102,22 @@ private static InputType getValidInputType(
}

private final InputType inputType;
private final Boolean lateChunking;

public JinaAIEmbeddingsTaskSettings(StreamInput in) throws IOException {
this(in.readOptionalEnum(InputType.class));
this(in.readOptionalEnum(InputType.class), in.readOptionalBoolean());
}

public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable Boolean lateChunking) {
validateInputType(inputType);
this.inputType = inputType;
this.lateChunking = lateChunking;
}

public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType) {
validateInputType(inputType);
this.inputType = inputType;
this.lateChunking = null;
}

private static void validateInputType(InputType inputType) {
Expand All @@ -124,6 +140,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(INPUT_TYPE, inputType);
}

// TODO: Add a transport version
if (lateChunking != null) {
builder.field(LATE_CHUNKING, lateChunking);
}

builder.endObject();
return builder;
}
Expand All @@ -132,6 +153,11 @@ public InputType getInputType() {
return inputType;
}

@Override
public Boolean isLateChunkingEnabled() {
return lateChunking;
}

@Override
public String getWriteableName() {
return NAME;
Expand All @@ -145,19 +171,20 @@ public TransportVersion getMinimalSupportedVersion() {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(inputType);
out.writeOptionalBoolean(lateChunking);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JinaAIEmbeddingsTaskSettings that = (JinaAIEmbeddingsTaskSettings) o;
return Objects.equals(inputType, that.inputType);
return Objects.equals(inputType, that.inputType) && Objects.equals(lateChunking, that.lateChunking);
}

@Override
public int hashCode() {
return Objects.hash(inputType);
return Objects.hash(inputType, lateChunking);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public record JinaAIEmbeddingsRequestEntity(
private static final String CLASSIFICATION = "classification";
private static final String INPUT_FIELD = "input";
private static final String MODEL_FIELD = "model";
private static final String LATE_CHUNKING = "late_chunking";
public static final String TASK_TYPE_FIELD = "task";
static final String EMBEDDING_TYPE_FIELD = "embedding_type";

Expand All @@ -49,6 +50,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(INPUT_FIELD, input);
builder.field(MODEL_FIELD, model);

if (taskSettings.isLateChunkingEnabled() != null) {
builder.field(LATE_CHUNKING, taskSettings.isLateChunkingEnabled());
}

if (embeddingType != null) {
builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toRequestString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,40 @@ public void testMergingListener_Sparse() {
}
}

public void testShouldBatchAcrossInputsIsFalse_DoesNotBatchRequestsFromSeparateInputs() {
int batchSize = 512;

var testSentence = "This is a test sentence with ten words in total. ";

List<ChunkInferenceInput> inputs = List.of(
new ChunkInferenceInput(testSentence + testSentence + testSentence),
new ChunkInferenceInput(testSentence),
new ChunkInferenceInput(testSentence + testSentence + testSentence + testSentence)
);

var chunkingSettings = new SentenceBoundaryChunkingSettings(10, 0);

var finalListener = testListener();
List<EmbeddingRequestChunker.BatchRequestAndListener> batches = new EmbeddingRequestChunker<>(
inputs,
batchSize,
false,
chunkingSettings
).batchRequestsWithListeners(finalListener);

assertThat(batches, hasSize(3));
var expectedBatchSizes = List.of(3, 1, 4);
for (int i = 0; i < batches.size(); i++) {
assertThat(batches.get(i).batch().inputs().get(), hasSize(expectedBatchSizes.get(i)));
batches.get(i)
.listener()
.onResponse(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1f }))));
}

assertNotNull(finalListener.results);
assertThat(finalListener.results, hasSize(3));
}

public void testListenerErrorsWithWrongNumberOfResponses() {
List<ChunkInferenceInput> inputs = List.of(
new ChunkInferenceInput("1st small"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,20 @@ public void testFromMap_CreatesEmptySettings_WhenMapIsNull() {
}

public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() {
var inputType = randomFrom(VALID_INPUT_TYPE_VALUES);
var isLateChunkingEnabled = randomBoolean();
MatcherAssert.assertThat(
JinaAIEmbeddingsTaskSettings.fromMap(
new HashMap<>(Map.of(JinaAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString()))
new HashMap<>(
Map.of(
JinaAIEmbeddingsTaskSettings.INPUT_TYPE,
inputType.toString(),
JinaAIEmbeddingsTaskSettings.LATE_CHUNKING,
isLateChunkingEnabled
)
)
),
is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))
is(new JinaAIEmbeddingsTaskSettings(inputType, isLateChunkingEnabled))
);
}

Expand Down Expand Up @@ -131,16 +140,45 @@ public void testXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() {
}

public void testOf_KeepsOriginalValuesWhenRequestSettingsAreNull() {
var taskSettings = new JinaAIEmbeddingsTaskSettings(InputType.INGEST);
var inputType = randomFrom(VALID_INPUT_TYPE_VALUES);
var isLateChunkingEnabled = randomBoolean();
var taskSettings = new JinaAIEmbeddingsTaskSettings(inputType, isLateChunkingEnabled);
var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of(taskSettings, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS);
MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings));
}

public void testOf_UsesRequestTaskSettings() {
var taskSettings = new JinaAIEmbeddingsTaskSettings((InputType) null);
var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of(taskSettings, new JinaAIEmbeddingsTaskSettings(InputType.INGEST));
public void testOf_UsesRequestTaskSettingsWhenSettingsAreNull() {
var taskSettings = new JinaAIEmbeddingsTaskSettings(null, null);

var overriddenInputType = randomFrom(VALID_INPUT_TYPE_VALUES);
var overriddenIsLateChunkingEnabled = randomBoolean();
var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of(
taskSettings,
new JinaAIEmbeddingsTaskSettings(overriddenInputType, overriddenIsLateChunkingEnabled)
);

MatcherAssert.assertThat(
overriddenTaskSettings,
is(new JinaAIEmbeddingsTaskSettings(overriddenInputType, overriddenIsLateChunkingEnabled))
);
}

public void testOf_UsesRequestTaskSettingsWhenSettingsAreNotNull() {
var inputType = randomFrom(VALID_INPUT_TYPE_VALUES);
var isLateChunkingEnabled = randomBoolean();
var taskSettings = new JinaAIEmbeddingsTaskSettings(inputType, isLateChunkingEnabled);

var overriddenInputType = randomValueOtherThan(inputType, () -> randomFrom(VALID_INPUT_TYPE_VALUES));
var overriddenIsLateChunkingEnabled = isLateChunkingEnabled == false;
var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of(
taskSettings,
new JinaAIEmbeddingsTaskSettings(overriddenInputType, overriddenIsLateChunkingEnabled)
);

MatcherAssert.assertThat(overriddenTaskSettings, is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST)));
MatcherAssert.assertThat(
overriddenTaskSettings,
is(new JinaAIEmbeddingsTaskSettings(overriddenInputType, overriddenIsLateChunkingEnabled))
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@

public class JinaAIEmbeddingsRequestEntityTests extends ESTestCase {
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
var entity = new JinaAIEmbeddingsRequestEntity(
List.of("abc"),
InputType.INTERNAL_INGEST,
new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true),
"model",
JinaAIEmbeddingType.FLOAT
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);

MatcherAssert.assertThat(xContentResult, is("""
{"input":["abc"],"model":"model","late_chunking":true,"embedding_type":"float","task":"retrieval.passage"}"""));
}

public void testXContent_WritesOnlyInputTypeField_WhenItIsTheOnlyOptionalFieldDefined() throws IOException {
var entity = new JinaAIEmbeddingsRequestEntity(
List.of("abc"),
InputType.INTERNAL_INGEST,
Expand All @@ -40,6 +57,23 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException
{"input":["abc"],"model":"model","embedding_type":"float","task":"retrieval.passage"}"""));
}

public void testXContent_WritesOnlyLateChunkingField_WhenItIsTheOnlyOptionalFieldDefined() throws IOException {
var entity = new JinaAIEmbeddingsRequestEntity(
List.of("abc"),
InputType.INTERNAL_INGEST,
new JinaAIEmbeddingsTaskSettings(null, false),
"model",
null
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);

MatcherAssert.assertThat(xContentResult, is("""
{"input":["abc"],"model":"model","late_chunking":false,"task":"retrieval.passage"}"""));
}

public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException {
var entity = new JinaAIEmbeddingsRequestEntity(
List.of("abc"),
Expand Down