@@ -27,7 +27,7 @@ def __init__(
2727 pass
2828
2929 @abstractmethod
30- async def retrieve (self , partition : list [str ], query : str ) -> list [Document ]:
30+ async def retrieve (self , partition : list [str ], query : str , filter : dict | None = None ) -> list [Document ]:
3131 pass
3232
3333
@@ -43,13 +43,15 @@ async def retrieve(
4343 self ,
4444 partition : list [str ],
4545 query : str ,
46+ filter : dict | None = None ,
4647 ) -> list [Document ]:
4748 db = get_vectordb ()
4849 chunks = await db .async_search .remote (
4950 query = query ,
5051 partition = partition ,
5152 top_k = self .top_k ,
5253 similarity_threshold = self .similarity_threshold ,
54+ filter = filter ,
5355 with_surrounding_chunks = self .with_surrounding_chunks ,
5456 )
5557 return chunks
@@ -78,7 +80,7 @@ def __init__(
7880 prompt : ChatPromptTemplate = ChatPromptTemplate .from_template (MULTI_QUERY_PROMPT )
7981 self .generate_queries = prompt | llm | StrOutputParser () | (lambda x : x .split ("[SEP]" ))
8082
81- async def retrieve (self , partition : list [str ], query : str ) -> list [Document ]:
83+ async def retrieve (self , partition : list [str ], query : str , filter : dict | None = None ) -> list [Document ]:
8284 db = get_vectordb ()
8385 logger .debug ("Generating multiple queries" , k_queries = self .k_queries )
8486 generated_queries = await self .generate_queries .ainvoke (
@@ -92,6 +94,7 @@ async def retrieve(self, partition: list[str], query: str) -> list[Document]:
9294 partition = partition ,
9395 top_k_per_query = self .top_k ,
9496 similarity_threshold = self .similarity_threshold ,
97+ filter = filter ,
9598 with_surrounding_chunks = self .with_surrounding_chunks ,
9699 )
97100 return chunks
@@ -121,7 +124,7 @@ async def get_hyde(self, query: str):
121124 hyde_document = await self .hyde_generator .ainvoke ({"query" : query })
122125 return hyde_document
123126
124- async def retrieve (self , partition : list [str ], query : str ) -> list [Document ]:
127+ async def retrieve (self , partition : list [str ], query : str , filter : dict | None = None ) -> list [Document ]:
125128 db = get_vectordb ()
126129 hyde = await self .get_hyde (query )
127130 queries = [hyde ]
@@ -133,6 +136,7 @@ async def retrieve(self, partition: list[str], query: str) -> list[Document]:
133136 partition = partition ,
134137 top_k_per_query = self .top_k ,
135138 similarity_threshold = self .similarity_threshold ,
139+ filter = filter ,
136140 with_surrounding_chunks = self .with_surrounding_chunks ,
137141 )
138142
0 commit comments