Skip to content

Commit 336d8a7

Browse files
committed
test: refactor tests using AssertJ
1 parent a935882 commit 336d8a7

File tree

40 files changed

+320
-353
lines changed

40 files changed

+320
-353
lines changed

content-retrievers/langchain4j-community-lucene/src/test/java/dev/langchain4j/community/rag/content/retriever/lucene/EmbeddingSearchIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ void hybridQuery3() {
9191

9292
debugQuery(query, results);
9393

94-
assertThat(results).hasSize(0);
94+
assertThat(results).isEmpty();
9595
}
9696

9797
@BeforeEach

content-retrievers/langchain4j-community-lucene/src/test/java/dev/langchain4j/community/rag/content/retriever/lucene/FullTextSearchIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void queryContent() {
109109
results.stream().map(content -> content.textSegment().text()).collect(Collectors.toList());
110110
Collections.sort(actualTextSegments);
111111

112-
assertThat(results).hasSize(hits.length);
112+
assertThat(results).hasSameSizeAs(hits);
113113
assertThat(actualTextSegments).isEqualTo(expectedTextSegments);
114114
}
115115

@@ -200,7 +200,7 @@ void retrieverWithBadTokenCountField() {
200200
List<Content> results = contentRetriever.retrieve(query);
201201

202202
// No limiting by token count, since wrong field is used
203-
assertThat(results).hasSize(hits.length);
203+
assertThat(results).hasSameSizeAs(hits);
204204
}
205205

206206
@Test

content-retrievers/langchain4j-community-lucene/src/test/java/dev/langchain4j/community/rag/content/retriever/lucene/HybridSearchIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import org.slf4j.LoggerFactory;
2020

