44import lancedb
55from typing import List , Optional , Sequence , Union , Dict , Any
66from adalflow .core .embedder import Embedder
7- from adalflow .core .types import ModelClientType , RetrieverOutput , RetrieverOutputType
7+ from adalflow .core .retriever import Retriever
8+ from adalflow .core .types import RetrieverOutput
89
910# Initialize logging
1011log = logging .getLogger (__name__ )
1415LanceDBRetrieverDocumentsType = Sequence [LanceDBRetrieverDocumentEmbeddingType ]
1516
1617# Step 2: Define the LanceDBRetriever class
17- class LanceDBRetriever :
18+ class LanceDBRetriever ( Retriever [ LanceDBRetrieverDocumentEmbeddingType , Union [ str , List [ str ]]]) :
1819 def __init__ (self , embedder : Embedder , dimensions : int , db_uri : str = "/tmp/lancedb" , top_k : int = 5 , overwrite : bool = True ):
20+ """
21+ LanceDBRetriever is a retriever that leverages LanceDB to efficiently store and query document embeddings.
22+
23+ Attributes:
24+ embedder (Embedder): An instance of the Embedder class used for computing embeddings.
25+ dimensions (int): The dimensionality of the embeddings used.
26+ db_uri (str): The URI of the LanceDB storage (default is "/tmp/lancedb").
27+ top_k (int): The number of top results to retrieve for a given query (default is 5).
28+ overwrite (bool): If True, the existing table is overwritten; otherwise, new documents are appended.
29+
30+ This retriever supports adding documents with their embeddings to a LanceDB storage and retrieving relevant documents based on a given query.
31+
32+ More information on LanceDB can be found here:(https://github.com/lancedb/lancedb)
33+ Documentations: https://lancedb.github.io/lancedb/
34+ """
35+
1936 self .db = lancedb .connect (db_uri )
2037 self .embedder = embedder
2138 self .top_k = top_k
@@ -31,7 +48,12 @@ def __init__(self, embedder: Embedder, dimensions: int, db_uri: str = "/tmp/lanc
3148 self .table = self .db .create_table ("documents" , schema = schema , mode = "overwrite" if overwrite else "append" )
3249
3350 def add_documents (self , documents : Sequence [Dict [str , Any ]]):
34- """Adds documents with pre-computed embeddings."""
51+ """
52+ Adds documents with pre-computed embeddings to LanceDB.
53+ Args:
54+ documents (Sequence[Dict[str, Any]]): A sequence of documents, each with a 'content' field containing text.
55+
56+ """
3557 if not documents :
3658 log .warning ("No documents provided for embedding" )
3759 return
@@ -48,7 +70,15 @@ def add_documents(self, documents: Sequence[Dict[str, Any]]):
4870 log .info (f"Added { len (documents )} documents to the index" )
4971
5072 def retrieve (self , query : Union [str , List [str ]], top_k : Optional [int ] = None ) -> List [RetrieverOutput ]:
51- """Retrieve top-k documents from LanceDB for given query or queries."""
73+ """.
74+ Retrieve top-k documents from LanceDB for a given query or queries.
75+ Args:
76+ query (Union[str, List[str]]): A query string or a list of query strings.
77+ top_k (Optional[int]): The number of top documents to retrieve (if not specified, defaults to the instance's top_k).
78+
79+ Returns:
80+ List[RetrieverOutput]: A list of RetrieverOutput containing the indices and scores of the retrieved documents.
81+ """
5282 if isinstance (query , str ):
5383 query = [query ]
5484
0 commit comments