Skip to content

Commit 4121b28

Browse files
vga91Martin7-1
andauthored
Issue 111: GraphRAG retrieval concepts (#147)
* Issue 111: GraphRAG retrieval concepts * changed test * added comments * deleted retrievers * changes review * change variable names * decouple retriver from QA chain * make format * removed duplicated dependency * increased test coverage * changes review: added ParentChildEmbeddingStoreIngestor and Neo4jEmbeddingStoreIngestor * added and renamed tests * make format --------- Co-authored-by: Martin7-1 <[email protected]>
1 parent 8459627 commit 4121b28

File tree

17 files changed

+2137
-60
lines changed

17 files changed

+2137
-60
lines changed

content-retrievers/langchain4j-community-neo4j-retriever/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@
5353
<version>${neo4j.cypher.dsl.version}</version>
5454
</dependency>
5555

56+
<dependency>
57+
<groupId>dev.langchain4j</groupId>
58+
<artifactId>langchain4j-community-neo4j</artifactId>
59+
<version>${project.version}</version>
60+
</dependency>
61+
5662
<!-- test dependencies -->
5763
<dependency>
5864
<groupId>org.testcontainers</groupId>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package dev.langchain4j.community.rag.content.retriever.neo4j;
2+
3+
import org.junit.jupiter.api.AfterAll;
4+
import org.junit.jupiter.api.AfterEach;
5+
import org.junit.jupiter.api.BeforeAll;
6+
import org.neo4j.driver.AuthTokens;
7+
import org.neo4j.driver.Driver;
8+
import org.neo4j.driver.GraphDatabase;
9+
import org.neo4j.driver.Session;
10+
import org.testcontainers.containers.Neo4jContainer;
11+
import org.testcontainers.junit.jupiter.Container;
12+
13+
public class Neo4jContainerBaseTest {
14+
15+
protected static final String NEO4J_VERSION = System.getProperty("neo4jVersion", "5.26");
16+
17+
protected static Driver driver;
18+
19+
@Container
20+
protected static final Neo4jContainer<?> neo4jContainer = new Neo4jContainer<>("neo4j:" + NEO4J_VERSION)
21+
.withoutAuthentication()
22+
.withPlugins("apoc");
23+
24+
@BeforeAll
25+
static void beforeAll() {
26+
neo4jContainer.start();
27+
driver = GraphDatabase.driver(neo4jContainer.getBoltUrl(), AuthTokens.none());
28+
}
29+
30+
@AfterAll
31+
static void afterAll() {
32+
driver.close();
33+
neo4jContainer.stop();
34+
}
35+
36+
@AfterEach
37+
void afterEach() {
38+
try (Session session = driver.session()) {
39+
session.run("MATCH (n) DETACH DELETE n");
40+
}
41+
}
42+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
package dev.langchain4j.community.rag.content.retriever.neo4j;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.junit.jupiter.api.Assertions.assertEquals;
5+
import static org.junit.jupiter.api.Assertions.assertTrue;
6+
import static org.mockito.ArgumentMatchers.anyList;
7+
import static org.mockito.Mockito.when;
8+
9+
import dev.langchain4j.community.store.embedding.ParentChildEmbeddingStoreIngestor;
10+
import dev.langchain4j.community.store.embedding.neo4j.Neo4jEmbeddingStore;
11+
import dev.langchain4j.community.store.embedding.neo4j.Neo4jEmbeddingStoreIngestor;
12+
import dev.langchain4j.data.document.Document;
13+
import dev.langchain4j.data.document.DocumentSplitter;
14+
import dev.langchain4j.data.document.splitter.DocumentByRegexSplitter;
15+
import dev.langchain4j.data.message.AiMessage;
16+
import dev.langchain4j.model.chat.ChatModel;
17+
import dev.langchain4j.model.chat.response.ChatResponse;
18+
import dev.langchain4j.rag.content.Content;
19+
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
20+
import dev.langchain4j.rag.query.Query;
21+
import java.util.List;
22+
import java.util.Map;
23+
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.api.extension.ExtendWith;
25+
import org.mockito.Mock;
26+
import org.mockito.junit.jupiter.MockitoExtension;
27+
28+
@ExtendWith(MockitoExtension.class)
29+
public class Neo4jEmbeddingRetrieverTest extends Neo4jEmbeddingStoreIngestorBaseTest {
30+
31+
@Mock
32+
private ChatModel chatLanguageModel;
33+
34+
@Test
35+
public void testBasicRetriever() {
36+
Document parentDoc = getDocumentMiscTopics();
37+
38+
// Child splitter: splits into sentences using OpenNLP
39+
final String expectedQuery = "\\n\\n";
40+
int maxSegmentSize = 250;
41+
DocumentSplitter splitter = new DocumentByRegexSplitter(expectedQuery, expectedQuery, maxSegmentSize, 0);
42+
43+
final ParentChildEmbeddingStoreIngestor build = ParentChildEmbeddingStoreIngestor.builder()
44+
.documentSplitter(splitter)
45+
.embeddingStore(embeddingStore)
46+
.embeddingModel(embeddingModel)
47+
.build();
48+
49+
build.ingest(parentDoc);
50+
51+
// Query and validate results
52+
final String retrieveQuery = "fundamental theory";
53+
final EmbeddingStoreContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
54+
.embeddingStore(embeddingStore)
55+
.maxResults(1)
56+
.minScore(0.4)
57+
.build();
58+
List<Content> results = retriever.retrieve(Query.from(retrieveQuery));
59+
commonResults(results, retrieveQuery);
60+
}
61+
62+
@Test
63+
public void testRetrieverWithChatModel() {
64+
65+
final Neo4jEmbeddingStore neo4jEmbeddingStore = Neo4jEmbeddingStore.builder()
66+
.driver(driver)
67+
.retrievalQuery(CUSTOM_RETRIEVAL)
68+
.entityCreationQuery(CUSTOM_CREATION_QUERY)
69+
.label("Chunk")
70+
.indexName("chunk_embedding_index")
71+
.dimension(384)
72+
.build();
73+
74+
when(chatLanguageModel.chat(anyList()))
75+
.thenReturn(ChatResponse.builder()
76+
.aiMessage(AiMessage.aiMessage("Naruto"))
77+
.build());
78+
79+
final EmbeddingStoreContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
80+
.embeddingModel(embeddingModel)
81+
.embeddingStore(neo4jEmbeddingStore)
82+
.maxResults(2)
83+
.minScore(0.4)
84+
.build();
85+
86+
Document parentDoc = getDocumentMiscTopics();
87+
88+
// Child splitter: splits into sentences using OpenNLP
89+
final String expectedQuery = "\\n\\n";
90+
int maxSegmentSize = 250;
91+
DocumentSplitter splitter = new DocumentByRegexSplitter(expectedQuery, expectedQuery, maxSegmentSize, 0);
92+
93+
final Neo4jEmbeddingStoreIngestor build = Neo4jEmbeddingStoreIngestor.builder()
94+
.documentSplitter(splitter)
95+
.embeddingStore(neo4jEmbeddingStore)
96+
.embeddingModel(embeddingModel)
97+
.driver(driver)
98+
.query("CREATE (:MainDoc $metadata)")
99+
.questionModel(chatLanguageModel)
100+
.userPrompt("mock prompt user")
101+
.systemPrompt("mock prompt system")
102+
.build();
103+
build.ingest(parentDoc);
104+
final String retrieveQuery = "naruto";
105+
List<Content> results = retriever.retrieve(Query.from(retrieveQuery));
106+
commonResults(results, retrieveQuery);
107+
}
108+
109+
@Test
110+
void testRetrieverWithCustomRetrievalAndEmbeddingCreationQuery() {
111+
112+
final Neo4jEmbeddingStore neo4jEmbeddingStore = Neo4jEmbeddingStore.builder()
113+
.driver(driver)
114+
.retrievalQuery(CUSTOM_RETRIEVAL)
115+
.entityCreationQuery(CUSTOM_CREATION_QUERY)
116+
.label("Chunk")
117+
.indexName("chunk_embedding_index")
118+
.dimension(384)
119+
.build();
120+
121+
final EmbeddingStoreContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
122+
.embeddingModel(embeddingModel)
123+
.maxResults(5)
124+
.minScore(0.4)
125+
.embeddingStore(neo4jEmbeddingStore)
126+
.build();
127+
128+
Document doc = getDocumentAI();
129+
130+
// MainDoc splitter splits on paragraphs (double newlines)
131+
final String expectedQuery = "\\n\\n";
132+
int maxSegmentSize = 250;
133+
DocumentSplitter parentSplitter = new DocumentByRegexSplitter(expectedQuery, expectedQuery, maxSegmentSize, 0);
134+
135+
// Child splitter splits on periods (sentences)
136+
final String expectedQueryChild = "\\. ";
137+
DocumentSplitter childSplitter = new DocumentByRegexSplitter(expectedQueryChild, expectedQuery, 150, 0);
138+
139+
final Neo4jEmbeddingStoreIngestor build = Neo4jEmbeddingStoreIngestor.builder()
140+
.documentSplitter(parentSplitter)
141+
.documentChildSplitter(childSplitter)
142+
.driver(driver)
143+
.query("CREATE (:MainDoc $metadata)")
144+
.embeddingStore(neo4jEmbeddingStore)
145+
.embeddingModel(embeddingModel)
146+
.build();
147+
// Index the document into Neo4j as parent-child nodes
148+
build.ingest(doc);
149+
150+
final String retrieveQuery = "Machine Learning";
151+
List<Content> results = retriever.retrieve(Query.from(retrieveQuery));
152+
assertThat(results).hasSize(1);
153+
}
154+
155+
// TODO - change with cypher-dsl
156+
@Test
157+
void testRetrieverWithCustomRetrievalAndEmbeddingCreationQueryMainDocIdAndParams() {
158+
String customCreationQuery =
159+
"""
160+
UNWIND $rows AS row
161+
MATCH (p:MainDoc {customParentId: $customParentId})
162+
CREATE (p)-[:REFERS_TO]->(u:%1$s {%2$s: row.%2$s})
163+
SET u += row.%3$s
164+
WITH row, u
165+
CALL db.create.setNodeVectorProperty(u, $embeddingProperty, row.%4$s)
166+
RETURN count(*)""";
167+
final Neo4jEmbeddingStore neo4jEmbeddingStore = Neo4jEmbeddingStore.builder()
168+
.driver(driver)
169+
.retrievalQuery(CUSTOM_RETRIEVAL)
170+
.entityCreationQuery(customCreationQuery)
171+
.label("Chunk")
172+
.indexName("chunk_embedding_index")
173+
.dimension(384)
174+
.build();
175+
176+
final EmbeddingStoreContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
177+
.embeddingModel(embeddingModel)
178+
.maxResults(5)
179+
.minScore(0.4)
180+
.embeddingStore(neo4jEmbeddingStore)
181+
.build();
182+
183+
Document doc = getDocumentAI();
184+
185+
// MainDoc splitter splits on paragraphs (double newlines)
186+
final String expectedQuery = "\\n\\n";
187+
int maxSegmentSize = 250;
188+
DocumentSplitter parentSplitter = new DocumentByRegexSplitter(expectedQuery, expectedQuery, maxSegmentSize, 0);
189+
190+
// Child splitter splits on periods (sentences)
191+
final String expectedQueryChild = "\\. ";
192+
DocumentSplitter childSplitter =
193+
new DocumentByRegexSplitter(expectedQueryChild, expectedQuery, maxSegmentSize, 0);
194+
195+
final Neo4jEmbeddingStoreIngestor build = Neo4jEmbeddingStoreIngestor.builder()
196+
.documentSplitter(parentSplitter)
197+
.documentChildSplitter(childSplitter)
198+
.driver(driver)
199+
.query("CREATE (:MainDoc $metadata)")
200+
.parentIdKey("customParentId")
201+
.params(Map.of("customMainDocId", "foo", "bar", 1))
202+
.embeddingStore(neo4jEmbeddingStore)
203+
.embeddingModel(embeddingModel)
204+
.build();
205+
// Index the document into Neo4j as parent-child nodes
206+
build.ingest(doc);
207+
208+
final String retrieveQuery = "Machine Learning";
209+
List<Content> results = retriever.retrieve(Query.from(retrieveQuery));
210+
assertThat(results).hasSize(1);
211+
}
212+
213+
@Test
214+
void testRetrieverWithCustomRetrievalAndEmbeddingCreationQueryAndPreInsertedData() {
215+
216+
final Neo4jEmbeddingStore neo4jEmbeddingStore = Neo4jEmbeddingStore.builder()
217+
.driver(driver)
218+
.retrievalQuery(CUSTOM_RETRIEVAL)
219+
.entityCreationQuery(CUSTOM_CREATION_QUERY)
220+
.label("Chunk")
221+
.indexName("chunk_embedding_index")
222+
.dimension(384)
223+
.build();
224+
225+
seedMainDocAndChildData();
226+
227+
final EmbeddingStoreContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
228+
.embeddingModel(embeddingModel)
229+
.maxResults(5)
230+
.minScore(0.6)
231+
.embeddingStore(neo4jEmbeddingStore)
232+
.build();
233+
234+
// Act
235+
List<Content> results = retriever.retrieve(new Query("quantum physics"));
236+
237+
// Assert
238+
assertEquals(1, results.size());
239+
Content parent = results.get(0);
240+
241+
assertTrue(parent.textSegment().text().contains("quantum physics"));
242+
assertEquals("science", parent.textSegment().metadata().getString("source"));
243+
}
244+
}

0 commit comments

Comments
 (0)