Skip to content

Commit d548ee4

Browse files
committed
MLE-26953 Refactoring: DocumentInputs now has List<ChunkInputs>
This combines the 3 separate lists of text, classifications, and embeddings. Will make it much easier to add a model name. No change in functionality, just moving things around in the implementation.
1 parent f4a4cde commit d548ee4

File tree

9 files changed

+135
-164
lines changed

9 files changed

+135
-164
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
3+
*/
4+
package com.marklogic.spark.core;
5+
6+
/**
7+
* Encapsulates the data associated with a chunk of text, including its embedding and classification. Note there's
8+
* some naming issues to work out with this class and the Chunk interface.
9+
*/
10+
public class ChunkInputs {
11+
12+
private final String text;
13+
private float[] embedding;
14+
private byte[] classification;
15+
16+
public ChunkInputs(String text) {
17+
this.text = text;
18+
}
19+
20+
public String getText() {
21+
return text;
22+
}
23+
24+
public float[] getEmbedding() {
25+
return embedding;
26+
}
27+
28+
public void setEmbedding(float[] embedding) {
29+
this.embedding = embedding;
30+
}
31+
32+
public byte[] getClassification() {
33+
return classification;
34+
}
35+
36+
public void setClassification(byte[] classification) {
37+
this.classification = classification;
38+
}
39+
}

marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentInputs.java

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
2+
* Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
33
*/
44
package com.marklogic.spark.core;
55

