@@ -36,7 +36,7 @@ class Node(Serializable):
36
36
id : Optional [str ]
37
37
"""Unique ID for the node. Shall be generated by the KnowledgeStore if not set"""
38
38
metadata : dict = Field (default_factory = dict )
39
- """Metadata for the node. May contain information used to link this node
39
+ """Metadata for the node. May contain information used to link this node
40
40
with other nodes."""
41
41
42
42
@@ -164,7 +164,7 @@ async def aadd_documents(
164
164
return await self .aadd_nodes (nodes , ** kwargs )
165
165
166
166
@abstractmethod
167
- def traversing_retrieve (
167
+ def traversal_search (
168
168
self ,
169
169
query : str ,
170
170
* ,
@@ -187,7 +187,7 @@ def traversing_retrieve(
187
187
Retrieved documents.
188
188
"""
189
189
190
- async def atraversing_retrieve (
190
+ async def atraversal_search (
191
191
self ,
192
192
query : str ,
193
193
* ,
@@ -210,42 +210,121 @@ async def atraversing_retrieve(
210
210
Retrieved documents.
211
211
"""
212
212
for doc in await run_in_executor (
213
- None , self .traversing_retrieve , query , k = k , depth = depth , ** kwargs
213
+ None , self .traversal_search , query , k = k , depth = depth , ** kwargs
214
214
):
215
215
yield doc
216
216
217
- def similarity_search (
218
- self , query : str , k : int = 4 , ** kwargs : Any
219
- ) -> List [Document ]:
220
- return list (self .traversing_retrieve (query , k = k , depth = 0 ))
217
+ @abstractmethod
218
+ def mmr_traversal_search (
219
+ self ,
220
+ query : str ,
221
+ * ,
222
+ k : int = 4 ,
223
+ depth : int = 2 ,
224
+ fetch_k : int = 100 ,
225
+ lambda_mult : float = 0.5 ,
226
+ score_threshold : float = 0.0 ,
227
+ ** kwargs : Any ,
228
+ ) -> Iterable [Document ]:
229
+ """Retrieve documents from this knowledge store using MMR-traversal.
221
230
222
- async def asimilarity_search (
223
- self , query : str , k : int = 4 , ** kwargs : Any
224
- ) -> List [Document ]:
225
- return [doc async for doc in self .atraversing_retrieve (query , k = k , depth = 0 )]
231
+ This strategy first retrieves the top `fetch_k` results by similarity to
232
+ the question. It then selects the top `k` results based on
233
+ maximum-marginal relevance using the given `lambda_mult`.
234
+
235
+ At each step, it considers the (remaining) documents from `fetch_k` as
236
+ well as any documents connected by edges to a selected document
237
+ retrieved based on similarity (a "root").
238
+
239
+ Args:
240
+ query: The query string to search for.
241
+ k: Number of Documents to return. Defaults to 4.
242
+ fetch_k: Number of Documents to fetch via similarity.
243
+ Defaults to 10.
244
+ depth: Maximum depth of a node (number of edges) from a node
245
+ retrieved via similarity. Defaults to 2.
246
+ lambda_mult: Number between 0 and 1 that determines the degree
247
+ of diversity among the results with 0 corresponding to maximum
248
+ diversity and 1 to minimum diversity. Defaults to 0.5.
249
+ score_threshold: Only documents with a score greater than or equal
250
+ this threshold will be chosen. Defaults to 0.0 so all scores are
251
+ taken.
252
+ """
253
+
254
+ async def ammr_traversal_search (
255
+ self ,
256
+ query : str ,
257
+ * ,
258
+ k : int = 4 ,
259
+ depth : int = 2 ,
260
+ fetch_k : int = 100 ,
261
+ lambda_mult : float = 0.5 ,
262
+ score_threshold : float = 0.0 ,
263
+ ** kwargs : Any ,
264
+ ) -> AsyncIterable [Document ]:
265
+ """Retrieve documents from this knowledge store using MMR-traversal.
266
+
267
+ This strategy first retrieves the top `fetch_k` results by similarity to
268
+ the question. It then selects the top `k` results based on
269
+ maximum-marginal relevance using the given `lambda_mult`.
270
+
271
+ At each step, it considers the (remaining) documents from `fetch_k` as
272
+ well as any documents connected by edges to a selected document
273
+ retrieved based on similarity (a "root").
274
+
275
+ Args:
276
+ query: The query string to search for.
277
+ k: Number of Documents to return. Defaults to 4.
278
+ fetch_k: Number of Documents to fetch via similarity.
279
+ Defaults to 10.
280
+ depth: Maximum depth of a node (number of edges) from a node
281
+ retrieved via similarity. Defaults to 2.
282
+ lambda_mult: Number between 0 and 1 that determines the degree
283
+ of diversity among the results with 0 corresponding to maximum
284
+ diversity and 1 to minimum diversity. Defaults to 0.5.
285
+ score_threshold: Only documents with a score greater than or equal
286
+ this threshold will be chosen. Defaults to 0.0 so all scores are
287
+ taken.
288
+ """
289
+ for doc in await run_in_executor (
290
+ None ,
291
+ self .traversal_search ,
292
+ query ,
293
+ k = k ,
294
+ fetch_k = fetch_k ,
295
+ depth = depth ,
296
+ lambda_mult = lambda_mult ,
297
+ score_threshold = score_threshold ,
298
+ ** kwargs ,
299
+ ):
300
+ yield doc
301
+
302
+ def similarity_search (self , query : str , k : int = 4 , ** kwargs : Any ) -> List [Document ]:
303
+ return list (self .traversal_search (query , k = k , depth = 0 ))
304
+
305
+ async def asimilarity_search (self , query : str , k : int = 4 , ** kwargs : Any ) -> List [Document ]:
306
+ return [doc async for doc in self .atraversal_search (query , k = k , depth = 0 )]
226
307
227
308
def search (self , query : str , search_type : str , ** kwargs : Any ) -> List [Document ]:
228
309
if search_type == "similarity" :
229
310
return self .similarity_search (query , ** kwargs )
230
311
elif search_type == "similarity_score_threshold" :
231
- docs_and_similarities = self .similarity_search_with_relevance_scores (
232
- query , ** kwargs
233
- )
312
+ docs_and_similarities = self .similarity_search_with_relevance_scores (query , ** kwargs )
234
313
return [doc for doc , _ in docs_and_similarities ]
235
314
elif search_type == "mmr" :
236
315
return self .max_marginal_relevance_search (query , ** kwargs )
237
316
elif search_type == "traversal" :
238
- return list (self .traversing_retrieve (query , ** kwargs ))
317
+ return list (self .traversal_search (query , ** kwargs ))
318
+ elif search_type == "mmr_traversal" :
319
+ return list (self .mmr_traversal_search (query , ** kwargs ))
239
320
else :
240
321
raise ValueError (
241
322
f"search_type of { search_type } not allowed. Expected "
242
323
"search_type to be 'similarity', 'similarity_score_threshold', "
243
324
"'mmr' or 'traversal'."
244
325
)
245
326
246
- async def asearch (
247
- self , query : str , search_type : str , ** kwargs : Any
248
- ) -> List [Document ]:
327
+ async def asearch (self , query : str , search_type : str , ** kwargs : Any ) -> List [Document ]:
249
328
if search_type == "similarity" :
250
329
return await self .asimilarity_search (query , ** kwargs )
251
330
elif search_type == "similarity_score_threshold" :
@@ -256,7 +335,7 @@ async def asearch(
256
335
elif search_type == "mmr" :
257
336
return await self .amax_marginal_relevance_search (query , ** kwargs )
258
337
elif search_type == "traversal" :
259
- return [doc async for doc in self .atraversing_retrieve (query , ** kwargs )]
338
+ return [doc async for doc in self .atraversal_search (query , ** kwargs )]
260
339
else :
261
340
raise ValueError (
262
341
f"search_type of { search_type } not allowed. Expected "
@@ -334,15 +413,16 @@ class KnowledgeStoreRetriever(VectorStoreRetriever):
334
413
"similarity_score_threshold" ,
335
414
"mmr" ,
336
415
"traversal" ,
416
+ "mmr_traversal" ,
337
417
)
338
418
339
419
def _get_relevant_documents (
340
420
self , query : str , * , run_manager : CallbackManagerForRetrieverRun
341
421
) -> List [Document ]:
342
422
if self .search_type == "traversal" :
343
- return list (
344
- self .vectorstore . traversing_retrieve ( query , ** self . search_kwargs )
345
- )
423
+ return list (self . vectorstore . traversal_search ( query , ** self . search_kwargs ))
424
+ elif self .search_type == "mmr_traversal" :
425
+ return list ( self . vectorstore . traversal_search ( query , ** self . search_kwargs ) )
346
426
else :
347
427
return super ()._get_relevant_documents (query , run_manager = run_manager )
348
428
@@ -352,11 +432,12 @@ async def _aget_relevant_documents(
352
432
if self .search_type == "traversal" :
353
433
return [
354
434
doc
355
- async for doc in self .vectorstore .atraversing_retrieve (
356
- query , ** self .search_kwargs
357
- )
435
+ async for doc in self .vectorstore .atraversal_search (query , ** self .search_kwargs )
436
+ ]
437
+ elif self .search_type == "mmr_traversal" :
438
+ return [
439
+ doc
440
+ async for doc in self .vectorstore .ammr_traversal_search (query , ** self .search_kwargs )
358
441
]
359
442
else :
360
- return await super ()._aget_relevant_documents (
361
- query , run_manager = run_manager
362
- )
443
+ return await super ()._aget_relevant_documents (query , run_manager = run_manager )
0 commit comments