Skip to content

Commit b709393

Browse files
committed
formatting changes
1 parent 022181d commit b709393

File tree

4 files changed

+2518
-2156
lines changed

4 files changed

+2518
-2156
lines changed

adalflow/adalflow/components/retriever/lancedb_retriver.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,24 @@
1111
log = logging.getLogger(__name__)
1212

1313
# Defined data types
14-
LanceDBRetrieverDocumentEmbeddingType = Union[List[float], np.ndarray] # single embedding
14+
LanceDBRetrieverDocumentEmbeddingType = Union[
15+
List[float], np.ndarray
16+
] # single embedding
1517
LanceDBRetrieverDocumentsType = Sequence[LanceDBRetrieverDocumentEmbeddingType]
1618

19+
1720
# Step 2: Define the LanceDBRetriever class
18-
class LanceDBRetriever(Retriever[LanceDBRetrieverDocumentEmbeddingType, Union[str, List[str]]]):
19-
def __init__(self, embedder: Embedder, dimensions: int, db_uri: str = "/tmp/lancedb", top_k: int = 5, overwrite: bool = True):
21+
class LanceDBRetriever(
22+
Retriever[LanceDBRetrieverDocumentEmbeddingType, Union[str, List[str]]]
23+
):
24+
def __init__(
25+
self,
26+
embedder: Embedder,
27+
dimensions: int,
28+
db_uri: str = "/tmp/lancedb",
29+
top_k: int = 5,
30+
overwrite: bool = True,
31+
):
2032
"""
2133
LanceDBRetriever is a retriever that leverages LanceDB to efficiently store and query document embeddings.
2234
@@ -39,13 +51,17 @@ def __init__(self, embedder: Embedder, dimensions: int, db_uri: str = "/tmp/lanc
3951
self.dimensions = dimensions
4052

4153
# Define table schema with vector field for embeddings
42-
schema = pa.schema([
43-
pa.field("vector", pa.list_(pa.float32(), list_size=self.dimensions)),
44-
pa.field("content", pa.string())
45-
])
54+
schema = pa.schema(
55+
[
56+
pa.field("vector", pa.list_(pa.float32(), list_size=self.dimensions)),
57+
pa.field("content", pa.string()),
58+
]
59+
)
4660

4761
# Create or overwrite the table for storing documents and embeddings
48-
self.table = self.db.create_table("documents", schema=schema, mode="overwrite" if overwrite else "append")
62+
self.table = self.db.create_table(
63+
"documents", schema=schema, mode="overwrite" if overwrite else "append"
64+
)
4965

5066
def add_documents(self, documents: Sequence[Dict[str, Any]]):
5167
"""
@@ -63,13 +79,18 @@ def add_documents(self, documents: Sequence[Dict[str, Any]]):
6379
embeddings = self.embedder(input=doc_texts).data
6480

6581
# Format embeddings for LanceDB
66-
data = [{"vector": embedding.embedding, "content": text} for embedding, text in zip(embeddings, doc_texts)]
82+
data = [
83+
{"vector": embedding.embedding, "content": text}
84+
for embedding, text in zip(embeddings, doc_texts)
85+
]
6786

6887
# Add data to LanceDB table
6988
self.table.add(data)
7089
log.info(f"Added {len(documents)} documents to the index")
7190

