@@ -348,6 +348,87 @@ async def local_search_streaming(
348348 yield stream_chunk
349349
350350
351+ @validate_call (config = {"arbitrary_types_allowed" : True })
352+ async def drift_search_streaming (
353+ config : GraphRagConfig ,
354+ nodes : pd .DataFrame ,
355+ entities : pd .DataFrame ,
356+ community_reports : pd .DataFrame ,
357+ text_units : pd .DataFrame ,
358+ relationships : pd .DataFrame ,
359+ community_level : int ,
360+ response_type : str ,
361+ query : str ,
362+ ) -> AsyncGenerator :
363+ """Perform a DRIFT search and return the context data and response.
364+
365+ Parameters
366+ ----------
367+ - config (GraphRagConfig): A graphrag configuration (from settings.yaml)
368+ - nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
369+ - entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
370+ - community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
371+ - text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
372+ - relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
373+ - community_level (int): The community level to search at.
374+ - query (str): The user query to search for.
375+
376+ Returns
377+ -------
378+ TODO: Document the search response type and format.
379+
380+ Raises
381+ ------
382+ TODO: Document any exceptions to expect.
383+ """
384+ vector_store_args = config .embeddings .vector_store
385+ logger .info (f"Vector Store Args: { redact (vector_store_args )} " ) # type: ignore # noqa
386+
387+ description_embedding_store = _get_embedding_store (
388+ config_args = vector_store_args , # type: ignore
389+ embedding_name = entity_description_embedding ,
390+ )
391+
392+ full_content_embedding_store = _get_embedding_store (
393+ config_args = vector_store_args , # type: ignore
394+ embedding_name = community_full_content_embedding ,
395+ )
396+
397+ entities_ = read_indexer_entities (nodes , entities , community_level )
398+ reports = read_indexer_reports (community_reports , nodes , community_level )
399+ read_indexer_report_embeddings (reports , full_content_embedding_store )
400+ prompt = _load_search_prompt (config .root_dir , config .drift_search .prompt )
401+ reduce_prompt = _load_search_prompt (
402+ config .root_dir , config .drift_search .reduce_prompt
403+ )
404+
405+ search_engine = get_drift_search_engine (
406+ config = config ,
407+ reports = reports ,
408+ text_units = read_indexer_text_units (text_units ),
409+ entities = entities_ ,
410+ relationships = read_indexer_relationships (relationships ),
411+ description_embedding_store = description_embedding_store , # type: ignore
412+ local_system_prompt = prompt ,
413+ reduce_system_prompt = reduce_prompt ,
414+ response_type = response_type ,
415+ )
416+
417+ search_result = search_engine .astream_search (query = query )
418+
419+ # when streaming results, a context data object is returned as the first result
420+ # and the query response in subsequent tokens
421+ context_data = None
422+ get_context_data = True
423+ async for stream_chunk in search_result :
424+ if get_context_data :
425+ context_data = _reformat_context_data (stream_chunk ) # type: ignore
426+ yield context_data
427+ get_context_data = False
428+ else :
429+ yield stream_chunk
430+
431+
351432@validate_call (config = {"arbitrary_types_allowed" : True })
352433async def drift_search (
353434 config : GraphRagConfig ,
@@ -401,7 +482,9 @@ async def drift_search(
401482 reports = read_indexer_reports (community_reports , nodes , community_level )
402483 read_indexer_report_embeddings (reports , full_content_embedding_store )
403484 prompt = _load_search_prompt (config .root_dir , config .drift_search .prompt )
404- reduce_prompt = _load_search_prompt (config .root_dir , config .drift_search .reduce_prompt )
485+ reduce_prompt = _load_search_prompt (
486+ config .root_dir , config .drift_search .reduce_prompt
487+ )
405488
406489 search_engine = get_drift_search_engine (
407490 config = config ,
@@ -419,15 +502,7 @@ async def drift_search(
419502 response = result .response
420503 context_data = _reformat_context_data (result .context_data ) # type: ignore
421504
422- # TODO: Map/reduce the response to a single string with a comprehensive answer including all follow-ups
423- # For the time being, return highest scoring response (position 0) and context data
424- match response :
425- case dict ():
426- return response ["nodes" ][0 ]["answer" ], context_data # type: ignore
427- case str ():
428- return response , context_data
429- case list ():
430- return response , context_data
505+ return response , context_data
431506
432507
433508@validate_call (config = {"arbitrary_types_allowed" : True })
0 commit comments