Skip to content

Commit f44932b

Browse files
akashmangoaifm1320
authored andcommitted
fixed test cases
1 parent 471cb1c commit f44932b

File tree

1 file changed

+105
-87
lines changed

1 file changed

+105
-87
lines changed
Lines changed: 105 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,141 @@
11
import unittest
2-
from unittest.mock import Mock, MagicMock
2+
from unittest.mock import Mock, patch
33
import numpy as np
4+
import pandas as pd
45
from adalflow.components.retriever import LanceDBRetriever
56
from adalflow.core.embedder import Embedder
6-
from adalflow.core.types import RetrieverOutput, Document
7+
from unittest import mock
8+
from adalflow.core.types import EmbedderOutput, RetrieverOutput
79

8-
# Mock LanceDB and PyArrow imports since they are specific to LanceDB
9-
lancedb = MagicMock()
10-
pa = MagicMock()
10+
# Helper function to create dummy embeddings
11+
def create_dummy_embeddings(num_embeddings, dim):
12+
return np.random.rand(num_embeddings, dim).astype(np.float32)
1113

1214
class TestLanceDBRetriever(unittest.TestCase):
1315
def setUp(self):
14-
# Basic configuration
1516
self.dimensions = 128
17+
self.top_k = 5
18+
self.single_query = ["sample query"]
1619
self.embedder = Mock(spec=Embedder)
17-
self.db_uri = "/tmp/test_lancedb"
1820

