Skip to content

Commit 5463b41

Browse files
committed
MLE-26953 Chunks now capture model name
This is using embeddingModel.getModelName() in the LangChain4j API, and it will allow for Nuclia integration to easily add the model name found in each chunk response.
1 parent d548ee4 commit 5463b41

File tree

14 files changed

+108
-47
lines changed

14 files changed

+108
-47
lines changed

marklogic-spark-connector/src/main/java/com/marklogic/langchain4j/embedding/EmbeddingGenerator.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
2+
* Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
33
*/
44
package com.marklogic.langchain4j.embedding;
55

@@ -29,6 +29,7 @@ public class EmbeddingGenerator implements EmbeddingProducer {
2929
private final EmbeddingModel embeddingModel;
3030
private final int batchSize;
3131
private final String prompt;
32+
private final String modelName;
3233

3334
// Only used for debug logging.
3435
private static final AtomicLong tokenCount = new AtomicLong(0);
@@ -38,6 +39,7 @@ public EmbeddingGenerator(EmbeddingModel embeddingModel, int batchSize, String p
3839
this.embeddingModel = embeddingModel;
3940
this.batchSize = batchSize;
4041
this.prompt = prompt;
42+
this.modelName = embeddingModel.modelName();
4143
}
4244

4345
@Override
@@ -62,7 +64,7 @@ private int generateAndAddEmbeddings(List<TextSegment> segments, List<Chunk> chu
6264
List<Embedding> embeddings = generateEmbeddings(segments);
6365
for (int i = 0; i < embeddings.size(); i++) {
6466
Embedding embedding = embeddings.get(i);
65-
chunks.get(chunkCounter).addEmbedding(embedding.vector());
67+
chunks.get(chunkCounter).addEmbedding(embedding.vector(), modelName);
6668
chunkCounter++;
6769
}
6870
return chunkCounter;

marklogic-spark-connector/src/main/java/com/marklogic/spark/core/ChunkInputs.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ public class ChunkInputs {
1212
private final String text;
1313
private float[] embedding;
1414
private byte[] classification;
15+
private String modelName;
1516

1617
public ChunkInputs(String text) {
1718
this.text = text;
@@ -36,4 +37,12 @@ public byte[] getClassification() {
3637
public void setClassification(byte[] classification) {
3738
this.classification = classification;
3839
}
40+
41+
public String getModelName() {
42+
return modelName;
43+
}
44+
45+
public void setModelName(String modelName) {
46+
this.modelName = modelName;
47+
}
3948
}

marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentInputs.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,15 @@ public void addChunkClassification(byte[] classification) {
8787
throw new IllegalStateException("Cannot add classification: all chunks already have classifications");
8888
}
8989

90-
public void addEmbedding(float[] embedding) {
90+
public void addEmbedding(float[] embedding, String modelName) {
9191
if (chunkInputsList == null || chunkInputsList.isEmpty()) {
9292
throw new IllegalStateException("Cannot add embedding: no chunks exist");
9393
}
9494
// Find the next chunk without an embedding
9595
for (ChunkInputs chunk : chunkInputsList) {
9696
if (chunk.getEmbedding() == null) {
9797
chunk.setEmbedding(embedding);
98+
chunk.setModelName(modelName);
9899
return;
99100
}
100101
}

marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentPipeline.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
2+
* Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
33
*/
44
package com.marklogic.spark.core;
55

@@ -119,8 +119,8 @@ public String getEmbeddingText() {
119119
}
120120

121121
@Override
122-
public void addEmbedding(float[] embedding) {
123-
inputs.addEmbedding(embedding);
122+
public void addEmbedding(float[] embedding, String modelName) {
123+
inputs.addEmbedding(embedding, modelName);
124124
}
125125
}
126126

marklogic-spark-connector/src/main/java/com/marklogic/spark/core/embedding/Chunk.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
2+
* Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
33
*/
44
package com.marklogic.spark.core.embedding;
55

@@ -17,6 +17,7 @@ public interface Chunk {
1717
* Add the vector data in the given embedding to the chunk.
1818
*
1919
* @param embedding
20+
* @param modelName
2021
*/
21-
void addEmbedding(float[] embedding);
22+
void addEmbedding(float[] embedding, String modelName);
2223
}

marklogic-spark-connector/src/main/java/com/marklogic/spark/core/embedding/DOMChunk.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
2+
* Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
33
*/
44
package com.marklogic.spark.core.embedding;
55

@@ -56,7 +56,13 @@ public String getEmbeddingText() {
5656
}
5757

5858
@Override
59-
public void addEmbedding(float[] embedding) {
59+
public void addEmbedding(float[] embedding, String modelName) {
60+
if (modelName != null) {
61+
final Element modelNameElement = document.createElementNS(xmlChunkConfig.getEmbeddingNamespace(), "model-name");
62+
modelNameElement.setTextContent(modelName);
63+
chunkElement.appendChild(modelNameElement);
64+
}
65+
6066
// DOM is fine with null as a value for the namespace.
6167
final Element embeddingElement = document.createElementNS(xmlChunkConfig.getEmbeddingNamespace(), xmlChunkConfig.getEmbeddingName());
6268

marklogic-spark-connector/src/main/java/com/marklogic/spark/core/embedding/JsonChunk.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
2+
* Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
33
*/
44
package com.marklogic.spark.core.embedding;
55

@@ -34,7 +34,11 @@ public String getEmbeddingText() {
3434
}
3535

3636
@Override
37-
public void addEmbedding(float[] embedding) {
37+
public void addEmbedding(float[] embedding, String modelName) {
38+
if (modelName != null) {
39+
chunk.put("model-name", modelName);
40+
}
41+
3842
if (base64EncodeVectors) {
3943
String base64Vector = VectorUtil.base64Encode(embedding);
4044
chunk.put(this.embeddingArrayName, base64Vector);

marklogic-spark-connector/src/main/java/com/marklogic/spark/core/splitter/JsonChunkDocumentProducer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ protected DocumentWriteOperation addChunksToSourceDocument() {
5656
}
5757
var jsonChunk = new JsonChunk(chunk, null, chunkConfig.getEmbeddingName(), chunkConfig.isBase64EncodeVectors());
5858
if (chunkInputs.getEmbedding() != null) {
59-
jsonChunk.addEmbedding(chunkInputs.getEmbedding());
59+
jsonChunk.addEmbedding(chunkInputs.getEmbedding(), chunkInputs.getModelName());
6060
}
6161
chunks.add(jsonChunk);
6262
}
@@ -93,7 +93,7 @@ protected DocumentWriteOperation makeChunkDocument() {
9393
}
9494
var jsonChunk = new JsonChunk(chunk, null, chunkConfig.getEmbeddingName(), chunkConfig.isBase64EncodeVectors());
9595
if (chunkInputs.getEmbedding() != null) {
96-
jsonChunk.addEmbedding(chunkInputs.getEmbedding());
96+
jsonChunk.addEmbedding(chunkInputs.getEmbedding(), chunkInputs.getModelName());
9797
}
9898
chunks.add(jsonChunk);
9999
listIndex++;

marklogic-spark-connector/src/main/java/com/marklogic/spark/core/splitter/XmlChunkDocumentProducer.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ protected DocumentWriteOperation makeChunkDocument() {
7070
ChunkInputs chunkInputs = chunkInputsList.get(listIndex);
7171
Element classificationResponseNode = chunkInputs.getClassification() != null ?
7272
getClassificationResponseElement(chunkInputs.getClassification()) : null;
73-
addChunk(doc, chunkInputs.getText(), chunksElement, chunks, classificationResponseNode, chunkInputs.getEmbedding());
73+
addChunk(doc, chunkInputs.getText(), chunksElement, chunks, classificationResponseNode, chunkInputs.getEmbedding(), chunkInputs.getModelName());
7474
listIndex++;
7575
}
7676

@@ -91,7 +91,7 @@ protected DocumentWriteOperation addChunksToSourceDocument() {
9191
for (ChunkInputs chunkInputs : chunkInputsList) {
9292
Element classificationResponseNode = chunkInputs.getClassification() != null ?
9393
getClassificationResponseElement(chunkInputs.getClassification()) : null;
94-
addChunk(doc, chunkInputs.getText(), chunksElement, chunks, classificationResponseNode, chunkInputs.getEmbedding());
94+
addChunk(doc, chunkInputs.getText(), chunksElement, chunks, classificationResponseNode, chunkInputs.getEmbedding(), chunkInputs.getModelName());
9595
}
9696

9797
return new DocumentAndChunks(
@@ -110,7 +110,7 @@ private Element getClassificationResponseElement(byte[] classificationBytes) {
110110
}
111111
}
112112

113-
private void addChunk(Document doc, String textSegment, Element chunksElement, List<Chunk> chunks, Element classificationResponse, float[] embedding) {
113+
private void addChunk(Document doc, String textSegment, Element chunksElement, List<Chunk> chunks, Element classificationResponse, float[] embedding, String modelName) {
114114
Element chunk = doc.createElementNS(chunkConfig.getXmlNamespace(), "chunk");
115115
chunksElement.appendChild(chunk);
116116

@@ -129,7 +129,7 @@ private void addChunk(Document doc, String textSegment, Element chunksElement, L
129129

130130
var domChunk = new DOMChunk(doc, chunk, this.xmlChunkConfig, this.xPathFactory);
131131
if (embedding != null) {
132-
domChunk.addEmbedding(embedding);
132+
domChunk.addEmbedding(embedding, modelName);
133133
}
134134
chunks.add(domChunk);
135135
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/*
2+
* Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved.
3+
*/
4+
package com.marklogic.spark.writer.embedding;
5+
6+
import com.marklogic.spark.AbstractIntegrationTest;
7+
import org.junit.jupiter.api.AfterEach;
8+
9+
abstract class AbstractEmbeddingTest extends AbstractIntegrationTest {
10+
11+
static final String TEST_EMBEDDING_FUNCTION_CLASS = "com.marklogic.spark.writer.embedding.MinilmEmbeddingModelFunction";
12+
13+
// The minilm embedding model returns "unknown" as its model name.
14+
static final String EXPECTED_MODEL_NAME = "unknown";
15+
16+
@AfterEach
17+
void teardown() {
18+
TestEmbeddingModel.reset();
19+
}
20+
}

0 commit comments

Comments
 (0)