|
| 1 | +import logging |
| 2 | +import numpy as np |
| 3 | +import pyarrow as pa |
| 4 | +import lancedb |
| 5 | +from typing import List, Optional, Sequence, Union, Dict, Any |
| 6 | +from adalflow.core.embedder import Embedder |
| 7 | +from adalflow.core.retriever import Retriever |
| 8 | +from adalflow.core.types import RetrieverOutput |
| 9 | + |
| 10 | +# Initialize logging |
| 11 | +log = logging.getLogger(__name__) |
| 12 | + |
| 13 | +# Defined data types |
| 14 | +LanceDBRetrieverDocumentEmbeddingType = Union[ |
| 15 | + List[float], np.ndarray |
| 16 | +] # single embedding |
| 17 | +LanceDBRetrieverDocumentsType = Sequence[LanceDBRetrieverDocumentEmbeddingType] |
| 18 | + |
| 19 | + |
| 20 | +# Step 2: Define the LanceDBRetriever class |
| 21 | +class LanceDBRetriever( |
| 22 | + Retriever[LanceDBRetrieverDocumentEmbeddingType, Union[str, List[str]]] |
| 23 | +): |
| 24 | + def __init__( |
| 25 | + self, |
| 26 | + embedder: Embedder, |
| 27 | + dimensions: int, |
| 28 | + db_uri: str = "/tmp/lancedb", |
| 29 | + top_k: int = 5, |
| 30 | + overwrite: bool = True, |
| 31 | + ): |
| 32 | + """ |
| 33 | + LanceDBRetriever is a retriever that leverages LanceDB to efficiently store and query document embeddings. |
| 34 | +
|
| 35 | + Attributes: |
| 36 | + embedder (Embedder): An instance of the Embedder class used for computing embeddings. |
| 37 | + dimensions (int): The dimensionality of the embeddings used. |
| 38 | + db_uri (str): The URI of the LanceDB storage (default is "/tmp/lancedb"). |
| 39 | + top_k (int): The number of top results to retrieve for a given query (default is 5). |
| 40 | + overwrite (bool): If True, the existing table is overwritten; otherwise, new documents are appended. |
| 41 | +
|
| 42 | + This retriever supports adding documents with their embeddings to a LanceDB storage and retrieving relevant documents based on a given query. |
| 43 | +
|
| 44 | + More information on LanceDB can be found here:(https://github.com/lancedb/lancedb) |
| 45 | + Documentations: https://lancedb.github.io/lancedb/ |
| 46 | + """ |
| 47 | + super().__init__() |
| 48 | + self.db = lancedb.connect(db_uri) |
| 49 | + self.embedder = embedder |
| 50 | + self.top_k = top_k |
| 51 | + self.dimensions = dimensions |
| 52 | + |
| 53 | + # Define table schema with vector field for embeddings |
| 54 | + schema = pa.schema( |
| 55 | + [ |
| 56 | + pa.field("vector", pa.list_(pa.float32(), list_size=self.dimensions)), |
| 57 | + pa.field("content", pa.string()), |
| 58 | + ] |
| 59 | + ) |
| 60 | + |
| 61 | + # Create or overwrite the table for storing documents and embeddings |
| 62 | + self.table = self.db.create_table( |
| 63 | + "documents", schema=schema, mode="overwrite" if overwrite else "append" |
| 64 | + ) |
| 65 | + |
| 66 | + def add_documents(self, documents: Sequence[Dict[str, Any]]): |
| 67 | + """ |
| 68 | + Adds documents with pre-computed embeddings to LanceDB. |
| 69 | + Args: |
| 70 | + documents (Sequence[Dict[str, Any]]): A sequence of documents, each with a 'content' field containing text. |
| 71 | +
|
| 72 | + """ |
| 73 | + if not documents: |
| 74 | + log.warning("No documents provided for embedding") |
| 75 | + return |
| 76 | + |
| 77 | + # Embed document content using Embedder |
| 78 | + doc_texts = [doc["content"] for doc in documents] |
| 79 | + embeddings = self.embedder(input=doc_texts).data |
| 80 | + |
| 81 | + # Format embeddings for LanceDB |
| 82 | + data = [ |
| 83 | + {"vector": embedding.embedding, "content": text} |
| 84 | + for embedding, text in zip(embeddings, doc_texts) |
| 85 | + ] |
| 86 | + |
| 87 | + # Add data to LanceDB table |
| 88 | + self.table.add(data) |
| 89 | + log.info(f"Added {len(documents)} documents to the index") |
| 90 | + |
| 91 | + def retrieve( |
| 92 | + self, query: Union[str, List[str]], top_k: Optional[int] = None |
| 93 | + ) -> List[RetrieverOutput]: |
| 94 | + """. |
| 95 | + Retrieve top-k documents from LanceDB for a given query or queries. |
| 96 | + Args: |
| 97 | + query (Union[str, List[str]]): A query string or a list of query strings. |
| 98 | + top_k (Optional[int]): The number of top documents to retrieve (if not specified, defaults to the instance's top_k). |
| 99 | +
|
| 100 | + Returns: |
| 101 | + List[RetrieverOutput]: A list of RetrieverOutput containing the indices and scores of the retrieved documents. |
| 102 | + """ |
| 103 | + if isinstance(query, str): |
| 104 | + query = [query] |
| 105 | + |
| 106 | + if not query or (isinstance(query, str) and query.strip() == ""): |
| 107 | + raise ValueError("Query cannot be empty.") |
| 108 | + |
| 109 | + # Check if table (index) exists before performing search |
| 110 | + if not self.table: |
| 111 | + raise ValueError( |
| 112 | + "The index has not been initialized or the table is missing." |
| 113 | + ) |
| 114 | + |
| 115 | + query_embeddings = self.embedder(input=query).data |
| 116 | + output: List[RetrieverOutput] = [] |
| 117 | + |
| 118 | + # Perform search in LanceDB for each query |
| 119 | + for query_emb in query_embeddings: |
| 120 | + results = ( |
| 121 | + self.table.search(query_emb.embedding) |
| 122 | + .limit(top_k or self.top_k) |
| 123 | + .to_pandas() |
| 124 | + ) |
| 125 | + |
| 126 | + # Gather indices and scores from search results |
| 127 | + indices = results.index.tolist() |
| 128 | + scores = results["_distance"].tolist() |
| 129 | + |
| 130 | + # Append results to output |
| 131 | + output.append( |
| 132 | + RetrieverOutput( |
| 133 | + doc_indices=indices, |
| 134 | + doc_scores=scores, |
| 135 | + query=query[0] if len(query) == 1 else query, |
| 136 | + ) |
| 137 | + ) |
| 138 | + return output |
0 commit comments