19-
# Mock embedding output with a simple structure
20-
self.dummy_embeddings = np.random.rand(10, self.dimensions).astype(np.float32)
21-
self.embedder.return_value.data = [
22-
Mock(embedding=embedding) for embedding in self.dummy_embeddings
23-
]
24-
25-
# Initialize LanceDBRetriever with mocked embedder
26-
self.retriever = LanceDBRetriever(
27-
embedder=self.embedder, dimensions=self.dimensions, db_uri=self.db_uri
21+
# Mock embedder to return dummy embeddings
22+
self.dummy_embeddings = create_dummy_embeddings(10, self.dimensions)
23+
self.embedder.return_value = EmbedderOutput(
24+
data=[Mock(embedding=emb) for emb in self.dummy_embeddings[:len(self.single_query)]]
2825
)
2926

30-
# Mock LanceDB table and connection
31-
self.retriever.db.create_table = MagicMock(return_value=Mock())
32-
self.retriever.table = self.retriever.db.create_table.return_value
27+
with patch("lancedb.connect") as mock_db_connect:
28+
self.mock_db = mock_db_connect.return_value
29+
self.mock_table = Mock()
30+
self.mock_db.create_table.return_value = self.mock_table
31+
self.retriever = LanceDBRetriever(
32+
embedder=self.embedder,
33+
dimensions=self.dimensions,
34+
db_uri="/tmp/lancedb",
35+
top_k=self.top_k
36+
)
3337

3438
def test_initialization(self):
35-
# Check dimensions and embedder assignment
3639
self.assertEqual(self.retriever.dimensions, self.dimensions)
37-
self.assertEqual(self.retriever.top_k, 5)
40+
self.assertEqual(self.retriever.top_k, self.top_k)
41+
self.mock_db.create_table.assert_called_once()
3842

3943
def test_add_documents(self):
40-
# Sample documents
41-
documents = [{"content": f"Document {i}"} for i in range(5)]
44+
documents = [{"content": f"Document {i}"} for i in range(10)]
45+
embeddings = create_dummy_embeddings(len(documents), self.dimensions)
4246

43-
# Mock LanceDB table add method
44-
self.retriever.table.add = MagicMock()
47+
# Mock embedding output
48+
self.embedder.return_value = EmbedderOutput(
49+
data=[Mock(embedding=embedding) for embedding in embeddings]
50+
)
4551

46-
# Add documents to LanceDBRetriever
4752
self.retriever.add_documents(documents)
53+
self.assertEqual(self.mock_table.add.call_count, 1)
54+
args, _ = self.mock_table.add.call_args
55+
self.assertEqual(len(args[0]), len(documents))
56+
57+
def test_add_documents_no_documents(self):
58+
self.retriever.add_documents([])
59+
self.mock_table.add.assert_not_called()
4860

49-
# Ensure add method was called
50-
self.retriever.table.add.assert_called_once()
51-
# Verify embeddings were passed to LanceDB add method
52-
added_data = self.retriever.table.add.call_args[0][0]
53-
self.assertEqual(len(added_data), len(documents))
54-
self.assertIn("vector", added_data[0])
55-
self.assertIn("content", added_data[0])
56-
57-
def test_retrieve(self):
58-
# Prepare a sample query and mocked search result from LanceDB
59-
query = "test query"
60-
dummy_scores = [0.9, 0.8, 0.7]
61-
dummy_indices = [0, 1, 2]
62-
63-
# Set up mock search result as if it was retrieved from LanceDB
64-
self.retriever.table.search = MagicMock(return_value=Mock())
65-
self.retriever.table.search().limit().to_pandas.return_value = Mock(
66-
index=dummy_indices, _distance=dummy_scores
61+
def test_retrieve_single_query(self):
62+
query = "sample query"
63+
query_embedding = create_dummy_embeddings(1, self.dimensions)[0]
64+
65+
# Mock embedding for query
66+
self.embedder.return_value = EmbedderOutput(
67+
data=[Mock(embedding=query_embedding)]
6768
)
6869

69-
# Retrieve top-k results for the query
70-
result = self.retriever.retrieve(query)
70+
# Mock search results from LanceDB as pandas DataFrame
71+
results_df = pd.DataFrame({
72+
"index": [0, 1, 2],
73+
"_distance": [0.1, 0.2, 0.3]
74+
})
75+
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = results_df
7176

72-
# Check if retrieve method returns expected output structure
73-
self.assertIsInstance(result, list)
74-
self.assertEqual(len(result), 1)
77+
result = self.retriever.retrieve(query)
7578
self.assertIsInstance(result[0], RetrieverOutput)
76-
self.assertEqual(result[0].query, query)
77-
self.assertEqual(result[0].doc_indices, dummy_indices)
78-
self.assertEqual(result[0].doc_scores, dummy_scores)
79+
self.assertEqual(len(result[0].doc_indices), 3)
80+
self.assertEqual(len(result[0].doc_scores), 3)
81+
self.assertListEqual(result[0].doc_indices, [0, 1, 2])
82+
self.assertListEqual(result[0].doc_scores, [0.1, 0.2, 0.3])
7983

8084
def test_retrieve_multiple_queries(self):
81-
# Prepare multiple queries and mocked search result
8285
queries = ["query 1", "query 2"]
83-
dummy_scores = [[0.9, 0.8], [0.85, 0.75]]
84-
dummy_indices = [[0, 1], [2, 3]]
85-
86-
# Set up mock for each query's result
87-
self.retriever.table.search().limit().to_pandas.side_effect = [
88-
Mock(index=dummy_indices[0], _distance=dummy_scores[0]),
89-
Mock(index=dummy_indices[1], _distance=dummy_scores[1]),
90-
]
91-
92-
# Retrieve for multiple queries
93-
results = self.retriever.retrieve(queries)
94-
95-
# Verify the structure and content of the results
96-
self.assertEqual(len(results), len(queries))
97-
for i, result in enumerate(results):
98-
self.assertEqual(result.query, queries[i])
99-
self.assertEqual(result.doc_indices, dummy_indices[i])
100-
self.assertEqual(result.doc_scores, dummy_scores[i])
101-
102-
def test_empty_document_addition(self):
103-
# Ensure warning log for empty document list
104-
with self.assertLogs(level='WARNING'):
105-
self.retriever.add_documents([])
86+
query_embeddings = create_dummy_embeddings(len(queries), self.dimensions)
10687

107-
def test_retrieve_with_empty_query(self):
108-
# Check empty query handling, expecting a list with empty RetrieverOutput
109-
result = self.retriever.retrieve("")
110-
self.assertEqual(result, [RetrieverOutput(doc_indices=[], doc_scores=[], query="")])
88+
# Mock embedding for queries
89+
self.embedder.return_value = EmbedderOutput(
90+
data=[Mock(embedding=embedding) for embedding in query_embeddings]
91+
)
92+
93+
# Mock search results for each query
94+
results_df = pd.DataFrame({
95+
"index": [0, 1, 2],
96+
"_distance": [0.1, 0.2, 0.3]
97+
})
98+
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = results_df
11199

112-
def test_add_documents_embedding_failure(self):
113-
# Simulate embedding failure
114-
self.embedder.side_effect = Exception("Embedding failure")
115-
documents = [{"content": "test document"}]
100+
result = self.retriever.retrieve(queries)
101+
self.assertEqual(len(result), len(queries))
102+
for res in result:
103+
self.assertIsInstance(res, RetrieverOutput)
104+
self.assertEqual(len(res.doc_indices), 3)
105+
self.assertEqual(len(res.doc_scores), 3)
116106

117-
with self.assertRaises(Exception) as context:
118-
self.retriever.add_documents(documents)
107+
def test_retrieve_with_empty_query(self):
108+
# Mock the empty results DataFrame
109+
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = pd.DataFrame({
110+
"index": [],
111+
"_distance": []
112+
})
113+
114+
def test_retrieve_with_no_index(self):
115+
empty_retriever = LanceDBRetriever(
116+
embedder=self.embedder, dimensions=self.dimensions
117+
)
118+
with self.assertRaises(ValueError):
119+
empty_retriever.retrieve("test query")
120+
121+
def test_overwrite_table_on_initialization(self):
122+
with patch("lancedb.connect") as mock_db_connect:
123+
mock_db = mock_db_connect.return_value
124+
mock_table = Mock()
125+
mock_db.create_table.return_value = mock_table
126+
127+
LanceDBRetriever(
128+
embedder=self.embedder,
129+
dimensions=self.dimensions,
130+
db_uri="/tmp/lancedb",
131+
overwrite=True
132+
)
133+
mock_db.create_table.assert_called_once_with(
134+
"documents",
135+
schema=mock.ANY,
136+
mode="overwrite"
137+
)
119138

120-
self.assertEqual(str(context.exception), "Embedding failure")
121139

122140
if __name__ == "__main__":
123141
unittest.main()

0 commit comments

Comments
 (0)