@@ -30,11 +30,7 @@ public class DocumentInputs {
3030
private Map<String, String> extractedMetadata;
3131

3232
private byte[] documentClassification;
33-
private List<byte[]> chunkClassifications;
34-
private List<float[]> embeddings;
35-
36-
// These will be created via a splitter.
37-
private List<String> chunks;
33+
private List<ChunkInputs> chunkInputsList;
3834

3935
public DocumentInputs(String initialUri, AbstractWriteHandle content, JsonNode columnValuesForUriTemplate,
4036
DocumentMetadataHandle initialMetadata) {
@@ -78,17 +74,31 @@ public AbstractWriteHandle getContent() {
7874
}
7975

8076
public void addChunkClassification(byte[] classification) {
81-
if (chunkClassifications == null) {
82-
chunkClassifications = new ArrayList<>();
77+
if (chunkInputsList == null || chunkInputsList.isEmpty()) {
78+
throw new IllegalStateException("Cannot add classification: no chunks exist");
79+
}
80+
// Find the next chunk without a classification
81+
for (ChunkInputs chunk : chunkInputsList) {
82+
if (chunk.getClassification() == null) {
83+
chunk.setClassification(classification);
84+
return;
85+
}
8386
}
84-
chunkClassifications.add(classification);
87+
throw new IllegalStateException("Cannot add classification: all chunks already have classifications");
8588
}
8689

8790
public void addEmbedding(float[] embedding) {
88-
if (embeddings == null) {
89-
embeddings = new ArrayList<>();
91+
if (chunkInputsList == null || chunkInputsList.isEmpty()) {
92+
throw new IllegalStateException("Cannot add embedding: no chunks exist");
9093
}
91-
embeddings.add(embedding);
94+
// Find the next chunk without an embedding
95+
for (ChunkInputs chunk : chunkInputsList) {
96+
if (chunk.getEmbedding() == null) {
97+
chunk.setEmbedding(embedding);
98+
return;
99+
}
100+
}
101+
throw new IllegalStateException("Cannot add embedding: all chunks already have embeddings");
92102
}
93103

94104
public String getInitialUri() {
@@ -124,15 +134,25 @@ public void setExtractedMetadata(Map<String, String> extractedMetadata) {
124134
}
125135

126136
public List<String> getChunks() {
127-
return chunks;
137+
if (chunkInputsList == null) {
138+
return null;
139+
}
140+
List<String> texts = new ArrayList<>(chunkInputsList.size());
141+
for (ChunkInputs chunk : chunkInputsList) {
142+
texts.add(chunk.getText());
143+
}
144+
return texts;
128145
}
129146

130147
public void setChunks(List<String> chunks) {
131-
this.chunks = chunks;
132-
}
133-
134-
public List<byte[]> getClassifications() {
135-
return chunkClassifications;
148+
if (chunks == null) {
149+
this.chunkInputsList = null;
150+
} else {
151+
this.chunkInputsList = new ArrayList<>(chunks.size());
152+
for (String text : chunks) {
153+
this.chunkInputsList.add(new ChunkInputs(text));
154+
}
155+
}
136156
}
137157

138158
public byte[] getDocumentClassification() {
@@ -143,7 +163,7 @@ public void setDocumentClassification(byte[] documentClassification) {
143163
this.documentClassification = documentClassification;
144164
}
145165

146-
public List<float[]> getEmbeddings() {
147-
return embeddings;
166+
public List<ChunkInputs> getChunkInputsList() {
167+
return chunkInputsList;
148168
}
149169
}

marklogic-spark-connector/src/main/java/com/marklogic/spark/core/splitter/AbstractChunkDocumentProducer.java

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
/*
2-
* Copyright (c) 2023-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
2+
* Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
33
*/
44
package com.marklogic.spark.core.splitter;
55

66
import com.marklogic.client.document.DocumentWriteOperation;
77
import com.marklogic.client.io.Format;
8+
import com.marklogic.spark.core.ChunkInputs;
89

910
import java.util.Iterator;
1011
import java.util.List;
@@ -16,27 +17,23 @@
1617
abstract class AbstractChunkDocumentProducer implements Iterator<DocumentWriteOperation> {
1718

1819
protected final DocumentWriteOperation sourceDocument;
19-
protected final List<String> textSegments;
20+
protected final List<ChunkInputs> chunkInputsList;
2021
protected final ChunkConfig chunkConfig;
21-
protected final List<byte[]> classifications;
22-
protected final List<float[]> embeddings;
2322
protected final int maxChunksPerDocument;
2423

2524
protected int listIndex = -1;
2625
private int chunkDocumentCounter = 1;
2726

28-
AbstractChunkDocumentProducer(DocumentWriteOperation sourceDocument, Format sourceDocumentFormat, List<String> textSegments, ChunkConfig chunkConfig, List<byte[]> classifications, List<float[]> embeddings) {
27+
AbstractChunkDocumentProducer(DocumentWriteOperation sourceDocument, Format sourceDocumentFormat, List<ChunkInputs> chunkInputsList, ChunkConfig chunkConfig) {
2928
this.sourceDocument = sourceDocument;
30-
this.textSegments = textSegments;
29+
this.chunkInputsList = chunkInputsList;
3130
this.chunkConfig = chunkConfig;
32-
this.classifications = classifications;
33-
this.embeddings = embeddings;
3431

3532
// Chunks cannot be written to the source document unless its format is JSON or XML. So if maxChunks is zero and
3633
// we don't have a JSON or XML document, all chunks will be written to a separate document.
3734
boolean cannotAddChunksToSourceDocument = !Format.JSON.equals(sourceDocumentFormat) && !Format.XML.equals(sourceDocumentFormat);
3835
this.maxChunksPerDocument = cannotAddChunksToSourceDocument && chunkConfig.getMaxChunks() == 0 ?
39-
textSegments.size() :
36+
chunkInputsList.size() :
4037
chunkConfig.getMaxChunks();
4138
}
4239

@@ -47,7 +44,7 @@ abstract class AbstractChunkDocumentProducer implements Iterator<DocumentWriteOp
4744

4845
@Override
4946
public final boolean hasNext() {
50-
return listIndex < textSegments.size();
47+
return listIndex < chunkInputsList.size();
5148
}
5249

5350
// Sonar complains that a NoSuchElementException should be thrown here, but that would only occur if the
@@ -58,7 +55,7 @@ public DocumentWriteOperation next() {
5855
if (listIndex == -1) {
5956
listIndex++;
6057
if (this.maxChunksPerDocument == 0) {
61-
listIndex = textSegments.size();
58+
listIndex = chunkInputsList.size();
6259
return addChunksToSourceDocument();
6360
}
6461
return sourceDocument;
@@ -83,15 +80,4 @@ protected final String makeChunkDocumentUri(DocumentWriteOperation sourceDocumen
8380
}
8481
return uri;
8582
}
86-
87-
/**
88-
* Return the embedding at position n if it exists.
89-
* @param embeddings the embeddings list
90-
* @param n the position for the embedding requests
91-
* @return the embedding float array
92-
*/
93-
protected float[] getEmbeddingIfExists(List<float[]> embeddings, int n) {
94-
return (embeddings != null && n < embeddings.size() ? embeddings.get(n) : null);
95-
}
96-
9783
}
Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
/*
2-
* Copyright (c) 2023-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
2+
* Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
33
*/
44
package com.marklogic.spark.core.splitter;
55

66
import com.marklogic.client.document.DocumentWriteOperation;
7+
import com.marklogic.spark.core.ChunkInputs;
78

89
import java.util.Iterator;
910
import java.util.List;
@@ -15,11 +16,8 @@ public interface ChunkAssembler {
1516

1617
/**
1718
* @param sourceDocument
18-
* @param chunks
19-
* @param classifications
20-
* @param embeddings
19+
* @param chunkInputsList
2120
* @return an iterator, which allows for an implementation to lazily construct documents if necessary.
2221
*/
23-
Iterator<DocumentWriteOperation> assembleChunks(DocumentWriteOperation sourceDocument, List<String> chunks,
24-
List<byte[]> classifications, List<float[]> embeddings);
22+
Iterator<DocumentWriteOperation> assembleChunks(DocumentWriteOperation sourceDocument, List<ChunkInputs> chunkInputsList);
2523
}

marklogic-spark-connector/src/main/java/com/marklogic/spark/core/splitter/DefaultChunkAssembler.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
/*
2-
* Copyright (c) 2023-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
2+
* Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
33
*/
44
package com.marklogic.spark.core.splitter;
55

66
import com.marklogic.client.document.DocumentWriteOperation;
77
import com.marklogic.client.io.Format;
88
import com.marklogic.spark.Util;
9+
import com.marklogic.spark.core.ChunkInputs;
910

1011
import java.util.Iterator;
1112
import java.util.List;
@@ -20,7 +21,7 @@ public DefaultChunkAssembler(ChunkConfig chunkConfig) {
2021
}
2122

2223
@Override
23-
public Iterator<DocumentWriteOperation> assembleChunks(DocumentWriteOperation sourceDocument, List<String> textSegments, List<byte[]> classifications, List<float[]> embeddings) {
24+
public Iterator<DocumentWriteOperation> assembleChunks(DocumentWriteOperation sourceDocument, List<ChunkInputs> chunkInputsList) {
2425
final Format sourceDocumentFormat = Util.determineSourceDocumentFormat(sourceDocument.getContent(), sourceDocument.getUri());
2526
if (sourceDocumentFormat == null) {
2627
Util.MAIN_LOGGER.warn("Cannot split document with URI {}; cannot determine the document format.", sourceDocument.getUri());
@@ -30,8 +31,8 @@ public Iterator<DocumentWriteOperation> assembleChunks(DocumentWriteOperation so
3031
final Format chunkDocumentFormat = determineChunkDocumentFormat(sourceDocumentFormat);
3132

3233
return Format.XML.equals(chunkDocumentFormat) ?
33-
new XmlChunkDocumentProducer(sourceDocument, sourceDocumentFormat, textSegments, chunkConfig, classifications, embeddings) :
34-
new JsonChunkDocumentProducer(sourceDocument, sourceDocumentFormat, textSegments, chunkConfig, classifications, embeddings);
34+
new XmlChunkDocumentProducer(sourceDocument, sourceDocumentFormat, chunkInputsList, chunkConfig) :
35+
new JsonChunkDocumentProducer(sourceDocument, sourceDocumentFormat, chunkInputsList, chunkConfig);
3536
}
3637

3738
private Format determineChunkDocumentFormat(Format sourceDocumentFormat) {

marklogic-spark-connector/src/main/java/com/marklogic/spark/core/splitter/JsonChunkDocumentProducer.java

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
2+
* Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
33
*/
44
package com.marklogic.spark.core.splitter;
55

@@ -14,6 +14,7 @@
1414
import com.marklogic.client.io.JacksonHandle;
1515
import com.marklogic.client.io.marker.AbstractWriteHandle;
1616
import com.marklogic.spark.ConnectorException;
17+
import com.marklogic.spark.core.ChunkInputs;
1718
import com.marklogic.spark.core.embedding.Chunk;
1819
import com.marklogic.spark.core.embedding.DocumentAndChunks;
1920
import com.marklogic.spark.core.embedding.JsonChunk;
@@ -30,8 +31,8 @@ class JsonChunkDocumentProducer extends AbstractChunkDocumentProducer {
3031
private final XmlMapper xmlMapper;
3132

3233
JsonChunkDocumentProducer(DocumentWriteOperation sourceDocument, Format sourceDocumentFormat,
33-
List<String> textSegments, ChunkConfig chunkConfig, List<byte[]> classifications, List<float[]> embeddings) {
34-
super(sourceDocument, sourceDocumentFormat, textSegments, chunkConfig, classifications, embeddings);
34+
List<ChunkInputs> chunkInputsList, ChunkConfig chunkConfig) {
35+
super(sourceDocument, sourceDocumentFormat, chunkInputsList, chunkConfig);
3536
xmlMapper = new XmlMapper();
3637
}
3738

@@ -42,24 +43,21 @@ protected DocumentWriteOperation addChunksToSourceDocument() {
4243

4344
ArrayNode chunksArray = doc.putArray(determineChunksArrayName(doc));
4445
List<Chunk> chunks = new ArrayList<>();
45-
int chunksCounter = 0;
46-
for (String text : textSegments) {
46+
for (ChunkInputs chunkInputs : chunkInputsList) {
4747
ObjectNode chunk = chunksArray.addObject();
48-
chunk.put("text", text);
49-
if (classifications != null && classifications.size() > chunksCounter) {
48+
chunk.put("text", chunkInputs.getText());
49+
if (chunkInputs.getClassification() != null) {
5050
try {
51-
JsonNode classification = xmlMapper.readTree(classifications.get(chunksCounter));
51+
JsonNode classification = xmlMapper.readTree(chunkInputs.getClassification());
5252
chunk.set("classification", classification);
5353
} catch (IOException e) {
5454
throw new ConnectorException(String.format("Unable to classify data from document with URI: %s; cause: %s", sourceDocument.getUri(), e.getMessage()), e);
5555
}
5656
}
57-
float[] embedding = getEmbeddingIfExists(embeddings, chunksCounter);
5857
var jsonChunk = new JsonChunk(chunk, null, chunkConfig.getEmbeddingName(), chunkConfig.isBase64EncodeVectors());
59-
if (embedding != null) {
60-
jsonChunk.addEmbedding(embedding);
58+
if (chunkInputs.getEmbedding() != null) {
59+
jsonChunk.addEmbedding(chunkInputs.getEmbedding());
6160
}
62-
chunksCounter++;
6361
chunks.add(jsonChunk);
6462
}
6563

@@ -81,26 +79,23 @@ protected DocumentWriteOperation makeChunkDocument() {
8179

8280
ArrayNode chunksArray = rootField.putArray(DEFAULT_CHUNKS_ARRAY_NAME);
8381
List<Chunk> chunks = new ArrayList<>();
84-
int chunksCounter = 0;
8582
for (int i = 0; i < this.maxChunksPerDocument && hasNext(); i++) {
86-
String text = textSegments.get(listIndex);
83+
ChunkInputs chunkInputs = chunkInputsList.get(listIndex);
8784
ObjectNode chunk = chunksArray.addObject();
88-
chunk.put("text", text);
89-
if (classifications != null && classifications.size() > chunksCounter) {
85+
chunk.put("text", chunkInputs.getText());
86+
if (chunkInputs.getClassification() != null) {
9087
try {
91-
JsonNode classification = xmlMapper.readTree(classifications.get(chunksCounter));
88+
JsonNode classification = xmlMapper.readTree(chunkInputs.getClassification());
9289
chunk.set("classification", classification);
9390
} catch (IOException e) {
9491
throw new ConnectorException(String.format("Unable to classify data from document with URI: %s; cause: %s", uri, e.getMessage()), e);
9592
}
9693
}
97-
float[] embedding = getEmbeddingIfExists(embeddings, listIndex);
9894
var jsonChunk = new JsonChunk(chunk, null, chunkConfig.getEmbeddingName(), chunkConfig.isBase64EncodeVectors());
99-
if (embedding != null) {
100-
jsonChunk.addEmbedding(embedding);
95+
if (chunkInputs.getEmbedding() != null) {
96+
jsonChunk.addEmbedding(chunkInputs.getEmbedding());
10197
}
10298
chunks.add(jsonChunk);
103-
chunksCounter++;
10499
listIndex++;
105100
}
106101

0 commit comments

Comments
 (0)