Skip to content

Commit 88fa0a2

Browse files
akashmangoaifm1320
authored andcommitted
lancdb as retriver support
1 parent d91638e commit 88fa0a2

File tree

2 files changed

+200
-0
lines changed

2 files changed

+200
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import logging
2+
import numpy as np
3+
import pyarrow as pa
4+
import lancedb
5+
from typing import List, Optional, Sequence, Union, Dict, Any
6+
from adalflow.core.embedder import Embedder
7+
from adalflow.core.types import ModelClientType, RetrieverOutput, RetrieverOutputType
8+
9+
# Initialize logging
10+
log = logging.getLogger(__name__)
11+
12+
# Defined data types
13+
LanceDBRetrieverDocumentEmbeddingType = Union[List[float], np.ndarray] # single embedding
14+
LanceDBRetrieverDocumentsType = Sequence[LanceDBRetrieverDocumentEmbeddingType]
15+
16+
# Step 2: Define the LanceDBRetriever class
17+
class LanceDBRetriever:
18+
def __init__(self, embedder: Embedder, dimensions: int, db_uri: str = "/tmp/lancedb", top_k: int = 5, overwrite: bool = True):
19+
self.db = lancedb.connect(db_uri)
20+
self.embedder = embedder
21+
self.top_k = top_k
22+
self.dimensions = dimensions
23+
24+
# Define table schema with vector field for embeddings
25+
schema = pa.schema([
26+
pa.field("vector", pa.list_(pa.float32(), list_size=self.dimensions)),
27+
pa.field("content", pa.string())
28+
])
29+
30+
# Create or overwrite the table for storing documents and embeddings
31+
self.table = self.db.create_table("documents", schema=schema, mode="overwrite" if overwrite else "append")
32+
33+
def add_documents(self, documents: Sequence[Dict[str, Any]]):
34+
"""Adds documents with pre-computed embeddings."""
35+
if not documents:
36+
log.warning("No documents provided for embedding")
37+
return
38+
39+
# Embed document content using Embedder
40+
doc_texts = [doc["content"] for doc in documents]
41+
embeddings = self.embedder(input=doc_texts).data
42+
43+
# Format embeddings for LanceDB
44+
data = [{"vector": embedding.embedding, "content": text} for embedding, text in zip(embeddings, doc_texts)]
45+
46+
# Add data to LanceDB table
47+
self.table.add(data)
48+
log.info(f"Added {len(documents)} documents to the index")
49+
50+
def retrieve(self, query: Union[str, List[str]], top_k: Optional[int] = None) -> List[RetrieverOutput]:
51+
"""Retrieve top-k documents from LanceDB for given query or queries."""
52+
if isinstance(query, str):
53+
query = [query]
54+
55+
# Embed the query text(s) with Embedder
56+
query_embeddings = self.embedder(input=query).data
57+
output: List[RetrieverOutput] = []
58+
59+
# Perform search in LanceDB for each query
60+
for query_emb in query_embeddings:
61+
results = (
62+
self.table.search(query_emb.embedding)
63+
.limit(top_k or self.top_k)
64+
.to_pandas()
65+
)
66+
67+
# Gather indices and scores from search results
68+
indices = results.index.tolist()
69+
scores = results["_distance"].tolist()
70+
71+
# Append results to output
72+
output.append(RetrieverOutput(
73+
doc_indices=indices,
74+
doc_scores=scores,
75+
query=query[0] if len(query) == 1 else query
76+
))
77+
return output
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import unittest
2+
from unittest.mock import Mock, MagicMock
3+
import numpy as np
4+
from adalflow.components.retriever import LanceDBRetriever
5+
from adalflow.core.embedder import Embedder
6+
from adalflow.core.types import RetrieverOutput, Document
7+
8+
# Mock LanceDB and PyArrow imports since they are specific to LanceDB
9+
lancedb = MagicMock()
10+
pa = MagicMock()
11+
12+
class TestLanceDBRetriever(unittest.TestCase):
13+
def setUp(self):
14+
# Basic configuration
15+
self.dimensions = 128
16+
self.embedder = Mock(spec=Embedder)
17+
self.db_uri = "/tmp/test_lancedb"
18+
19+
# Mock embedding output with a simple structure
20+
self.dummy_embeddings = np.random.rand(10, self.dimensions).astype(np.float32)
21+
self.embedder.return_value.data = [
22+
Mock(embedding=embedding) for embedding in self.dummy_embeddings
23+
]
24+
25+
# Initialize LanceDBRetriever with mocked embedder
26+
self.retriever = LanceDBRetriever(
27+
embedder=self.embedder, dimensions=self.dimensions, db_uri=self.db_uri
28+
)
29+
30+
# Mock LanceDB table and connection
31+
self.retriever.db.create_table = MagicMock(return_value=Mock())
32+
self.retriever.table = self.retriever.db.create_table.return_value
33+
34+
def test_initialization(self):
35+
# Check dimensions and embedder assignment
36+
self.assertEqual(self.retriever.dimensions, self.dimensions)
37+
self.assertEqual(self.retriever.top_k, 5)
38+
39+
def test_add_documents(self):
40+
# Sample documents
41+
documents = [{"content": f"Document {i}"} for i in range(5)]
42+
43+
# Mock LanceDB table add method
44+
self.retriever.table.add = MagicMock()
45+
46+
# Add documents to LanceDBRetriever
47+
self.retriever.add_documents(documents)
48+
49+
# Ensure add method was called
50+
self.retriever.table.add.assert_called_once()
51+
# Verify embeddings were passed to LanceDB add method
52+
added_data = self.retriever.table.add.call_args[0][0]
53+
self.assertEqual(len(added_data), len(documents))
54+
self.assertIn("vector", added_data[0])
55+
self.assertIn("content", added_data[0])
56+
57+
def test_retrieve(self):
58+
# Prepare a sample query and mocked search result from LanceDB
59+
query = "test query"
60+
dummy_scores = [0.9, 0.8, 0.7]
61+
dummy_indices = [0, 1, 2]
62+
63+
# Set up mock search result as if it was retrieved from LanceDB
64+
self.retriever.table.search = MagicMock(return_value=Mock())
65+
self.retriever.table.search().limit().to_pandas.return_value = Mock(
66+
index=dummy_indices, _distance=dummy_scores
67+
)
68+
69+
# Retrieve top-k results for the query
70+
result = self.retriever.retrieve(query)
71+
72+
# Check if retrieve method returns expected output structure
73+
self.assertIsInstance(result, list)
74+
self.assertEqual(len(result), 1)
75+
self.assertIsInstance(result[0], RetrieverOutput)
76+
self.assertEqual(result[0].query, query)
77+
self.assertEqual(result[0].doc_indices, dummy_indices)
78+
self.assertEqual(result[0].doc_scores, dummy_scores)
79+
80+
def test_retrieve_multiple_queries(self):
81+
# Prepare multiple queries and mocked search result
82+
queries = ["query 1", "query 2"]
83+
dummy_scores = [[0.9, 0.8], [0.85, 0.75]]
84+
dummy_indices = [[0, 1], [2, 3]]
85+
86+
# Set up mock for each query's result
87+
self.retriever.table.search().limit().to_pandas.side_effect = [
88+
Mock(index=dummy_indices[0], _distance=dummy_scores[0]),
89+
Mock(index=dummy_indices[1], _distance=dummy_scores[1]),
90+
]
91+
92+
# Retrieve for multiple queries
93+
results = self.retriever.retrieve(queries)
94+
95+
# Verify the structure and content of the results
96+
self.assertEqual(len(results), len(queries))
97+
for i, result in enumerate(results):
98+
self.assertEqual(result.query, queries[i])
99+
self.assertEqual(result.doc_indices, dummy_indices[i])
100+
self.assertEqual(result.doc_scores, dummy_scores[i])
101+
102+
def test_empty_document_addition(self):
103+
# Ensure warning log for empty document list
104+
with self.assertLogs(level='WARNING'):
105+
self.retriever.add_documents([])
106+
107+
def test_retrieve_with_empty_query(self):
108+
# Check empty query handling, expecting a list with empty RetrieverOutput
109+
result = self.retriever.retrieve("")
110+
self.assertEqual(result, [RetrieverOutput(doc_indices=[], doc_scores=[], query="")])
111+
112+
def test_add_documents_embedding_failure(self):
113+
# Simulate embedding failure
114+
self.embedder.side_effect = Exception("Embedding failure")
115+
documents = [{"content": "test document"}]
116+
117+
with self.assertRaises(Exception) as context:
118+
self.retriever.add_documents(documents)
119+
120+
self.assertEqual(str(context.exception), "Embedding failure")
121+
122+
if __name__ == "__main__":
123+
unittest.main()

0 commit comments

Comments
 (0)