From 9ab10b75425a9d11dcaf1216e57dd53509b8d346 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Wed, 25 Jun 2025 13:50:29 -0400 Subject: [PATCH 1/3] Adding support for JinaAI late chunking --- .../elasticsearch/inference/TaskSettings.java | 4 +++ .../chunking/EmbeddingRequestChunker.java | 34 +++++++++++++++---- .../services/jinaai/JinaAIService.java | 1 + .../JinaAIEmbeddingsTaskSettings.java | 34 ++++++++++++++++--- .../JinaAIEmbeddingsRequestEntity.java | 5 +++ 5 files changed, 66 insertions(+), 12 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/TaskSettings.java b/server/src/main/java/org/elasticsearch/inference/TaskSettings.java index 7dd20688245ba..b7590cca07293 100644 --- a/server/src/main/java/org/elasticsearch/inference/TaskSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskSettings.java @@ -19,4 +19,8 @@ public interface TaskSettings extends ToXContentObject, VersionedNamedWriteable boolean isEmpty(); TaskSettings updatedTaskSettings(Map newSettings); + + default Boolean isLateChunkingEnabled() { + return null; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 2df2f1e62f89a..de9b9bdca8098 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -87,6 +87,15 @@ public EmbeddingRequestChunker( List inputs, int maxNumberOfInputsPerBatch, ChunkingSettings defaultChunkingSettings + ) { + this(inputs, maxNumberOfInputsPerBatch, false, defaultChunkingSettings); + } + + public EmbeddingRequestChunker( + List inputs, + int maxNumberOfInputsPerBatch, + Boolean isLateChunkingEnabled, + ChunkingSettings defaultChunkingSettings ) { this.resultEmbeddings = new ArrayList<>(inputs.size()); this.resultOffsetStarts = new ArrayList<>(inputs.size()); @@ -133,13 +142,24 @@ public EmbeddingRequestChunker( } } - AtomicInteger counter = new AtomicInteger(); - this.batchRequests = allRequests.stream() - .collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch)) - .values() - .stream() - .map(BatchRequest::new) - .toList(); + if (isLateChunkingEnabled != null && isLateChunkingEnabled) { + // This must be true for late chunking cases otherwise we can't pass all chunks in a single request + assert (maxNumberOfInputsPerBatch >= MAX_CHUNKS); + this.batchRequests = allRequests.stream() + .collect(Collectors.groupingBy(Request::inputIndex)) + .values() + .stream() + .map(BatchRequest::new) + .toList(); + } else { + AtomicInteger counter = new AtomicInteger(); + this.batchRequests = allRequests.stream() + .collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch)) + .values() + .stream() + .map(BatchRequest::new) + .toList(); + } } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index c2e88cb6cdc7c..724d411ea91f1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -282,6 +282,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, + jinaaiModel.getTaskSettings().isLateChunkingEnabled(), jinaaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java index 8869871abd4c1..11cd220895c5d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java @@ -24,6 +24,7 @@ import java.util.Objects; import static org.elasticsearch.inference.InputType.invalidInputTypeMessage; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIService.VALID_INPUT_TYPE_VALUES; @@ -36,6 +37,7 @@ public class JinaAIEmbeddingsTaskSettings implements TaskSettings { public static final String NAME = "jinaai_embeddings_task_settings"; public static final JinaAIEmbeddingsTaskSettings EMPTY_SETTINGS = new JinaAIEmbeddingsTaskSettings((InputType) null); static final String INPUT_TYPE = "input_type"; + static final String LATE_CHUNKING = "late_chunking"; public static JinaAIEmbeddingsTaskSettings fromMap(Map map) { if (map == null || map.isEmpty()) { @@ -53,11 +55,13 @@ public static JinaAIEmbeddingsTaskSettings fromMap(Map map) { validationException ); + Boolean lateChunking = extractOptionalBoolean(map, LATE_CHUNKING, validationException); + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new JinaAIEmbeddingsTaskSettings(inputType); + return new JinaAIEmbeddingsTaskSettings(inputType, lateChunking); } /** @@ -77,7 +81,8 @@ public static JinaAIEmbeddingsTaskSettings of( ) { var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings); - return new JinaAIEmbeddingsTaskSettings(inputTypeToUse); + return new JinaAIEmbeddingsTaskSettings(inputTypeToUse, requestTaskSettings.lateChunking); + // TODO: Check the above } private static InputType getValidInputType( @@ -94,14 +99,22 @@ private static InputType getValidInputType( } private final InputType inputType; + private final Boolean lateChunking; public JinaAIEmbeddingsTaskSettings(StreamInput in) throws IOException { - this(in.readOptionalEnum(InputType.class)); + this(in.readOptionalEnum(InputType.class), in.readOptionalBoolean()); + } + + public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable Boolean lateChunking) { + validateInputType(inputType); + this.inputType = inputType; + this.lateChunking = lateChunking; } public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType) { validateInputType(inputType); this.inputType = inputType; + this.lateChunking = null; } private static void validateInputType(InputType inputType) { @@ -124,6 +137,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INPUT_TYPE, inputType); } + // TODO: Add a transport version + if (lateChunking != null) { + builder.field(LATE_CHUNKING, lateChunking); + } + builder.endObject(); return builder; } @@ -132,6 +150,11 @@ public InputType getInputType() { return inputType; } + @Override + public Boolean isLateChunkingEnabled() { + return lateChunking; + } + @Override public String getWriteableName() { return NAME; @@ -145,6 +168,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(inputType); + out.writeOptionalBoolean(lateChunking); } @Override @@ -152,12 +176,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; JinaAIEmbeddingsTaskSettings that = (JinaAIEmbeddingsTaskSettings) o; - return Objects.equals(inputType, that.inputType); + return Objects.equals(inputType, that.inputType) && lateChunking == that.lateChunking; } @Override public int hashCode() { - return Objects.hash(inputType); + return Objects.hash(inputType, lateChunking); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java index 791c4af76b145..284810c22f94c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java @@ -34,6 +34,7 @@ public record JinaAIEmbeddingsRequestEntity( private static final String CLASSIFICATION = "classification"; private static final String INPUT_FIELD = "input"; private static final String MODEL_FIELD = "model"; + private static final String LATE_CHUNKING = "late_chunking"; public static final String TASK_TYPE_FIELD = "task"; static final String EMBEDDING_TYPE_FIELD = "embedding_type"; @@ -49,6 +50,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INPUT_FIELD, input); builder.field(MODEL_FIELD, model); + if (taskSettings.isLateChunkingEnabled() != null) { + builder.field(LATE_CHUNKING, taskSettings.isLateChunkingEnabled()); + } + if (embeddingType != null) { builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toRequestString()); } From 0a74e1031c982c1b271b48df608999afe964feea Mon Sep 17 00:00:00 2001 From: Dan Rubinstein Date: Wed, 25 Jun 2025 15:10:10 -0400 Subject: [PATCH 2/3] Update docs/changelog/130038.yaml --- docs/changelog/130038.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/130038.yaml diff --git a/docs/changelog/130038.yaml b/docs/changelog/130038.yaml new file mode 100644 index 0000000000000..b4a62c9ee0f99 --- /dev/null +++ b/docs/changelog/130038.yaml @@ -0,0 +1,5 @@ +pr: 130038 +summary: Adding support for JinaAI late chunking +area: Machine Learning +type: enhancement +issues: [] From f94e57459c9e295f5aaaa2055d2427b89119c110 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Fri, 27 Jun 2025 11:35:59 -0400 Subject: [PATCH 3/3] Adding unit tests --- .../chunking/EmbeddingRequestChunker.java | 15 +++--- .../services/jinaai/JinaAIService.java | 4 +- .../JinaAIEmbeddingsTaskSettings.java | 9 ++-- .../EmbeddingRequestChunkerTests.java | 34 ++++++++++++ .../JinaAIEmbeddingsTaskSettingsTests.java | 52 ++++++++++++++++--- .../JinaAIEmbeddingsRequestEntityTests.java | 34 ++++++++++++ 6 files changed, 129 insertions(+), 19 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index de9b9bdca8098..58550187ff361 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -88,13 +88,13 @@ public EmbeddingRequestChunker( int maxNumberOfInputsPerBatch, ChunkingSettings defaultChunkingSettings ) { - this(inputs, maxNumberOfInputsPerBatch, false, defaultChunkingSettings); + this(inputs, maxNumberOfInputsPerBatch, true, defaultChunkingSettings); } public EmbeddingRequestChunker( List inputs, int maxNumberOfInputsPerBatch, - Boolean isLateChunkingEnabled, + Boolean shouldBatchAcrossInputs, ChunkingSettings defaultChunkingSettings ) { this.resultEmbeddings = new ArrayList<>(inputs.size()); @@ -142,19 +142,18 @@ public EmbeddingRequestChunker( } } - if (isLateChunkingEnabled != null && isLateChunkingEnabled) { - // This must be true for late chunking cases otherwise we can't pass all chunks in a single request - assert (maxNumberOfInputsPerBatch >= MAX_CHUNKS); + if (shouldBatchAcrossInputs == null || shouldBatchAcrossInputs) { + AtomicInteger counter = new AtomicInteger(); this.batchRequests = allRequests.stream() - .collect(Collectors.groupingBy(Request::inputIndex)) + .collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch)) .values() .stream() .map(BatchRequest::new) .toList(); } else { - AtomicInteger counter = new AtomicInteger(); + assert (maxNumberOfInputsPerBatch >= MAX_CHUNKS); this.batchRequests = allRequests.stream() - .collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch)) + .collect(Collectors.groupingBy(Request::inputIndex)) .values() .stream() .map(BatchRequest::new) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 724d411ea91f1..1632d2a3ff7f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -279,10 +279,12 @@ protected void doChunkedInfer( JinaAIModel jinaaiModel = (JinaAIModel) model; var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents()); + var isLateChunkingEnabled = jinaaiModel.getTaskSettings().isLateChunkingEnabled(); + var shouldBatchAcrossInputs = isLateChunkingEnabled == null || isLateChunkingEnabled == false; List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - jinaaiModel.getTaskSettings().isLateChunkingEnabled(), + shouldBatchAcrossInputs, jinaaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java index 11cd220895c5d..db55b9012d60c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java @@ -80,9 +80,12 @@ public static JinaAIEmbeddingsTaskSettings of( JinaAIEmbeddingsTaskSettings requestTaskSettings ) { var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings); + // TODO: Should a null late chunking override a non-null late chunking? + var shouldUseLateChunking = requestTaskSettings.lateChunking != null + ? requestTaskSettings.lateChunking + : originalSettings.lateChunking; - return new JinaAIEmbeddingsTaskSettings(inputTypeToUse, requestTaskSettings.lateChunking); - // TODO: Check the above + return new JinaAIEmbeddingsTaskSettings(inputTypeToUse, shouldUseLateChunking); } private static InputType getValidInputType( @@ -176,7 +179,7 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; JinaAIEmbeddingsTaskSettings that = (JinaAIEmbeddingsTaskSettings) o; - return Objects.equals(inputType, that.inputType) && lateChunking == that.lateChunking; + return Objects.equals(inputType, that.inputType) && Objects.equals(lateChunking, that.lateChunking); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 411d992adfa3d..f263d7c060cfb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -866,6 +866,40 @@ public void testMergingListener_Sparse() { } } + public void testShouldBatchAcrossInputsIsFalse_DoesNotBatchRequestsFromSeparateInputs() { + int batchSize = 512; + + var testSentence = "This is a test sentence with ten words in total. "; + + List inputs = List.of( + new ChunkInferenceInput(testSentence + testSentence + testSentence), + new ChunkInferenceInput(testSentence), + new ChunkInferenceInput(testSentence + testSentence + testSentence + testSentence) + ); + + var chunkingSettings = new SentenceBoundaryChunkingSettings(10, 0); + + var finalListener = testListener(); + List batches = new EmbeddingRequestChunker<>( + inputs, + batchSize, + false, + chunkingSettings + ).batchRequestsWithListeners(finalListener); + + assertThat(batches, hasSize(3)); + var expectedBatchSizes = List.of(3, 1, 4); + for (int i = 0; i < batches.size(); i++) { + assertThat(batches.get(i).batch().inputs().get(), hasSize(expectedBatchSizes.get(i))); + batches.get(i) + .listener() + .onResponse(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1f })))); + } + + assertNotNull(finalListener.results); + assertThat(finalListener.results, hasSize(3)); + } + public void testListenerErrorsWithWrongNumberOfResponses() { List inputs = List.of( new ChunkInferenceInput("1st small"), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettingsTests.java index 9235d26c87d87..8c383d0dd4975 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettingsTests.java @@ -74,11 +74,20 @@ public void testFromMap_CreatesEmptySettings_WhenMapIsNull() { } public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() { + var inputType = randomFrom(VALID_INPUT_TYPE_VALUES); + var isLateChunkingEnabled = randomBoolean(); MatcherAssert.assertThat( JinaAIEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(JinaAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString())) + new HashMap<>( + Map.of( + JinaAIEmbeddingsTaskSettings.INPUT_TYPE, + inputType.toString(), + JinaAIEmbeddingsTaskSettings.LATE_CHUNKING, + isLateChunkingEnabled + ) + ) ), - is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST)) + is(new JinaAIEmbeddingsTaskSettings(inputType, isLateChunkingEnabled)) ); } @@ -131,16 +140,45 @@ public void testXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { } public void testOf_KeepsOriginalValuesWhenRequestSettingsAreNull() { - var taskSettings = new JinaAIEmbeddingsTaskSettings(InputType.INGEST); + var inputType = randomFrom(VALID_INPUT_TYPE_VALUES); + var isLateChunkingEnabled = randomBoolean(); + var taskSettings = new JinaAIEmbeddingsTaskSettings(inputType, isLateChunkingEnabled); var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of(taskSettings, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS); MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); } - public void testOf_UsesRequestTaskSettings() { - var taskSettings = new JinaAIEmbeddingsTaskSettings((InputType) null); - var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of(taskSettings, new JinaAIEmbeddingsTaskSettings(InputType.INGEST)); + public void testOf_UsesRequestTaskSettingsWhenSettingsAreNull() { + var taskSettings = new JinaAIEmbeddingsTaskSettings(null, null); + + var overriddenInputType = randomFrom(VALID_INPUT_TYPE_VALUES); + var overriddenIsLateChunkingEnabled = randomBoolean(); + var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of( + taskSettings, + new JinaAIEmbeddingsTaskSettings(overriddenInputType, overriddenIsLateChunkingEnabled) + ); + + MatcherAssert.assertThat( + overriddenTaskSettings, + is(new JinaAIEmbeddingsTaskSettings(overriddenInputType, overriddenIsLateChunkingEnabled)) + ); + } + + public void testOf_UsesRequestTaskSettingsWhenSettingsAreNotNull() { + var inputType = randomFrom(VALID_INPUT_TYPE_VALUES); + var isLateChunkingEnabled = randomBoolean(); + var taskSettings = new JinaAIEmbeddingsTaskSettings(inputType, isLateChunkingEnabled); + + var overriddenInputType = randomValueOtherThan(inputType, () -> randomFrom(VALID_INPUT_TYPE_VALUES)); + var overriddenIsLateChunkingEnabled = isLateChunkingEnabled == false; + var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of( + taskSettings, + new JinaAIEmbeddingsTaskSettings(overriddenInputType, overriddenIsLateChunkingEnabled) + ); - MatcherAssert.assertThat(overriddenTaskSettings, is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat( + overriddenTaskSettings, + is(new JinaAIEmbeddingsTaskSettings(overriddenInputType, overriddenIsLateChunkingEnabled)) + ); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java index 733dcc69c278d..7dd990fde9e46 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java @@ -24,6 +24,23 @@ public class JinaAIEmbeddingsRequestEntityTests extends ESTestCase { public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new JinaAIEmbeddingsRequestEntity( + List.of("abc"), + InputType.INTERNAL_INGEST, + new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), + "model", + JinaAIEmbeddingType.FLOAT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","late_chunking":true,"embedding_type":"float","task":"retrieval.passage"}""")); + } + + public void testXContent_WritesOnlyInputTypeField_WhenItIsTheOnlyOptionalFieldDefined() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( List.of("abc"), InputType.INTERNAL_INGEST, @@ -40,6 +57,23 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {"input":["abc"],"model":"model","embedding_type":"float","task":"retrieval.passage"}""")); } + public void testXContent_WritesOnlyLateChunkingField_WhenItIsTheOnlyOptionalFieldDefined() throws IOException { + var entity = new JinaAIEmbeddingsRequestEntity( + List.of("abc"), + InputType.INTERNAL_INGEST, + new JinaAIEmbeddingsTaskSettings(null, false), + "model", + null + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","late_chunking":false,"task":"retrieval.passage"}""")); + } + public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( List.of("abc"),