72-
def retrieve(self, query: Union[str, List[str]], top_k: Optional[int] = None) -> List[RetrieverOutput]:
91+
def retrieve(
92+
self, query: Union[str, List[str]], top_k: Optional[int] = None
93+
) -> List[RetrieverOutput]:
7394
""".
7495
Retrieve top-k documents from LanceDB for a given query or queries.
7596
Args:
@@ -83,11 +104,13 @@ def retrieve(self, query: Union[str, List[str]], top_k: Optional[int] = None) ->
83104
query = [query]
84105

85106
if not query or (isinstance(query, str) and query.strip() == ""):
86-
raise ValueError("Query cannot be empty.")
107+
raise ValueError("Query cannot be empty.")
87108

88109
# Check if table (index) exists before performing search
89110
if not self.table:
90-
raise ValueError("The index has not been initialized or the table is missing.")
111+
raise ValueError(
112+
"The index has not been initialized or the table is missing."
113+
)
91114

92115
query_embeddings = self.embedder(input=query).data
93116
output: List[RetrieverOutput] = []
@@ -105,9 +128,11 @@ def retrieve(self, query: Union[str, List[str]], top_k: Optional[int] = None) ->
105128
scores = results["_distance"].tolist()
106129

107130
# Append results to output
108-
output.append(RetrieverOutput(
109-
doc_indices=indices,
110-
doc_scores=scores,
111-
query=query[0] if len(query) == 1 else query
112-
))
131+
output.append(
132+
RetrieverOutput(
133+
doc_indices=indices,
134+
doc_scores=scores,
135+
query=query[0] if len(query) == 1 else query,
136+
)
137+
)
113138
return output

adalflow/adalflow/utils/lazy_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class OptionalPackages(Enum):
7575
)
7676

7777
LANCEDB = (
78-
"lancedb",
78+
"lancedb",
7979
"Please install lancedb with: pip install lancedb .",
8080
)
8181

adalflow/tests/test_lancedb_retriver.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
from unittest import mock
88
from adalflow.core.types import EmbedderOutput, RetrieverOutput
99

10+
1011
# Helper function to create dummy embeddings
1112
def create_dummy_embeddings(num_embeddings, dim):
1213
return np.random.rand(num_embeddings, dim).astype(np.float32)
1314

15+
1416
class TestLanceDBRetriever(unittest.TestCase):
1517
def setUp(self):
1618
self.dimensions = 128
@@ -21,7 +23,10 @@ def setUp(self):
2123
# Mock embedder to return dummy embeddings
2224
self.dummy_embeddings = create_dummy_embeddings(10, self.dimensions)
2325
self.embedder.return_value = EmbedderOutput(
24-
data=[Mock(embedding=emb) for emb in self.dummy_embeddings[:len(self.single_query)]]
26+
data=[
27+
Mock(embedding=emb)
28+
for emb in self.dummy_embeddings[: len(self.single_query)]
29+
]
2530
)
2631

2732
with patch("lancedb.connect") as mock_db_connect:
@@ -32,7 +37,7 @@ def setUp(self):
3237
embedder=self.embedder,
3338
dimensions=self.dimensions,
3439
db_uri="/tmp/lancedb",
35-
top_k=self.top_k
40+
top_k=self.top_k,
3641
)
3742

3843
def test_initialization(self):
@@ -68,11 +73,10 @@ def test_retrieve_single_query(self):
6873
)
6974

7075
# Mock search results from LanceDB as pandas DataFrame
71-
results_df = pd.DataFrame({
72-
"index": [0, 1, 2],
73-
"_distance": [0.1, 0.2, 0.3]
74-
})
75-
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = results_df
76+
results_df = pd.DataFrame({"index": [0, 1, 2], "_distance": [0.1, 0.2, 0.3]})
77+
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = (
78+
results_df
79+
)
7680

7781
result = self.retriever.retrieve(query)
7882
self.assertIsInstance(result[0], RetrieverOutput)
@@ -91,11 +95,10 @@ def test_retrieve_multiple_queries(self):
9195
)
9296

9397
# Mock search results for each query
94-
results_df = pd.DataFrame({
95-
"index": [0, 1, 2],
96-
"_distance": [0.1, 0.2, 0.3]
97-
})
98-
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = results_df
98+
results_df = pd.DataFrame({"index": [0, 1, 2], "_distance": [0.1, 0.2, 0.3]})
99+
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = (
100+
results_df
101+
)
99102

100103
result = self.retriever.retrieve(queries)
101104
self.assertEqual(len(result), len(queries))
@@ -106,10 +109,9 @@ def test_retrieve_multiple_queries(self):
106109

107110
def test_retrieve_with_empty_query(self):
108111
# Mock the empty results DataFrame
109-
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = pd.DataFrame({
110-
"index": [],
111-
"_distance": []
112-
})
112+
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = pd.DataFrame(
113+
{"index": [], "_distance": []}
114+
)
113115

114116
def test_retrieve_with_no_index(self):
115117
empty_retriever = LanceDBRetriever(
@@ -128,12 +130,10 @@ def test_overwrite_table_on_initialization(self):
128130
embedder=self.embedder,
129131
dimensions=self.dimensions,
130132
db_uri="/tmp/lancedb",
131-
overwrite=True
133+
overwrite=True,
132134
)
133135
mock_db.create_table.assert_called_once_with(
134-
"documents",
135-
schema=mock.ANY,
136-
mode="overwrite"
136+
"documents", schema=mock.ANY, mode="overwrite"
137137
)
138138

139139

0 commit comments

Comments
 (0)