Skip to content

Commit 7137548

Browse files
authored
Merge pull request #236 from akashAD98/feature/lancedb
Feature/lancedb as vectordb added
2 parents 51ac755 + 177a96c commit 7137548

File tree

9 files changed

+4422
-3456
lines changed

9 files changed

+4422
-3456
lines changed

adalflow/adalflow/components/retriever/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,19 @@
2727
OptionalPackages.QDRANT,
2828
)
2929

30+
LanceDBRetriever = LazyImport(
31+
"adalflow.components.retriever.lancedb_retriver.LanceDBRetriever",
32+
OptionalPackages.LANCEDB,
33+
)
34+
3035
__all__ = [
3136
"BM25Retriever",
3237
"LLMRetriever",
3338
"FAISSRetriever",
3439
"RerankerRetriever",
3540
"PostgresRetriever",
3641
"QdrantRetriever",
42+
"LanceDBRetriever",
3743
"split_text_by_word_fn",
3844
"split_text_by_word_fn_then_lower_tokenized",
3945
]
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import logging
2+
import numpy as np
3+
import pyarrow as pa
4+
import lancedb
5+
from typing import List, Optional, Sequence, Union, Dict, Any
6+
from adalflow.core.embedder import Embedder
7+
from adalflow.core.retriever import Retriever
8+
from adalflow.core.types import RetrieverOutput
9+
10+
# Initialize logging
11+
log = logging.getLogger(__name__)
12+
13+
# Defined data types
14+
LanceDBRetrieverDocumentEmbeddingType = Union[
15+
List[float], np.ndarray
16+
] # single embedding
17+
LanceDBRetrieverDocumentsType = Sequence[LanceDBRetrieverDocumentEmbeddingType]
18+
19+
20+
# Step 2: Define the LanceDBRetriever class
21+
class LanceDBRetriever(
22+
Retriever[LanceDBRetrieverDocumentEmbeddingType, Union[str, List[str]]]
23+
):
24+
def __init__(
25+
self,
26+
embedder: Embedder,
27+
dimensions: int,
28+
db_uri: str = "/tmp/lancedb",
29+
top_k: int = 5,
30+
overwrite: bool = True,
31+
):
32+
"""
33+
LanceDBRetriever is a retriever that leverages LanceDB to efficiently store and query document embeddings.
34+
35+
Attributes:
36+
embedder (Embedder): An instance of the Embedder class used for computing embeddings.
37+
dimensions (int): The dimensionality of the embeddings used.
38+
db_uri (str): The URI of the LanceDB storage (default is "/tmp/lancedb").
39+
top_k (int): The number of top results to retrieve for a given query (default is 5).
40+
overwrite (bool): If True, the existing table is overwritten; otherwise, new documents are appended.
41+
42+
This retriever supports adding documents with their embeddings to a LanceDB storage and retrieving relevant documents based on a given query.
43+
44+
More information on LanceDB can be found here:(https://github.com/lancedb/lancedb)
45+
Documentations: https://lancedb.github.io/lancedb/
46+
"""
47+
super().__init__()
48+
self.db = lancedb.connect(db_uri)
49+
self.embedder = embedder
50+
self.top_k = top_k
51+
self.dimensions = dimensions
52+
53+
# Define table schema with vector field for embeddings
54+
schema = pa.schema(
55+
[
56+
pa.field("vector", pa.list_(pa.float32(), list_size=self.dimensions)),
57+
pa.field("content", pa.string()),
58+
]
59+
)
60+
61+
# Create or overwrite the table for storing documents and embeddings
62+
self.table = self.db.create_table(
63+
"documents", schema=schema, mode="overwrite" if overwrite else "append"
64+
)
65+
66+
def add_documents(self, documents: Sequence[Dict[str, Any]]):
67+
"""
68+
Adds documents with pre-computed embeddings to LanceDB.
69+
Args:
70+
documents (Sequence[Dict[str, Any]]): A sequence of documents, each with a 'content' field containing text.
71+
72+
"""
73+
if not documents:
74+
log.warning("No documents provided for embedding")
75+
return
76+
77+
# Embed document content using Embedder
78+
doc_texts = [doc["content"] for doc in documents]
79+
embeddings = self.embedder(input=doc_texts).data
80+
81+
# Format embeddings for LanceDB
82+
data = [
83+
{"vector": embedding.embedding, "content": text}
84+
for embedding, text in zip(embeddings, doc_texts)
85+
]
86+
87+
# Add data to LanceDB table
88+
self.table.add(data)
89+
log.info(f"Added {len(documents)} documents to the index")
90+
91+
def retrieve(
92+
self, query: Union[str, List[str]], top_k: Optional[int] = None
93+
) -> List[RetrieverOutput]:
94+
""".
95+
Retrieve top-k documents from LanceDB for a given query or queries.
96+
Args:
97+
query (Union[str, List[str]]): A query string or a list of query strings.
98+
top_k (Optional[int]): The number of top documents to retrieve (if not specified, defaults to the instance's top_k).
99+
100+
Returns:
101+
List[RetrieverOutput]: A list of RetrieverOutput containing the indices and scores of the retrieved documents.
102+
"""
103+
if isinstance(query, str):
104+
query = [query]
105+
106+
if not query or (isinstance(query, str) and query.strip() == ""):
107+
raise ValueError("Query cannot be empty.")
108+
109+
# Check if table (index) exists before performing search
110+
if not self.table:
111+
raise ValueError(
112+
"The index has not been initialized or the table is missing."
113+
)
114+
115+
query_embeddings = self.embedder(input=query).data
116+
output: List[RetrieverOutput] = []
117+
118+
# Perform search in LanceDB for each query
119+
for query_emb in query_embeddings:
120+
results = (
121+
self.table.search(query_emb.embedding)
122+
.limit(top_k or self.top_k)
123+
.to_pandas()
124+
)
125+
126+
# Gather indices and scores from search results
127+
indices = results.index.tolist()
128+
scores = results["_distance"].tolist()
129+
130+
# Append results to output
131+
output.append(
132+
RetrieverOutput(
133+
doc_indices=indices,
134+
doc_scores=scores,
135+
query=query[0] if len(query) == 1 else query,
136+
)
137+
)
138+
return output

adalflow/adalflow/utils/lazy_import.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ class OptionalPackages(Enum):
7474
"Please install faiss with: pip install faiss-cpu (or faiss if you use GPU)",
7575
)
7676

77+
LANCEDB = (
78+
"lancedb",
79+
"Please install lancedb with: pip install lancedb .",
80+
)
81+
7782
# db library
7883
SQLALCHEMY = (
7984
"sqlalchemy",

0 commit comments

Comments
 (0)