66import asyncio
77import sys
88from pathlib import Path
9-
10- import pandas as pd
9+ from typing import TYPE_CHECKING , Any
1110
1211import graphrag .api as api
1312from graphrag .config .load_config import load_config
1615from graphrag .storage .factory import StorageFactory
1716from graphrag .utils .storage import load_table_from_storage , storage_has_table
1817
18+ if TYPE_CHECKING :
19+ import pandas as pd
20+
1921logger = PrintProgressLogger ("" )
2022
2123
@@ -49,14 +51,43 @@ def run_global_search(
4951 ],
5052 optional_list = [],
5153 )
54+
55+ # Call the Multi-Index Global Search API
56+ if dataframe_dict ["multi-index" ]:
57+ final_nodes_list = dataframe_dict ["create_final_nodes" ]
58+ final_entities_list = dataframe_dict ["create_final_entities" ]
59+ final_communities_list = dataframe_dict ["create_final_communities" ]
60+ final_community_reports_list = dataframe_dict ["create_final_community_reports" ]
61+ index_names = dataframe_dict ["index_names" ]
62+
63+ response , context_data = asyncio .run (
64+ api .multi_index_global_search (
65+ config = config ,
66+ nodes_list = final_nodes_list ,
67+ entities_list = final_entities_list ,
68+ communities_list = final_communities_list ,
69+ community_reports_list = final_community_reports_list ,
70+ index_names = index_names ,
71+ community_level = community_level ,
72+ dynamic_community_selection = dynamic_community_selection ,
73+ response_type = response_type ,
74+ streaming = streaming ,
75+ query = query ,
76+ )
77+ )
78+ logger .success (f"Global Search Response:\n { response } " )
79+ # NOTE: we return the response and context data here purely as a complete demonstration of the API.
80+ # External users should use the API directly to get the response and context data.
81+ return response , context_data
82+
83+ # Otherwise, call the Single-Index Global Search API
5284 final_nodes : pd .DataFrame = dataframe_dict ["create_final_nodes" ]
5385 final_entities : pd .DataFrame = dataframe_dict ["create_final_entities" ]
5486 final_communities : pd .DataFrame = dataframe_dict ["create_final_communities" ]
5587 final_community_reports : pd .DataFrame = dataframe_dict [
5688 "create_final_community_reports"
5789 ]
5890
59- # call the Query API
6091 if streaming :
6192
6293 async def run_streaming_search ():
@@ -137,6 +168,46 @@ def run_local_search(
137168 "create_final_covariates" ,
138169 ],
139170 )
171+ # Call the Multi-Index Local Search API
172+ if dataframe_dict ["multi-index" ]:
173+ final_nodes_list = dataframe_dict ["create_final_nodes" ]
174+ final_entities_list = dataframe_dict ["create_final_entities" ]
175+ final_community_reports_list = dataframe_dict ["create_final_community_reports" ]
176+ final_text_units_list = dataframe_dict ["create_final_text_units" ]
177+ final_relationships_list = dataframe_dict ["create_final_relationships" ]
178+ index_names = dataframe_dict ["index_names" ]
179+
180+ # If any covariates tables are missing from any index, set the covariates list to None
181+ if (
182+ len (dataframe_dict ["create_final_covariates" ])
183+ != dataframe_dict ["num_indexes" ]
184+ ):
185+ final_covariates_list = None
186+ else :
187+ final_covariates_list = dataframe_dict ["create_final_covariates" ]
188+
189+ response , context_data = asyncio .run (
190+ api .multi_index_local_search (
191+ config = config ,
192+ nodes_list = final_nodes_list ,
193+ entities_list = final_entities_list ,
194+ community_reports_list = final_community_reports_list ,
195+ text_units_list = final_text_units_list ,
196+ relationships_list = final_relationships_list ,
197+ covariates_list = final_covariates_list ,
198+ index_names = index_names ,
199+ community_level = community_level ,
200+ response_type = response_type ,
201+ streaming = streaming ,
202+ query = query ,
203+ )
204+ )
205+ logger .success (f"Local Search Response:\n { response } " )
206+ # NOTE: we return the response and context data here purely as a complete demonstration of the API.
207+ # External users should use the API directly to get the response and context data.
208+ return response , context_data
209+
210+ # Otherwise, call the Single-Index Local Search API
140211 final_nodes : pd .DataFrame = dataframe_dict ["create_final_nodes" ]
141212 final_community_reports : pd .DataFrame = dataframe_dict [
142213 "create_final_community_reports"
@@ -146,7 +217,6 @@ def run_local_search(
146217 final_entities : pd .DataFrame = dataframe_dict ["create_final_entities" ]
147218 final_covariates : pd .DataFrame | None = dataframe_dict ["create_final_covariates" ]
148219
149- # call the Query API
150220 if streaming :
151221
152222 async def run_streaming_search ():
@@ -226,6 +296,37 @@ def run_drift_search(
226296 "create_final_entities" ,
227297 ],
228298 )
299+
300+ # Call the Multi-Index Drift Search API
301+ if dataframe_dict ["multi-index" ]:
302+ final_nodes_list = dataframe_dict ["create_final_nodes" ]
303+ final_entities_list = dataframe_dict ["create_final_entities" ]
304+ final_community_reports_list = dataframe_dict ["create_final_community_reports" ]
305+ final_text_units_list = dataframe_dict ["create_final_text_units" ]
306+ final_relationships_list = dataframe_dict ["create_final_relationships" ]
307+ index_names = dataframe_dict ["index_names" ]
308+
309+ response , context_data = asyncio .run (
310+ api .multi_index_drift_search (
311+ config = config ,
312+ nodes_list = final_nodes_list ,
313+ entities_list = final_entities_list ,
314+ community_reports_list = final_community_reports_list ,
315+ text_units_list = final_text_units_list ,
316+ relationships_list = final_relationships_list ,
317+ index_names = index_names ,
318+ community_level = community_level ,
319+ response_type = response_type ,
320+ streaming = streaming ,
321+ query = query ,
322+ )
323+ )
324+ logger .success (f"DRIFT Search Response:\n { response } " )
325+ # NOTE: we return the response and context data here purely as a complete demonstration of the API.
326+ # External users should use the API directly to get the response and context data.
327+ return response , context_data
328+
329+ # Otherwise, call the Single-Index Drift Search API
229330 final_nodes : pd .DataFrame = dataframe_dict ["create_final_nodes" ]
230331 final_community_reports : pd .DataFrame = dataframe_dict [
231332 "create_final_community_reports"
@@ -234,7 +335,6 @@ def run_drift_search(
234335 final_relationships : pd .DataFrame = dataframe_dict ["create_final_relationships" ]
235336 final_entities : pd .DataFrame = dataframe_dict ["create_final_entities" ]
236337
237- # call the Query API
238338 if streaming :
239339
240340 async def run_streaming_search ():
@@ -308,9 +408,29 @@ def run_basic_search(
308408 "create_final_text_units" ,
309409 ],
310410 )
411+
412+ # Call the Multi-Index Basic Search API
413+ if dataframe_dict ["multi-index" ]:
414+ final_text_units_list = dataframe_dict ["create_final_text_units" ]
415+ index_names = dataframe_dict ["index_names" ]
416+
417+ response , context_data = asyncio .run (
418+ api .multi_index_basic_search (
419+ config = config ,
420+ text_units_list = final_text_units_list ,
421+ index_names = index_names ,
422+ streaming = streaming ,
423+ query = query ,
424+ )
425+ )
426+ logger .success (f"Basic Search Response:\n { response } " )
427+ # NOTE: we return the response and context data here purely as a complete demonstration of the API.
428+ # External users should use the API directly to get the response and context data.
429+ return response , context_data
430+
431+ # Otherwise, call the Single-Index Basic Search API
311432 final_text_units : pd .DataFrame = dataframe_dict ["create_final_text_units" ]
312433
313- # # call the Query API
314434 if streaming :
315435
316436 async def run_streaming_search ():
@@ -351,9 +471,46 @@ def _resolve_output_files(
351471 config : GraphRagConfig ,
352472 output_list : list [str ],
353473 optional_list : list [str ] | None = None ,
354- ) -> dict [str , pd . DataFrame ]:
474+ ) -> dict [str , Any ]:
355475 """Read indexing output files to a dataframe dict."""
356476 dataframe_dict = {}
477+
478+ # Loading output files for multi-index search
479+ if config .outputs :
480+ dataframe_dict ["multi-index" ] = True
481+ dataframe_dict ["num_indexes" ] = len (config .outputs )
482+ dataframe_dict ["index_names" ] = config .outputs .keys ()
483+ for output in config .outputs .values ():
484+ output_config = output .model_dump ()
485+ storage_obj = StorageFactory ().create_storage (
486+ storage_type = output_config ["type" ], kwargs = output_config
487+ )
488+ for name in output_list :
489+ if name not in dataframe_dict :
490+ dataframe_dict [name ] = []
491+ df_value = asyncio .run (
492+ load_table_from_storage (name = name , storage = storage_obj )
493+ )
494+ dataframe_dict [name ].append (df_value )
495+
496+ # for optional output files, do not append if the dataframe does not exist
497+ if optional_list :
498+ for optional_file in optional_list :
499+ if optional_file not in dataframe_dict :
500+ dataframe_dict [optional_file ] = []
501+ file_exists = asyncio .run (
502+ storage_has_table (optional_file , storage_obj )
503+ )
504+ if file_exists :
505+ df_value = asyncio .run (
506+ load_table_from_storage (
507+ name = optional_file , storage = storage_obj
508+ )
509+ )
510+ dataframe_dict [optional_file ].append (df_value )
511+ return dataframe_dict
512+ # Loading output files for single-index search
513+ dataframe_dict ["multi-index" ] = False
357514 output_config = config .output .model_dump () # type: ignore
358515 storage_obj = StorageFactory ().create_storage (
359516 storage_type = output_config ["type" ], kwargs = output_config
@@ -373,5 +530,4 @@ def _resolve_output_files(
373530 dataframe_dict [optional_file ] = df_value
374531 else :
375532 dataframe_dict [optional_file ] = None
376-
377533 return dataframe_dict
0 commit comments