77
88import dev .langchain4j .Experimental ;
99import dev .langchain4j .community .data .document .graph .GraphDocument ;
10+ import dev .langchain4j .community .data .document .graph .GraphNode ;
11+ import dev .langchain4j .community .store .embedding .neo4j .Neo4jEmbeddingStore ;
1012import dev .langchain4j .data .document .Document ;
13+ import dev .langchain4j .data .document .Metadata ;
14+ import dev .langchain4j .data .embedding .Embedding ;
15+ import dev .langchain4j .data .segment .TextSegment ;
16+ import dev .langchain4j .model .embedding .EmbeddingModel ;
17+ import java .util .ArrayList ;
1118import java .util .HashMap ;
1219import java .util .List ;
1320import java .util .Map ;
@@ -34,23 +41,38 @@ public class KnowledgeGraphWriter {
3441 final String sanitizedTextProperty ;
3542
3643 private final Neo4jGraph graph ;
44+ private final Neo4jEmbeddingStore embeddingStore ;
45+ private EmbeddingModel embeddingModel = null ;
3746
3847 public KnowledgeGraphWriter (
3948 Neo4jGraph graph ,
4049 String idProperty ,
4150 String label ,
4251 String textProperty ,
4352 String relType ,
44- String constraintName ) {
53+ String constraintName ,
54+ Neo4jEmbeddingStore embeddingStore ,
55+ EmbeddingModel embeddingModel ) {
4556 this .graph = ensureNotNull (graph , "graph" );
57+
58+ this .embeddingStore = embeddingStore ;
59+ final boolean storeIsNull = this .embeddingStore == null ;
60+ if (!storeIsNull ) {
61+ this .embeddingModel = ensureNotNull (embeddingModel , "embeddingModel" );
62+ }
63+
4664 this .label = getOrDefault (label , DEFAULT_LABEL );
4765 this .relType = getOrDefault (relType , DEFAULT_REL_TYPE );
66+
4867 this .idProperty = getOrDefault (idProperty , DEFAULT_ID_PROP );
4968 this .textProperty = getOrDefault (textProperty , DEFAULT_TEXT_PROP );
5069 this .constraintName = getOrDefault (constraintName , DEFAULT_CONS_NAME );
5170
5271 /* sanitize labels and property names, to prevent from Cypher Injections */
53- this .sanitizedLabel = sanitizeOrThrows (this .label , "label" );
72+
73+ // if embeddingStore then label is taken from there getSanitizedLabel()
74+ this .sanitizedLabel =
75+ storeIsNull ? sanitizeOrThrows (this .label , "label" ) : this .embeddingStore .getSanitizedLabel ();
5476 this .sanitizedRelType = sanitizeOrThrows (this .relType , "relType" );
5577 this .sanitizedIdProperty = sanitizeOrThrows (this .idProperty , "idProperty" );
5678 this .sanitizedTextProperty = sanitizeOrThrows (this .textProperty , "textProperty" );
@@ -78,9 +100,10 @@ public void addGraphDocuments(List<GraphDocument> graphDocuments, boolean includ
78100
79101 // Import nodes
80102 Map <String , Object > nodeParams = new HashMap <>();
81- nodeParams .put (
82- "data" , graphDoc .nodes ().stream ().map (Neo4jUtils ::toMap ).toList ());
83-
103+ if (embeddingStore == null ) {
104+ nodeParams .put (
105+ "rows" , graphDoc .nodes ().stream ().map (Neo4jUtils ::toMap ).toList ());
106+ }
84107 if (includeSource ) {
85108 // create a copyOf metadata, not to update existing graphDoc,
86109 // subsequent tests could potentially fail
@@ -93,8 +116,7 @@ public void addGraphDocuments(List<GraphDocument> graphDocuments, boolean includ
93116 nodeParams .put ("document" , document );
94117 }
95118
96- String nodeImportQuery = getNodeImportQuery (includeSource );
97- graph .executeWrite (nodeImportQuery , nodeParams );
119+ insertNodes (includeSource , graphDoc , nodeParams );
98120
99121 // Import relationships
100122 List <Map <String , String >> relData = graphDoc .relationships ().stream ()
@@ -105,23 +127,60 @@ public void addGraphDocuments(List<GraphDocument> graphDocuments, boolean includ
105127 "target_label" , rel .targetNode ().type (),
106128 "type" , rel .type ().replace (" " , "_" ).toUpperCase ()))
107129 .toList ();
108-
109130 String relImportQuery = getRelImportQuery ();
110131 graph .executeWrite (relImportQuery , Map .of ("data" , relData ));
111132 }
112133 }
113134
135+ private void insertNodes (boolean includeSource , GraphDocument graphDoc , Map <String , Object > nodeParams ) {
136+ if (embeddingStore == null ) {
137+ String nodeImportQuery = getNodeImportQuery (includeSource );
138+ graph .executeWrite (nodeImportQuery , nodeParams );
139+ return ;
140+ }
141+
142+ if (includeSource ) {
143+ final String creationQuery = mergeSourceWithDocs (true )
144+ + """
145+ SET source += row.%3$s
146+ WITH row, source
147+ CALL db.create.setNodeVectorProperty(source, $embeddingProperty, row.%4$s)
148+ RETURN count(*)""" ;
149+ embeddingStore .setEntityCreationQuery (creationQuery );
150+ embeddingStore .setAdditionalParams (nodeParams );
151+ }
152+
153+ // we save the ids, otherwise it create UUID properties and the merge with import relationships doesn't work
154+ List <String > ids = new ArrayList <>();
155+ List <TextSegment > segments = new ArrayList <>();
156+ for (GraphNode node : graphDoc .nodes ()) {
157+ final Map <String , String > properties = new HashMap <>(node .properties ());
158+ properties .put ("type" , node .type ());
159+ final String id = node .id ();
160+ final TextSegment segment = TextSegment .from (id , Metadata .from (properties ));
161+ ids .add (id );
162+ segments .add (segment );
163+ }
164+
165+ final List <Embedding > embeddings = embeddingModel .embedAll (segments ).content ();
166+
167+ this .embeddingStore .addAll (ids , embeddings , segments );
168+ }
169+
114170 private String getNodeImportQuery (boolean includeSource ) {
115171
172+ return mergeSourceWithDocs (includeSource ) + "WITH source, row "
173+ + "SET source:$(row.type) "
174+ + "RETURN count(*) as total" ;
175+ }
176+
177+ private String mergeSourceWithDocs (boolean includeSource ) {
116178 String includeDocsQuery = getIncludeDocsQuery (includeSource );
117179 final String withDocsRel = includeSource ? String .format ("MERGE (d)-[:%s]->(source) " , relType ) : "" ;
118180
119- return includeDocsQuery + "UNWIND $data AS row "
120- + String .format ("MERGE (source:%1$s {%2$s: row.id}) " , sanitizedLabel , sanitizedIdProperty )
121- + withDocsRel
122- + "WITH source, row "
123- + "SET source:$(row.type) "
124- + "RETURN count(*) as total" ;
181+ return includeDocsQuery + "UNWIND $rows AS row \n "
182+ + String .format ("MERGE (source:%1$s {%2$s: row.id}) \n " , sanitizedLabel , sanitizedIdProperty )
183+ + withDocsRel ;
125184 }
126185
127186 private String getIncludeDocsQuery (boolean includeSource ) {
@@ -160,6 +219,8 @@ public static class Builder {
160219 private String relType ;
161220 private String constraintName ;
162221 private Neo4jGraph graph ;
222+ private Neo4jEmbeddingStore embeddingStore ;
223+ private EmbeddingModel embeddingModel ;
163224
164225 /**
165226 * @param graph the {@link Neo4jGraph} (required)
@@ -212,8 +273,30 @@ public Builder constraintName(String constraintName) {
212273 return this ;
213274 }
214275
276+ /**
277+ * Sets the optional embedding store used to store texts as vectors via
278+ * {@link Neo4jEmbeddingStore#add(dev.langchain4j.data.embedding.Embedding)}.
279+ *
280+ * @param embeddingStore the {@link Neo4jEmbeddingStore} instance to store vector embeddings (optional)
281+ */
282+ public Builder embeddingStore (Neo4jEmbeddingStore embeddingStore ) {
283+ this .embeddingStore = embeddingStore ;
284+ return this ;
285+ }
286+
287+ /**
288+ * Sets the embedding model to be used for embedding text, if {@code embeddingStore} is provided.
289+ *
290+ * @param embeddingModel the {@link EmbeddingModel} used to generate embeddings
291+ */
292+ public Builder embeddingModel (EmbeddingModel embeddingModel ) {
293+ this .embeddingModel = embeddingModel ;
294+ return this ;
295+ }
296+
215297 public KnowledgeGraphWriter build () {
216- return new KnowledgeGraphWriter (graph , idProperty , label , textProperty , relType , constraintName );
298+ return new KnowledgeGraphWriter (
299+ graph , idProperty , label , textProperty , relType , constraintName , embeddingStore , embeddingModel );
217300 }
218301 }
219302}
0 commit comments