diff --git a/docs/changelog/130038.yaml b/docs/changelog/130038.yaml new file mode 100644 index 0000000000000..b4a62c9ee0f99 --- /dev/null +++ b/docs/changelog/130038.yaml @@ -0,0 +1,5 @@ +pr: 130038 +summary: Adding support for JinaAI late chunking +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/TaskSettings.java b/server/src/main/java/org/elasticsearch/inference/TaskSettings.java index 7dd20688245ba..b7590cca07293 100644 --- a/server/src/main/java/org/elasticsearch/inference/TaskSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskSettings.java @@ -19,4 +19,8 @@ public interface TaskSettings extends ToXContentObject, VersionedNamedWriteable boolean isEmpty(); TaskSettings updatedTaskSettings(Map newSettings); + + default Boolean isLateChunkingEnabled() { + return null; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 2df2f1e62f89a..58550187ff361 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -87,6 +87,15 @@ public EmbeddingRequestChunker( List inputs, int maxNumberOfInputsPerBatch, ChunkingSettings defaultChunkingSettings + ) { + this(inputs, maxNumberOfInputsPerBatch, true, defaultChunkingSettings); + } + + public EmbeddingRequestChunker( + List inputs, + int maxNumberOfInputsPerBatch, + Boolean shouldBatchAcrossInputs, + ChunkingSettings defaultChunkingSettings ) { this.resultEmbeddings = new ArrayList<>(inputs.size()); this.resultOffsetStarts = new ArrayList<>(inputs.size()); @@ -133,13 +142,23 @@ public EmbeddingRequestChunker( } } - AtomicInteger counter = new AtomicInteger(); - this.batchRequests = allRequests.stream() - .collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch)) - .values() - .stream() - .map(BatchRequest::new) - .toList(); + if (shouldBatchAcrossInputs == null || shouldBatchAcrossInputs) { + AtomicInteger counter = new AtomicInteger(); + this.batchRequests = allRequests.stream() + .collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch)) + .values() + .stream() + .map(BatchRequest::new) + .toList(); + } else { + assert (maxNumberOfInputsPerBatch >= MAX_CHUNKS); + this.batchRequests = allRequests.stream() + .collect(Collectors.groupingBy(Request::inputIndex)) + .values() + .stream() + .map(BatchRequest::new) + .toList(); + } } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index c2e88cb6cdc7c..1632d2a3ff7f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -279,9 +279,12 @@ protected void doChunkedInfer( JinaAIModel jinaaiModel = (JinaAIModel) model; var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents()); + var isLateChunkingEnabled = jinaaiModel.getTaskSettings().isLateChunkingEnabled(); + var shouldBatchAcrossInputs = isLateChunkingEnabled == null || isLateChunkingEnabled == false; List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, + shouldBatchAcrossInputs, jinaaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java index 8869871abd4c1..db55b9012d60c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java @@ -24,6 +24,7 @@ import java.util.Objects; import static org.elasticsearch.inference.InputType.invalidInputTypeMessage; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIService.VALID_INPUT_TYPE_VALUES; @@ -36,6 +37,7 @@ public class JinaAIEmbeddingsTaskSettings implements TaskSettings { public static final String NAME = "jinaai_embeddings_task_settings"; public static final JinaAIEmbeddingsTaskSettings EMPTY_SETTINGS = new JinaAIEmbeddingsTaskSettings((InputType) null); static final String INPUT_TYPE = "input_type"; + static final String LATE_CHUNKING = "late_chunking"; public static JinaAIEmbeddingsTaskSettings fromMap(Map map) { if (map == null || map.isEmpty()) { @@ -53,11 +55,13 @@ public static JinaAIEmbeddingsTaskSettings fromMap(Map map) { validationException ); + Boolean lateChunking = extractOptionalBoolean(map, LATE_CHUNKING, validationException); + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new JinaAIEmbeddingsTaskSettings(inputType); + return new JinaAIEmbeddingsTaskSettings(inputType, lateChunking); } /** @@ -76,8 +80,12 @@ public static JinaAIEmbeddingsTaskSettings of( JinaAIEmbeddingsTaskSettings requestTaskSettings ) { var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings); + // TODO: Should a null late chunking override a non-null late chunking? + var shouldUseLateChunking = requestTaskSettings.lateChunking != null + ? requestTaskSettings.lateChunking + : originalSettings.lateChunking; - return new JinaAIEmbeddingsTaskSettings(inputTypeToUse); + return new JinaAIEmbeddingsTaskSettings(inputTypeToUse, shouldUseLateChunking); } private static InputType getValidInputType( @@ -94,14 +102,22 @@ private static InputType getValidInputType( } private final InputType inputType; + private final Boolean lateChunking; public JinaAIEmbeddingsTaskSettings(StreamInput in) throws IOException { - this(in.readOptionalEnum(InputType.class)); + this(in.readOptionalEnum(InputType.class), in.readOptionalBoolean()); + } + + public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable Boolean lateChunking) { + validateInputType(inputType); + this.inputType = inputType; + this.lateChunking = lateChunking; } public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType) { validateInputType(inputType); this.inputType = inputType; + this.lateChunking = null; } private static void validateInputType(InputType inputType) { @@ -124,6 +140,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INPUT_TYPE, inputType); } + // TODO: Add a transport version + if (lateChunking != null) { + builder.field(LATE_CHUNKING, lateChunking); + } + builder.endObject(); return builder; } @@ -132,6 +153,11 @@ public InputType getInputType() { return inputType; } + @Override + public Boolean isLateChunkingEnabled() { + return lateChunking; + } + @Override public String getWriteableName() { return NAME; @@ -145,6 +171,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(inputType); + out.writeOptionalBoolean(lateChunking); } @Override @@ -152,12 +179,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; JinaAIEmbeddingsTaskSettings that = (JinaAIEmbeddingsTaskSettings) o; - return Objects.equals(inputType, that.inputType); + return Objects.equals(inputType, that.inputType) && Objects.equals(lateChunking, that.lateChunking); } @Override public int hashCode() { - return Objects.hash(inputType); + return Objects.hash(inputType, lateChunking); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java index 791c4af76b145..284810c22f94c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java @@ -34,6 +34,7 @@ public record JinaAIEmbeddingsRequestEntity( private static final String CLASSIFICATION = "classification"; private static final String INPUT_FIELD = "input"; private static final String MODEL_FIELD = "model"; + private static final String LATE_CHUNKING = "late_chunking"; public static final String TASK_TYPE_FIELD = "task"; static final String EMBEDDING_TYPE_FIELD = "embedding_type"; @@ -49,6 +50,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INPUT_FIELD, input); builder.field(MODEL_FIELD, model); + if (taskSettings.isLateChunkingEnabled() != null) { + builder.field(LATE_CHUNKING, taskSettings.isLateChunkingEnabled()); + } + if (embeddingType != null) { builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toRequestString()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 411d992adfa3d..f263d7c060cfb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -866,6 +866,40 @@ public void testMergingListener_Sparse() { } } + public void testShouldBatchAcrossInputsIsFalse_DoesNotBatchRequestsFromSeparateInputs() { + int batchSize = 512; + + var testSentence = "This is a test sentence with ten words in total. "; + + List inputs = List.of( + new ChunkInferenceInput(testSentence + testSentence + testSentence), + new ChunkInferenceInput(testSentence), + new ChunkInferenceInput(testSentence + testSentence + testSentence + testSentence) + ); + + var chunkingSettings = new SentenceBoundaryChunkingSettings(10, 0); + + var finalListener = testListener(); + List batches = new EmbeddingRequestChunker<>( + inputs, + batchSize, + false, + chunkingSettings + ).batchRequestsWithListeners(finalListener); + + assertThat(batches, hasSize(3)); + var expectedBatchSizes = List.of(3, 1, 4); + for (int i = 0; i < batches.size(); i++) { + assertThat(batches.get(i).batch().inputs().get(), hasSize(expectedBatchSizes.get(i))); + batches.get(i) + .listener() + .onResponse(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1f })))); + } + + assertNotNull(finalListener.results); + assertThat(finalListener.results, hasSize(3)); + } + public void testListenerErrorsWithWrongNumberOfResponses() { List inputs = List.of( new ChunkInferenceInput("1st small"), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettingsTests.java index 9235d26c87d87..8c383d0dd4975 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettingsTests.java @@ -74,11 +74,20 @@ public void testFromMap_CreatesEmptySettings_WhenMapIsNull() { } public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() { + var inputType = randomFrom(VALID_INPUT_TYPE_VALUES); + var isLateChunkingEnabled = randomBoolean(); MatcherAssert.assertThat( JinaAIEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(JinaAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString())) + new HashMap<>( + Map.of( + JinaAIEmbeddingsTaskSettings.INPUT_TYPE, + inputType.toString(), + JinaAIEmbeddingsTaskSettings.LATE_CHUNKING, + isLateChunkingEnabled + ) + ) ), - is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST)) + is(new JinaAIEmbeddingsTaskSettings(inputType, isLateChunkingEnabled)) ); } @@ -131,16 +140,45 @@ public void testXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { } public void testOf_KeepsOriginalValuesWhenRequestSettingsAreNull() { - var taskSettings = new JinaAIEmbeddingsTaskSettings(InputType.INGEST); + var inputType = randomFrom(VALID_INPUT_TYPE_VALUES); + var isLateChunkingEnabled = randomBoolean(); + var taskSettings = new JinaAIEmbeddingsTaskSettings(inputType, isLateChunkingEnabled); var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of(taskSettings, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS); MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); } - public void testOf_UsesRequestTaskSettings() { - var taskSettings = new JinaAIEmbeddingsTaskSettings((InputType) null); - var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of(taskSettings, new JinaAIEmbeddingsTaskSettings(InputType.INGEST)); + public void testOf_UsesRequestTaskSettingsWhenSettingsAreNull() { + var taskSettings = new JinaAIEmbeddingsTaskSettings(null, null); + + var overriddenInputType = randomFrom(VALID_INPUT_TYPE_VALUES); + var overriddenIsLateChunkingEnabled = randomBoolean(); + var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of( + taskSettings, + new JinaAIEmbeddingsTaskSettings(overriddenInputType, overriddenIsLateChunkingEnabled) + ); + + MatcherAssert.assertThat( + overriddenTaskSettings, + is(new JinaAIEmbeddingsTaskSettings(overriddenInputType, overriddenIsLateChunkingEnabled)) + ); + } + + public void testOf_UsesRequestTaskSettingsWhenSettingsAreNotNull() { + var inputType = randomFrom(VALID_INPUT_TYPE_VALUES); + var isLateChunkingEnabled = randomBoolean(); + var taskSettings = new JinaAIEmbeddingsTaskSettings(inputType, isLateChunkingEnabled); + + var overriddenInputType = randomValueOtherThan(inputType, () -> randomFrom(VALID_INPUT_TYPE_VALUES)); + var overriddenIsLateChunkingEnabled = isLateChunkingEnabled == false; + var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of( + taskSettings, + new JinaAIEmbeddingsTaskSettings(overriddenInputType, overriddenIsLateChunkingEnabled) + ); - MatcherAssert.assertThat(overriddenTaskSettings, is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat( + overriddenTaskSettings, + is(new JinaAIEmbeddingsTaskSettings(overriddenInputType, overriddenIsLateChunkingEnabled)) + ); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java index 733dcc69c278d..7dd990fde9e46 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java @@ -24,6 +24,23 @@ public class JinaAIEmbeddingsRequestEntityTests extends ESTestCase { public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new JinaAIEmbeddingsRequestEntity( + List.of("abc"), + InputType.INTERNAL_INGEST, + new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), + "model", + JinaAIEmbeddingType.FLOAT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","late_chunking":true,"embedding_type":"float","task":"retrieval.passage"}""")); + } + + public void testXContent_WritesOnlyInputTypeField_WhenItIsTheOnlyOptionalFieldDefined() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( List.of("abc"), InputType.INTERNAL_INGEST, @@ -40,6 +57,23 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {"input":["abc"],"model":"model","embedding_type":"float","task":"retrieval.passage"}""")); } + public void testXContent_WritesOnlyLateChunkingField_WhenItIsTheOnlyOptionalFieldDefined() throws IOException { + var entity = new JinaAIEmbeddingsRequestEntity( + List.of("abc"), + InputType.INTERNAL_INGEST, + new JinaAIEmbeddingsTaskSettings(null, false), + "model", + null + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","late_chunking":false,"task":"retrieval.passage"}""")); + } + public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( List.of("abc"),