Skip to content

Commit 065c581

Browse files
authored
Add the ability to store the KnowledgeGraphWriter results in vector indexes via Neo4jEmbeddingStore (#217)
1 parent 56e42d8 commit 065c581

File tree

4 files changed

+203
-17
lines changed

4 files changed

+203
-17
lines changed

content-retrievers/langchain4j-community-neo4j-retriever/src/main/java/dev/langchain4j/community/rag/content/retriever/neo4j/KnowledgeGraphWriter.java

Lines changed: 98 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77

88
import dev.langchain4j.Experimental;
99
import dev.langchain4j.community.data.document.graph.GraphDocument;
10+
import dev.langchain4j.community.data.document.graph.GraphNode;
11+
import dev.langchain4j.community.store.embedding.neo4j.Neo4jEmbeddingStore;
1012
import dev.langchain4j.data.document.Document;
13+
import dev.langchain4j.data.document.Metadata;
14+
import dev.langchain4j.data.embedding.Embedding;
15+
import dev.langchain4j.data.segment.TextSegment;
16+
import dev.langchain4j.model.embedding.EmbeddingModel;
17+
import java.util.ArrayList;
1118
import java.util.HashMap;
1219
import java.util.List;
1320
import java.util.Map;
@@ -34,23 +41,38 @@ public class KnowledgeGraphWriter {
3441
final String sanitizedTextProperty;
3542

3643
private final Neo4jGraph graph;
44+
private final Neo4jEmbeddingStore embeddingStore;
45+
private EmbeddingModel embeddingModel = null;
3746

3847
public KnowledgeGraphWriter(
3948
Neo4jGraph graph,
4049
String idProperty,
4150
String label,
4251
String textProperty,
4352
String relType,
44-
String constraintName) {
53+
String constraintName,
54+
Neo4jEmbeddingStore embeddingStore,
55+
EmbeddingModel embeddingModel) {
4556
this.graph = ensureNotNull(graph, "graph");
57+
58+
this.embeddingStore = embeddingStore;
59+
final boolean storeIsNull = this.embeddingStore == null;
60+
if (!storeIsNull) {
61+
this.embeddingModel = ensureNotNull(embeddingModel, "embeddingModel");
62+
}
63+
4664
this.label = getOrDefault(label, DEFAULT_LABEL);
4765
this.relType = getOrDefault(relType, DEFAULT_REL_TYPE);
66+
4867
this.idProperty = getOrDefault(idProperty, DEFAULT_ID_PROP);
4968
this.textProperty = getOrDefault(textProperty, DEFAULT_TEXT_PROP);
5069
this.constraintName = getOrDefault(constraintName, DEFAULT_CONS_NAME);
5170

5271
/* sanitize labels and property names, to prevent from Cypher Injections */
53-
this.sanitizedLabel = sanitizeOrThrows(this.label, "label");
72+
73+
// if embeddingStore then label is taken from there getSanitizedLabel()
74+
this.sanitizedLabel =
75+
storeIsNull ? sanitizeOrThrows(this.label, "label") : this.embeddingStore.getSanitizedLabel();
5476
this.sanitizedRelType = sanitizeOrThrows(this.relType, "relType");
5577
this.sanitizedIdProperty = sanitizeOrThrows(this.idProperty, "idProperty");
5678
this.sanitizedTextProperty = sanitizeOrThrows(this.textProperty, "textProperty");
@@ -78,9 +100,10 @@ public void addGraphDocuments(List<GraphDocument> graphDocuments, boolean includ
78100

79101
// Import nodes
80102
Map<String, Object> nodeParams = new HashMap<>();
81-
nodeParams.put(
82-
"data", graphDoc.nodes().stream().map(Neo4jUtils::toMap).toList());
83-
103+
if (embeddingStore == null) {
104+
nodeParams.put(
105+
"rows", graphDoc.nodes().stream().map(Neo4jUtils::toMap).toList());
106+
}
84107
if (includeSource) {
85108
// create a copyOf metadata, not to update existing graphDoc,
86109
// subsequent tests could potentially fail
@@ -93,8 +116,7 @@ public void addGraphDocuments(List<GraphDocument> graphDocuments, boolean includ
93116
nodeParams.put("document", document);
94117
}
95118

96-
String nodeImportQuery = getNodeImportQuery(includeSource);
97-
graph.executeWrite(nodeImportQuery, nodeParams);
119+
insertNodes(includeSource, graphDoc, nodeParams);
98120

99121
// Import relationships
100122
List<Map<String, String>> relData = graphDoc.relationships().stream()
@@ -105,23 +127,60 @@ public void addGraphDocuments(List<GraphDocument> graphDocuments, boolean includ
105127
"target_label", rel.targetNode().type(),
106128
"type", rel.type().replace(" ", "_").toUpperCase()))
107129
.toList();
108-
109130
String relImportQuery = getRelImportQuery();
110131
graph.executeWrite(relImportQuery, Map.of("data", relData));
111132
}
112133
}
113134

135+
private void insertNodes(boolean includeSource, GraphDocument graphDoc, Map<String, Object> nodeParams) {
136+
if (embeddingStore == null) {
137+
String nodeImportQuery = getNodeImportQuery(includeSource);
138+
graph.executeWrite(nodeImportQuery, nodeParams);
139+
return;
140+
}
141+
142+
if (includeSource) {
143+
final String creationQuery = mergeSourceWithDocs(true)
144+
+ """
145+
SET source += row.%3$s
146+
WITH row, source
147+
CALL db.create.setNodeVectorProperty(source, $embeddingProperty, row.%4$s)
148+
RETURN count(*)""";
149+
embeddingStore.setEntityCreationQuery(creationQuery);
150+
embeddingStore.setAdditionalParams(nodeParams);
151+
}
152+
153+
// we save the ids, otherwise it create UUID properties and the merge with import relationships doesn't work
154+
List<String> ids = new ArrayList<>();
155+
List<TextSegment> segments = new ArrayList<>();
156+
for (GraphNode node : graphDoc.nodes()) {
157+
final Map<String, String> properties = new HashMap<>(node.properties());
158+
properties.put("type", node.type());
159+
final String id = node.id();
160+
final TextSegment segment = TextSegment.from(id, Metadata.from(properties));
161+
ids.add(id);
162+
segments.add(segment);
163+
}
164+
165+
final List<Embedding> embeddings = embeddingModel.embedAll(segments).content();
166+
167+
this.embeddingStore.addAll(ids, embeddings, segments);
168+
}
169+
114170
private String getNodeImportQuery(boolean includeSource) {
115171

172+
return mergeSourceWithDocs(includeSource) + "WITH source, row "
173+
+ "SET source:$(row.type) "
174+
+ "RETURN count(*) as total";
175+
}
176+
177+
private String mergeSourceWithDocs(boolean includeSource) {
116178
String includeDocsQuery = getIncludeDocsQuery(includeSource);
117179
final String withDocsRel = includeSource ? String.format("MERGE (d)-[:%s]->(source) ", relType) : "";
118180

119-
return includeDocsQuery + "UNWIND $data AS row "
120-
+ String.format("MERGE (source:%1$s {%2$s: row.id}) ", sanitizedLabel, sanitizedIdProperty)
121-
+ withDocsRel
122-
+ "WITH source, row "
123-
+ "SET source:$(row.type) "
124-
+ "RETURN count(*) as total";
181+
return includeDocsQuery + "UNWIND $rows AS row \n"
182+
+ String.format("MERGE (source:%1$s {%2$s: row.id}) \n", sanitizedLabel, sanitizedIdProperty)
183+
+ withDocsRel;
125184
}
126185

127186
private String getIncludeDocsQuery(boolean includeSource) {
@@ -160,6 +219,8 @@ public static class Builder {
160219
private String relType;
161220
private String constraintName;
162221
private Neo4jGraph graph;
222+
private Neo4jEmbeddingStore embeddingStore;
223+
private EmbeddingModel embeddingModel;
163224

164225
/**
165226
* @param graph the {@link Neo4jGraph} (required)
@@ -212,8 +273,30 @@ public Builder constraintName(String constraintName) {
212273
return this;
213274
}
214275

276+
/**
277+
* Sets the optional embedding store used to store texts as vectors via
278+
* {@link Neo4jEmbeddingStore#add(dev.langchain4j.data.embedding.Embedding)}.
279+
*
280+
* @param embeddingStore the {@link Neo4jEmbeddingStore} instance to store vector embeddings (optional)
281+
*/
282+
public Builder embeddingStore(Neo4jEmbeddingStore embeddingStore) {
283+
this.embeddingStore = embeddingStore;
284+
return this;
285+
}
286+
287+
/**
288+
* Sets the embedding model to be used for embedding text, if {@code embeddingStore} is provided.
289+
*
290+
* @param embeddingModel the {@link EmbeddingModel} used to generate embeddings
291+
*/
292+
public Builder embeddingModel(EmbeddingModel embeddingModel) {
293+
this.embeddingModel = embeddingModel;
294+
return this;
295+
}
296+
215297
public KnowledgeGraphWriter build() {
216-
return new KnowledgeGraphWriter(graph, idProperty, label, textProperty, relType, constraintName);
298+
return new KnowledgeGraphWriter(
299+
graph, idProperty, label, textProperty, relType, constraintName, embeddingStore, embeddingModel);
217300
}
218301
}
219302
}

content-retrievers/langchain4j-community-neo4j-retriever/src/test/java/dev/langchain4j/community/rag/content/retriever/neo4j/Neo4jKnowledgeGraphWriterBaseTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ abstract class Neo4jKnowledgeGraphWriterBaseTest {
8787
public static final String VALUE_KEANU = "value3";
8888
public static String USERNAME = "neo4j";
8989
public static String ADMIN_PASSWORD = "adminPass";
90-
private static final String NEO4J_VERSION = System.getProperty("neo4jVersion", "2025.01.0-enterprise");
90+
private static final String NEO4J_VERSION = System.getProperty("neo4jVersion", "2025.04.0-enterprise");
9191

9292
public static String CAT_ON_THE_TABLE = "Sylvester the cat is on the table";
9393
public static String KEANU_REEVES_ACTED = "Keanu Reeves acted in Matrix";

content-retrievers/langchain4j-community-neo4j-retriever/src/test/java/dev/langchain4j/community/rag/content/retriever/neo4j/Neo4jKnowledgeGraphWriterTest.java

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@
55
import static org.mockito.ArgumentMatchers.argThat;
66
import static org.mockito.Mockito.when;
77

8+
import dev.langchain4j.community.store.embedding.neo4j.Neo4jEmbeddingStore;
9+
import dev.langchain4j.data.embedding.Embedding;
810
import dev.langchain4j.data.message.AiMessage;
911
import dev.langchain4j.data.message.ChatMessage;
12+
import dev.langchain4j.data.segment.TextSegment;
1013
import dev.langchain4j.model.chat.ChatModel;
1114
import dev.langchain4j.model.chat.response.ChatResponse;
15+
import dev.langchain4j.model.embedding.EmbeddingModel;
16+
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
17+
import dev.langchain4j.store.embedding.EmbeddingMatch;
18+
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
1219
import java.util.List;
1320
import org.junit.jupiter.api.Test;
1421
import org.junit.jupiter.api.extension.ExtendWith;
@@ -56,4 +63,96 @@ void testWrongConstraintName() {
5663
assertThat(e.getMessage()).contains("Error executing query: CREATE CONSTRAINT 111");
5764
}
5865
}
66+
67+
private static final String LABEL_CUSTOM = "Custom";
68+
private static final EmbeddingModel EMBEDDING_MODEL = new AllMiniLmL6V2QuantizedEmbeddingModel();
69+
70+
@Test
71+
void testKnowledgeGraphWithEmbeddingStoreAndNullEmbeddingModel() {
72+
73+
final Neo4jEmbeddingStore embeddingStore = Neo4jEmbeddingStore.builder()
74+
.withBasicAuth(neo4jContainer.getBoltUrl(), USERNAME, ADMIN_PASSWORD)
75+
.dimension(384)
76+
.label(LABEL_CUSTOM)
77+
.build();
78+
79+
try {
80+
knowledgeGraphWriter = KnowledgeGraphWriter.builder()
81+
.graph(neo4jGraph)
82+
.embeddingStore(embeddingStore)
83+
.build();
84+
fail("Should fail due to null embeddingModel");
85+
} catch (Exception e) {
86+
assertThat(e.getMessage()).contains("embeddingModel cannot be null");
87+
}
88+
}
89+
90+
@Test
91+
void testKnowledgeGraphWithEmbeddingStore() {
92+
93+
final Neo4jEmbeddingStore embeddingStore = Neo4jEmbeddingStore.builder()
94+
.withBasicAuth(neo4jContainer.getBoltUrl(), USERNAME, ADMIN_PASSWORD)
95+
.dimension(384)
96+
.label(LABEL_CUSTOM)
97+
.build();
98+
99+
testKnowledgeGraphWithEmbeddingStoreCommon(embeddingStore, false);
100+
}
101+
102+
@Test
103+
void testKnowledgeGraphWithEmbeddingStoreAndIncludeSource() {
104+
final Neo4jEmbeddingStore embeddingStore = Neo4jEmbeddingStore.builder()
105+
.withBasicAuth(neo4jContainer.getBoltUrl(), USERNAME, ADMIN_PASSWORD)
106+
.dimension(384)
107+
.label(LABEL_CUSTOM)
108+
.build();
109+
110+
testKnowledgeGraphWithEmbeddingStoreCommon(embeddingStore, true);
111+
}
112+
113+
@Test
114+
void testKnowledgeGraphWithEmbeddingStoreRetrievalQueryAndIncludeSource() {
115+
116+
final Neo4jEmbeddingStore embeddingStore = Neo4jEmbeddingStore.builder()
117+
.withBasicAuth(neo4jContainer.getBoltUrl(), USERNAME, ADMIN_PASSWORD)
118+
.dimension(384)
119+
.label(LABEL_CUSTOM)
120+
.retrievalQuery(
121+
"""
122+
MATCH (node)<-[r:HAS_ENTITY]-(d:Document)
123+
WITH d, collect(DISTINCT {chunk: node, score: score}) AS chunks, avg(score) as avg_score
124+
RETURN d.text AS text, avg_score AS score, properties(d) AS metadata
125+
ORDER BY score DESC
126+
LIMIT $maxResults
127+
""")
128+
.build();
129+
130+
testKnowledgeGraphWithEmbeddingStoreCommon(embeddingStore, true);
131+
}
132+
133+
private static void testKnowledgeGraphWithEmbeddingStoreCommon(
134+
Neo4jEmbeddingStore embeddingStore, boolean includeSource) {
135+
final String text = "keanu reeves";
136+
final Embedding queryEmbedding = EMBEDDING_MODEL.embed(text).content();
137+
final EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
138+
.queryEmbedding(queryEmbedding)
139+
.minScore(0.9)
140+
.build();
141+
final List<EmbeddingMatch<TextSegment>> matchesBefore =
142+
embeddingStore.search(request).matches();
143+
assertThat(matchesBefore).isEmpty();
144+
145+
knowledgeGraphWriter = KnowledgeGraphWriter.builder()
146+
.graph(neo4jGraph)
147+
.embeddingStore(embeddingStore)
148+
.embeddingModel(EMBEDDING_MODEL)
149+
.build();
150+
151+
knowledgeGraphWriter.addGraphDocuments(graphDocs, includeSource);
152+
153+
final List<EmbeddingMatch<TextSegment>> matches =
154+
embeddingStore.search(request).matches();
155+
assertThat(matches).hasSize(1);
156+
assertThat(matches.get(0).embedded().text()).containsIgnoringCase(text);
157+
}
59158
}

embedding-stores/langchain4j-community-neo4j/src/main/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStore.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ public class Neo4jEmbeddingStore implements EmbeddingStore<TextSegment> {
131131
private final String sanitizedLabel;
132132
private final String textProperty;
133133
private final String retrievalQuery;
134-
private final String entityCreationQuery;
134+
private String entityCreationQuery;
135135
private final Set<String> notMetaKeys;
136136
private Map<String, Object> additionalParams;
137137

@@ -282,6 +282,10 @@ public void setAdditionalParams(final Map<String, Object> additionalParams) {
282282
this.additionalParams = additionalParams;
283283
}
284284

285+
public void setEntityCreationQuery(final String entityCreationQuery) {
286+
this.entityCreationQuery = entityCreationQuery;
287+
}
288+
285289
/*
286290
Methods with `@Override`
287291
*/

0 commit comments

Comments
 (0)