Skip to content

Commit e8a87b0

Browse files
akashmangoaifm1320
authored andcommitted
added subclass & doc
1 parent 396160a commit e8a87b0

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

adalflow/adalflow/components/retriever/lancedb_retriver.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import lancedb
55
from typing import List, Optional, Sequence, Union, Dict, Any
66
from adalflow.core.embedder import Embedder
7-
from adalflow.core.types import ModelClientType, RetrieverOutput, RetrieverOutputType
7+
from adalflow.core.retriever import Retriever
8+
from adalflow.core.types import RetrieverOutput
89

910
# Initialize logging
1011
log = logging.getLogger(__name__)
@@ -14,8 +15,24 @@
1415
LanceDBRetrieverDocumentsType = Sequence[LanceDBRetrieverDocumentEmbeddingType]
1516

1617
# Step 2: Define the LanceDBRetriever class
17-
class LanceDBRetriever:
18+
class LanceDBRetriever(Retriever[LanceDBRetrieverDocumentEmbeddingType, Union[str, List[str]]]):
1819
def __init__(self, embedder: Embedder, dimensions: int, db_uri: str = "/tmp/lancedb", top_k: int = 5, overwrite: bool = True):
20+
"""
21+
LanceDBRetriever is a retriever that leverages LanceDB to efficiently store and query document embeddings.
22+
23+
Attributes:
24+
embedder (Embedder): An instance of the Embedder class used for computing embeddings.
25+
dimensions (int): The dimensionality of the embeddings used.
26+
db_uri (str): The URI of the LanceDB storage (default is "/tmp/lancedb").
27+
top_k (int): The number of top results to retrieve for a given query (default is 5).
28+
overwrite (bool): If True, the existing table is overwritten; otherwise, new documents are appended.
29+
30+
This retriever supports adding documents with their embeddings to a LanceDB storage and retrieving relevant documents based on a given query.
31+
32+
More information on LanceDB can be found here:(https://github.com/lancedb/lancedb)
33+
Documentations: https://lancedb.github.io/lancedb/
34+
"""
35+
1936
self.db = lancedb.connect(db_uri)
2037
self.embedder = embedder
2138
self.top_k = top_k
@@ -31,7 +48,12 @@ def __init__(self, embedder: Embedder, dimensions: int, db_uri: str = "/tmp/lanc
3148
self.table = self.db.create_table("documents", schema=schema, mode="overwrite" if overwrite else "append")
3249

3350
def add_documents(self, documents: Sequence[Dict[str, Any]]):
34-
"""Adds documents with pre-computed embeddings."""
51+
"""
52+
Adds documents with pre-computed embeddings to LanceDB.
53+
Args:
54+
documents (Sequence[Dict[str, Any]]): A sequence of documents, each with a 'content' field containing text.
55+
56+
"""
3557
if not documents:
3658
log.warning("No documents provided for embedding")
3759
return
@@ -48,7 +70,15 @@ def add_documents(self, documents: Sequence[Dict[str, Any]]):
4870
log.info(f"Added {len(documents)} documents to the index")
4971

5072
def retrieve(self, query: Union[str, List[str]], top_k: Optional[int] = None) -> List[RetrieverOutput]:
51-
"""Retrieve top-k documents from LanceDB for given query or queries."""
73+
""".
74+
Retrieve top-k documents from LanceDB for a given query or queries.
75+
Args:
76+
query (Union[str, List[str]]): A query string or a list of query strings.
77+
top_k (Optional[int]): The number of top documents to retrieve (if not specified, defaults to the instance's top_k).
78+
79+
Returns:
80+
List[RetrieverOutput]: A list of RetrieverOutput containing the indices and scores of the retrieved documents.
81+
"""
5282
if isinstance(query, str):
5383
query = [query]
5484

0 commit comments

Comments
 (0)