Skip to content

Commit 1542b27

Browse files
committed
Add streaming endpoint
1 parent d148589 commit 1542b27

File tree

8 files changed

+723
-600
lines changed

8 files changed

+723
-600
lines changed

graphrag/api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
basic_search,
1414
basic_search_streaming,
1515
drift_search,
16+
drift_search_streaming,
1617
global_search,
1718
global_search_streaming,
1819
local_search,
@@ -29,6 +30,7 @@
2930
"local_search",
3031
"local_search_streaming",
3132
"drift_search",
33+
"drift_search_streaming",
3234
"basic_search",
3335
"basic_search_streaming",
3436
# prompt tuning API

graphrag/api/query.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,87 @@ async def local_search_streaming(
348348
yield stream_chunk
349349

350350

351+
@validate_call(config={"arbitrary_types_allowed": True})
352+
async def drift_search_streaming(
353+
config: GraphRagConfig,
354+
nodes: pd.DataFrame,
355+
entities: pd.DataFrame,
356+
community_reports: pd.DataFrame,
357+
text_units: pd.DataFrame,
358+
relationships: pd.DataFrame,
359+
community_level: int,
360+
response_type: str,
361+
query: str,
362+
) -> AsyncGenerator:
363+
"""Perform a DRIFT search and return the context data and response.
364+
365+
Parameters
366+
----------
367+
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
368+
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
369+
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
370+
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
371+
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
372+
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
373+
- community_level (int): The community level to search at.
374+
- query (str): The user query to search for.
375+
376+
Returns
377+
-------
378+
TODO: Document the search response type and format.
379+
380+
Raises
381+
------
382+
TODO: Document any exceptions to expect.
383+
"""
384+
vector_store_args = config.embeddings.vector_store
385+
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
386+
387+
description_embedding_store = _get_embedding_store(
388+
config_args=vector_store_args, # type: ignore
389+
embedding_name=entity_description_embedding,
390+
)
391+
392+
full_content_embedding_store = _get_embedding_store(
393+
config_args=vector_store_args, # type: ignore
394+
embedding_name=community_full_content_embedding,
395+
)
396+
397+
entities_ = read_indexer_entities(nodes, entities, community_level)
398+
reports = read_indexer_reports(community_reports, nodes, community_level)
399+
read_indexer_report_embeddings(reports, full_content_embedding_store)
400+
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
401+
reduce_prompt = _load_search_prompt(
402+
config.root_dir, config.drift_search.reduce_prompt
403+
)
404+
405+
search_engine = get_drift_search_engine(
406+
config=config,
407+
reports=reports,
408+
text_units=read_indexer_text_units(text_units),
409+
entities=entities_,
410+
relationships=read_indexer_relationships(relationships),
411+
description_embedding_store=description_embedding_store, # type: ignore
412+
local_system_prompt=prompt,
413+
reduce_system_prompt=reduce_prompt,
414+
response_type=response_type,
415+
)
416+
417+
search_result = search_engine.astream_search(query=query)
418+
419+
# when streaming results, a context data object is returned as the first result
420+
# and the query response in subsequent tokens
421+
context_data = None
422+
get_context_data = True
423+
async for stream_chunk in search_result:
424+
if get_context_data:
425+
context_data = _reformat_context_data(stream_chunk) # type: ignore
426+
yield context_data
427+
get_context_data = False
428+
else:
429+
yield stream_chunk
430+
431+
351432
@validate_call(config={"arbitrary_types_allowed": True})
352433
async def drift_search(
353434
config: GraphRagConfig,
@@ -401,7 +482,9 @@ async def drift_search(
401482
reports = read_indexer_reports(community_reports, nodes, community_level)
402483
read_indexer_report_embeddings(reports, full_content_embedding_store)
403484
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
404-
reduce_prompt = _load_search_prompt(config.root_dir, config.drift_search.reduce_prompt)
485+
reduce_prompt = _load_search_prompt(
486+
config.root_dir, config.drift_search.reduce_prompt
487+
)
405488

406489
search_engine = get_drift_search_engine(
407490
config=config,
@@ -419,15 +502,7 @@ async def drift_search(
419502
response = result.response
420503
context_data = _reformat_context_data(result.context_data) # type: ignore
421504

422-
# TODO: Map/reduce the response to a single string with a comprehensive answer including all follow-ups
423-
# For the time being, return highest scoring response (position 0) and context data
424-
match response:
425-
case dict():
426-
return response["nodes"][0]["answer"], context_data # type: ignore
427-
case str():
428-
return response, context_data
429-
case list():
430-
return response, context_data
505+
return response, context_data
431506

432507

433508
@validate_call(config={"arbitrary_types_allowed": True})

graphrag/cli/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def _query_cli(
460460
data_dir=data,
461461
root_dir=root,
462462
community_level=community_level,
463-
streaming=False, # Drift search does not support streaming (yet)
463+
streaming=streaming,
464464
response_type=response_type,
465465
query=query,
466466
)

graphrag/cli/query.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,33 @@ def run_drift_search(
235235

236236
# call the Query API
237237
if streaming:
238-
error_msg = "Streaming is not supported yet for DRIFT search."
239-
raise NotImplementedError(error_msg)
238+
239+
async def run_streaming_search():
240+
full_response = ""
241+
context_data = None
242+
get_context_data = True
243+
async for stream_chunk in api.drift_search_streaming(
244+
config=config,
245+
nodes=final_nodes,
246+
entities=final_entities,
247+
community_reports=final_community_reports,
248+
text_units=final_text_units,
249+
relationships=final_relationships,
250+
community_level=community_level,
251+
response_type=response_type,
252+
query=query,
253+
):
254+
if get_context_data:
255+
context_data = stream_chunk
256+
get_context_data = False
257+
else:
258+
full_response += stream_chunk
259+
print(stream_chunk, end="") # noqa: T201
260+
sys.stdout.flush() # flush output buffer to display text immediately
261+
print() # noqa: T201
262+
return full_response, context_data
263+
264+
return asyncio.run(run_streaming_search())
240265

241266
# not streaming
242267
response, context_data = asyncio.run(
@@ -283,8 +308,6 @@ def run_basic_search(
283308
)
284309
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
285310

286-
print(streaming) # noqa: T201
287-
288311
# # call the Query API
289312
if streaming:
290313

graphrag/prompts/query/drift_search_system_prompt.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106
107107
---Target response length and format---
108108
109-
Multiple paragraphs
109+
{response_type}
110110
111111
112112
---Goal---
@@ -133,8 +133,6 @@
133133
134134
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. Now answer the following query using the data above:
135135
136-
{query}
137-
138136
"""
139137

140138

graphrag/query/structured_search/drift_search/drift_context.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from graphrag.model.relationship import Relationship
1919
from graphrag.model.text_unit import TextUnit
2020
from graphrag.prompts.query.drift_search_system_prompt import (
21-
DRIFT_LOCAL_SYSTEM_PROMPT, DRIFT_REDUCE_PROMPT
21+
DRIFT_LOCAL_SYSTEM_PROMPT,
22+
DRIFT_REDUCE_PROMPT,
2223
)
2324
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
2425
from graphrag.query.llm.base import BaseTextEmbedding
@@ -52,7 +53,7 @@ def __init__(
5253
local_system_prompt: str | None = None,
5354
local_mixed_context: LocalSearchMixedContext | None = None,
5455
reduce_system_prompt: str | None = None,
55-
response_type: str | None = None
56+
response_type: str | None = None,
5657
):
5758
"""Initialize the DRIFT search context builder with necessary components."""
5859
self.config = config or DRIFTSearchConfig()

0 commit comments

Comments
 (0)