1111log = logging .getLogger (__name__ )
1212
1313# Defined data types
14- LanceDBRetrieverDocumentEmbeddingType = Union [List [float ], np .ndarray ] # single embedding
14+ LanceDBRetrieverDocumentEmbeddingType = Union [
15+ List [float ], np .ndarray
16+ ] # single embedding
1517LanceDBRetrieverDocumentsType = Sequence [LanceDBRetrieverDocumentEmbeddingType ]
1618
19+
1720# Step 2: Define the LanceDBRetriever class
18- class LanceDBRetriever (Retriever [LanceDBRetrieverDocumentEmbeddingType , Union [str , List [str ]]]):
19- def __init__ (self , embedder : Embedder , dimensions : int , db_uri : str = "/tmp/lancedb" , top_k : int = 5 , overwrite : bool = True ):
21+ class LanceDBRetriever (
22+ Retriever [LanceDBRetrieverDocumentEmbeddingType , Union [str , List [str ]]]
23+ ):
24+ def __init__ (
25+ self ,
26+ embedder : Embedder ,
27+ dimensions : int ,
28+ db_uri : str = "/tmp/lancedb" ,
29+ top_k : int = 5 ,
30+ overwrite : bool = True ,
31+ ):
2032 """
2133 LanceDBRetriever is a retriever that leverages LanceDB to efficiently store and query document embeddings.
2234
@@ -39,13 +51,17 @@ def __init__(self, embedder: Embedder, dimensions: int, db_uri: str = "/tmp/lanc
3951 self .dimensions = dimensions
4052
4153 # Define table schema with vector field for embeddings
42- schema = pa .schema ([
43- pa .field ("vector" , pa .list_ (pa .float32 (), list_size = self .dimensions )),
44- pa .field ("content" , pa .string ())
45- ])
54+ schema = pa .schema (
55+ [
56+ pa .field ("vector" , pa .list_ (pa .float32 (), list_size = self .dimensions )),
57+ pa .field ("content" , pa .string ()),
58+ ]
59+ )
4660
4761 # Create or overwrite the table for storing documents and embeddings
48- self .table = self .db .create_table ("documents" , schema = schema , mode = "overwrite" if overwrite else "append" )
62+ self .table = self .db .create_table (
63+ "documents" , schema = schema , mode = "overwrite" if overwrite else "append"
64+ )
4965
5066 def add_documents (self , documents : Sequence [Dict [str , Any ]]):
5167 """
@@ -63,13 +79,18 @@ def add_documents(self, documents: Sequence[Dict[str, Any]]):
6379 embeddings = self .embedder (input = doc_texts ).data
6480
6581 # Format embeddings for LanceDB
66- data = [{"vector" : embedding .embedding , "content" : text } for embedding , text in zip (embeddings , doc_texts )]
82+ data = [
83+ {"vector" : embedding .embedding , "content" : text }
84+ for embedding , text in zip (embeddings , doc_texts )
85+ ]
6786
6887 # Add data to LanceDB table
6988 self .table .add (data )
7089 log .info (f"Added { len (documents )} documents to the index" )
7190
72- def retrieve (self , query : Union [str , List [str ]], top_k : Optional [int ] = None ) -> List [RetrieverOutput ]:
91+ def retrieve (
92+ self , query : Union [str , List [str ]], top_k : Optional [int ] = None
93+ ) -> List [RetrieverOutput ]:
7394 """.
7495 Retrieve top-k documents from LanceDB for a given query or queries.
7596 Args:
@@ -83,11 +104,13 @@ def retrieve(self, query: Union[str, List[str]], top_k: Optional[int] = None) ->
83104 query = [query ]
84105
85106 if not query or (isinstance (query , str ) and query .strip () == "" ):
86- raise ValueError ("Query cannot be empty." )
107+ raise ValueError ("Query cannot be empty." )
87108
88109 # Check if table (index) exists before performing search
89110 if not self .table :
90- raise ValueError ("The index has not been initialized or the table is missing." )
111+ raise ValueError (
112+ "The index has not been initialized or the table is missing."
113+ )
91114
92115 query_embeddings = self .embedder (input = query ).data
93116 output : List [RetrieverOutput ] = []
@@ -105,9 +128,11 @@ def retrieve(self, query: Union[str, List[str]], top_k: Optional[int] = None) ->
105128 scores = results ["_distance" ].tolist ()
106129
107130 # Append results to output
108- output .append (RetrieverOutput (
109- doc_indices = indices ,
110- doc_scores = scores ,
111- query = query [0 ] if len (query ) == 1 else query
112- ))
131+ output .append (
132+ RetrieverOutput (
133+ doc_indices = indices ,
134+ doc_scores = scores ,
135+ query = query [0 ] if len (query ) == 1 else query ,
136+ )
137+ )
113138 return output
0 commit comments