2121
class HybridSearchIT {
22-
22+
2323
private static final Logger log = LoggerFactory.getLogger(HybridSearchIT.class);
2424

2525
private static final TextEmbedding[] hits = {
@@ -103,7 +103,7 @@ void hybridQuery3() {
103103
List<Content> results = contentRetriever.retrieve(Query.from(queryText));
104104
debugQuery(query, results);
105105

106-
assertThat(results).hasSize(0);
106+
assertThat(results).isEmpty();
107107
}
108108

109109
@BeforeEach

content-retrievers/langchain4j-community-lucene/src/test/java/dev/langchain4j/community/rag/content/retriever/lucene/IndexerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void addAllEmbeddings() {
5353

5454
List<Content> results = contentRetriever.retrieve(query);
5555

56-
assertThat(results).hasSize(0);
56+
assertThat(results).isEmpty();
5757
}
5858

5959
@Test

content-retrievers/langchain4j-community-neo4j-retriever/src/test/java/dev/langchain4j/community/rag/content/retriever/neo4j/Neo4jEmbeddingStoreIngestorBaseTest.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package dev.langchain4j.community.rag.content.retriever.neo4j;
22

33
import static org.assertj.core.api.Assertions.assertThat;
4-
import static org.junit.jupiter.api.Assertions.assertEquals;
54

65
import dev.langchain4j.community.store.embedding.neo4j.Neo4jEmbeddingStore;
76
import dev.langchain4j.community.store.embedding.neo4j.Neo4jEmbeddingStoreIngestor;
@@ -43,7 +42,7 @@ WITH parent, collect(node.text) AS chunks, max(score) AS score
4342
protected static EmbeddingModel embeddingModel;
4443

4544
@BeforeAll
46-
public static void beforeAll() {
45+
static void beforeAll() {
4746
Neo4jContainerBaseTest.beforeAll();
4847

4948
embeddingStore =
@@ -130,8 +129,8 @@ protected static void commonResults(List<Content> results, String... retrieveQue
130129
Content result = results.get(0);
131130

132131
assertThat(result.textSegment().text().toLowerCase()).containsIgnoringWhitespaces(retrieveQuery);
133-
assertEquals("Wikipedia link", result.textSegment().metadata().getString("source"));
134-
assertEquals("https://example.com/ai", result.textSegment().metadata().getString("url"));
132+
assertThat(result.textSegment().metadata().getString("source")).isEqualTo("Wikipedia link");
133+
assertThat(result.textSegment().metadata().getString("url")).isEqualTo("https://example.com/ai");
135134
}
136135

137136
protected static EmbeddingStoreContentRetriever getEmbeddingStoreContentRetriever(

content-retrievers/langchain4j-community-neo4j-retriever/src/test/java/dev/langchain4j/community/rag/content/retriever/neo4j/Neo4jEmbeddingStoreIngestorIT.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2525

2626
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
27-
public class Neo4jEmbeddingStoreIngestorIT extends Neo4jEmbeddingStoreIngestorBaseTest {
27+
class Neo4jEmbeddingStoreIngestorIT extends Neo4jEmbeddingStoreIngestorBaseTest {
2828

2929
ChatModel chatModel = OpenAiChatModel.builder()
3030
.baseUrl(System.getenv("OPENAI_BASE_URL"))
@@ -36,7 +36,7 @@ public class Neo4jEmbeddingStoreIngestorIT extends Neo4jEmbeddingStoreIngestorBa
3636
.build();
3737

3838
@Test
39-
void testRetrieverWithCustomAnswerModelAndPrompt() {
39+
void retrieverWithCustomAnswerModelAndPrompt() {
4040
String promptAnswer =
4141
"""
4242
You are an assistant that helps to form nice and human
@@ -117,8 +117,7 @@ void testRetrieverWithCustomAnswerModelAndPrompt() {
117117
.build();
118118

119119
final String chainResult = chain.execute(Query.from(retrieveQuery));
120-
assertThat(chainResult).containsIgnoringCase("dattebayo");
121-
assertThat(chainResult).containsIgnoringCase("super saiyan");
120+
assertThat(chainResult).containsIgnoringCase("dattebayo").containsIgnoringCase("super saiyan");
122121

123122
RetrievalQAChain chainWithPromptBuilder = RetrievalQAChain.builder()
124123
.chatModel(chatModel)
@@ -127,17 +126,18 @@ void testRetrieverWithCustomAnswerModelAndPrompt() {
127126
.build();
128127

129128
final String chainResultWithPromptBuilder = chainWithPromptBuilder.execute(Query.from(retrieveQuery));
130-
assertThat(chainResultWithPromptBuilder).containsIgnoringCase("dattebayo");
131-
assertThat(chainResultWithPromptBuilder).containsIgnoringCase("super saiyan");
129+
assertThat(chainResultWithPromptBuilder)
130+
.containsIgnoringCase("dattebayo")
131+
.containsIgnoringCase("super saiyan");
132132
}
133133

134134
@Test
135-
public void testSummaryGraphIngestor() {
135+
void summaryGraphIngestor() {
136136
summaryGraphIngestorCommon(chatModel);
137137
}
138138

139139
@Test
140-
public void testHypotheticalQuestionIngestor() {
140+
void hypotheticalQuestionIngestor() {
141141
hypotheticalQuestionIngestorCommon(chatModel);
142142
}
143143
}

content-retrievers/langchain4j-community-neo4j-retriever/src/test/java/dev/langchain4j/community/rag/content/retriever/neo4j/Neo4jEmbeddingStoreIngestorQAChainTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
import org.mockito.junit.jupiter.MockitoExtension;
2828

2929
@ExtendWith(MockitoExtension.class)
30-
public class Neo4jEmbeddingStoreIngestorQAChainTest extends Neo4jEmbeddingStoreIngestorBaseTest {
30+
class Neo4jEmbeddingStoreIngestorQAChainTest extends Neo4jEmbeddingStoreIngestorBaseTest {
3131
@Mock
3232
private ChatModel chatLanguageModel;
3333

3434
@Test
35-
public void testBasicRetrieverWithChatQuestionAndAnswerModel() {
35+
void basicRetrieverWithChatQuestionAndAnswerModel() {
3636

3737
final Neo4jEmbeddingStore neo4jEmbeddingStore = Neo4jEmbeddingStore.builder()
3838
.driver(driver)

content-retrievers/langchain4j-community-neo4j-retriever/src/test/java/dev/langchain4j/community/rag/content/retriever/neo4j/Neo4jEmbeddingStoreIngestorTest.java

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package dev.langchain4j.community.rag.content.retriever.neo4j;
22

33
import static org.assertj.core.api.Assertions.assertThat;
4-
import static org.junit.jupiter.api.Assertions.assertEquals;
5-
import static org.junit.jupiter.api.Assertions.assertFalse;
6-
import static org.junit.jupiter.api.Assertions.assertTrue;
74
import static org.mockito.ArgumentMatchers.anyList;
85
import static org.mockito.Mockito.when;
96

@@ -33,13 +30,13 @@
3330
import org.mockito.junit.jupiter.MockitoExtension;
3431

3532
@ExtendWith(MockitoExtension.class)
36-
public class Neo4jEmbeddingStoreIngestorTest extends Neo4jEmbeddingStoreIngestorBaseTest {
33+
class Neo4jEmbeddingStoreIngestorTest extends Neo4jEmbeddingStoreIngestorBaseTest {
3734

3835
@Mock
3936
private ChatModel chatLanguageModel;
4037

4138
@Test
42-
public void testBasicRetriever() {
39+
void testBasicRetriever() {
4340
Document parentDoc = getDocumentMiscTopics();
4441

4542
// Child splitter: splits into sentences using OpenNLP
@@ -67,7 +64,7 @@ public void testBasicRetriever() {
6764
}
6865

6966
@Test
70-
public void testRetrieverWithChatModel() {
67+
void testRetrieverWithChatModel() {
7168

7269
final Neo4jEmbeddingStore neo4jEmbeddingStore = Neo4jEmbeddingStore.builder()
7370
.driver(driver)
@@ -242,15 +239,15 @@ void testRetrieverWithCustomRetrievalAndEmbeddingCreationQueryAndPreInsertedData
242239
List<Content> results = retriever.retrieve(new Query("quantum physics"));
243240

244241
// Assert
245-
assertEquals(1, results.size());
242+
assertThat(results).hasSize(1);
246243
Content parent = results.get(0);
247244

248-
assertTrue(parent.textSegment().text().contains("quantum physics"));
249-
assertEquals("science", parent.textSegment().metadata().getString("source"));
245+
assertThat(parent.textSegment().text()).contains("quantum physics");
246+
assertThat(parent.textSegment().metadata().getString("source")).isEqualTo("science");
250247
}
251248

252249
@Test
253-
public void testSummaryGraphIngestor() {
250+
void testSummaryGraphIngestor() {
254251

255252
when(chatLanguageModel.chat(anyList()))
256253
.thenReturn(ChatResponse.builder()
@@ -261,7 +258,7 @@ public void testSummaryGraphIngestor() {
261258
}
262259

263260
@Test
264-
public void testHypotheticalQuestionIngestor() {
261+
void testHypotheticalQuestionIngestor() {
265262

266263
when(chatLanguageModel.chat(anyList()))
267264
.thenReturn(ChatResponse.builder()
@@ -272,7 +269,7 @@ public void testHypotheticalQuestionIngestor() {
272269
}
273270

274271
@Test
275-
public void testParentChildRetriever() {
272+
void testParentChildRetriever() {
276273
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
277274

278275
int maxSegmentSize = 250;
@@ -325,10 +322,12 @@ protected static void summaryGraphIngestorCommon(ChatModel chatModel) {
325322
// Query and validate results
326323
List<Content> results = retriever.retrieve(Query.from("What is Machine Learning?"));
327324

328-
assertFalse(results.isEmpty(), "Should retrieve at least one parent document");
325+
assertThat(results.isEmpty())
326+
.as("Should retrieve at least one parent document")
327+
.isFalse();
329328

330329
Content result = results.get(0);
331-
assertTrue(result.textSegment().text().toLowerCase().contains("machine learning"));
330+
assertThat(result.textSegment().text().toLowerCase()).contains("machine learning");
332331
assertThat(result.textSegment().metadata().getString("url")).isEqualTo("https://example.com/ai");
333332
}
334333

@@ -352,12 +351,14 @@ protected static void hypotheticalQuestionIngestorCommon(ChatModel chatModel) {
352351
final EmbeddingStoreContentRetriever retriever = getEmbeddingStoreContentRetriever(ingestor);
353352
List<Content> results = retriever.retrieve(Query.from("Who is John Doe?"));
354353

355-
assertFalse(results.isEmpty(), "Should retrieve at least one parent document");
354+
assertThat(results.isEmpty())
355+
.as("Should retrieve at least one parent document")
356+
.isFalse();
356357

357358
Content result = results.get(0);
358359

359360
assertThat(result.textSegment().text().toLowerCase()).containsIgnoringWhitespaces("super saiyan");
360-
assertEquals("Wikipedia link", result.textSegment().metadata().getString("source"));
361-
assertEquals("https://example.com/ai", result.textSegment().metadata().getString("url"));
361+
assertThat(result.textSegment().metadata().getString("source")).isEqualTo("Wikipedia link");
362+
assertThat(result.textSegment().metadata().getString("url")).isEqualTo("https://example.com/ai");
362363
}
363364
}

content-retrievers/langchain4j-community-neo4j-retriever/src/test/java/dev/langchain4j/community/rag/content/retriever/neo4j/Neo4jKnowledgeGraphWriterBaseTest.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ void beforeAll() {
134134
List<Document> documents = List.of(docCat, docKeanu);
135135

136136
graphDocs = graphTransformer.transformAll(documents);
137-
assertThat(graphDocs.size()).isEqualTo(2);
137+
assertThat(graphDocs).hasSize(2);
138138
}
139139

140140
abstract ChatModel getModel();
@@ -388,8 +388,7 @@ record = records.get(1);
388388

389389
private static void assertNodeLabels(Node start, String entityLabel) {
390390
Iterable<String> labels = start.labels();
391-
assertThat(labels).hasSize(2);
392-
assertThat(labels).contains(entityLabel);
391+
assertThat(labels).hasSize(2).contains(entityLabel);
393392
}
394393

395394
private static void assertNodeProps(Node start, String propRegex, String idProp) {
@@ -406,8 +405,7 @@ private static void assertionsDocument(
406405
String expectedMetaKey,
407406
String expectedMetaValue) {
408407
Map<String, Object> map = start.asMap();
409-
assertThat(map.size()).isEqualTo(3);
410-
assertThat(map).containsKey(idProp);
408+
assertThat(map).hasSize(3).containsKey(idProp);
411409
Object text = map.get(textProp);
412410
assertThat(text).isEqualTo(expectedText);
413411

content-retrievers/langchain4j-community-neo4j-retriever/src/test/java/dev/langchain4j/community/rag/content/retriever/neo4j/Neo4jText2CypherRetrieverIT.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,9 @@ void shouldReturnANaturalLanguageResponse() {
267267
// When
268268
String response = neo4jContentRetriever.fromLLM(query);
269269

270-
// Then
271-
assertThat(response).containsIgnoringCase("author");
272-
assertThat(response).containsIgnoringCase("Frank Herbert");
270+
assertThat(response)
271+
// Then
272+
.containsIgnoringCase("author")
273+
.containsIgnoringCase("Frank Herbert");
273274
}
274275
}

0 commit comments

Comments
 (0)