Skip to content

Commit 83cc2da

Browse files
Multi-index query CLI support (#1675)
* Add vector store id reference to embeddings config. * changed structure of output config section * added cli integration for multi index global * added cli integration for multi index local * added cli integration for multi index drift and basic * finished local testing of multi-index cli * ruff fixes * partially refactored test code to align with new output section * more test changes for new output structure * semversioner * refactored to align with new multi index config proposal * locally tested new multi-index output proposal * cleaned up tests to align with new structure --------- Co-authored-by: Derek Worthen <[email protected]>
1 parent 0805924 commit 83cc2da

File tree

8 files changed

+227
-42
lines changed

8 files changed

+227
-42
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "multi index query cli support"
4+
}

graphrag/api/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
global_search_streaming,
1919
local_search,
2020
local_search_streaming,
21+
multi_index_basic_search,
22+
multi_index_drift_search,
23+
multi_index_global_search,
24+
multi_index_local_search,
2125
)
2226
from graphrag.prompt_tune.types import DocSelectionType
2327

@@ -33,6 +37,10 @@
3337
"drift_search_streaming",
3438
"basic_search",
3539
"basic_search_streaming",
40+
"multi_index_basic_search",
41+
"multi_index_drift_search",
42+
"multi_index_global_search",
43+
"multi_index_local_search",
3644
# prompt tuning API
3745
"DocSelectionType",
3846
"generate_indexing_prompts",

graphrag/api/query.py

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -220,13 +220,10 @@ async def multi_index_global_search(
220220
response_type: str,
221221
streaming: bool,
222222
query: str,
223-
) -> (
224-
tuple[
225-
str | dict[str, Any] | list[dict[str, Any]],
226-
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
227-
]
228-
| AsyncGenerator
229-
):
223+
) -> tuple[
224+
str | dict[str, Any] | list[dict[str, Any]],
225+
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
226+
]:
230227
"""Perform a global search across multiple indexes and return the context data and response.
231228
232229
Parameters
@@ -422,7 +419,6 @@ async def local_search(
422419
entities_ = read_indexer_entities(nodes, entities, community_level)
423420
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
424421
prompt = load_search_prompt(config.root_dir, config.local_search.prompt)
425-
426422
search_engine = get_local_search_engine(
427423
config=config,
428424
reports=read_indexer_reports(community_reports, nodes, community_level),
@@ -531,13 +527,10 @@ async def multi_index_local_search(
531527
response_type: str,
532528
streaming: bool,
533529
query: str,
534-
) -> (
535-
tuple[
536-
str | dict[str, Any] | list[dict[str, Any]],
537-
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
538-
]
539-
| AsyncGenerator
540-
):
530+
) -> tuple[
531+
str | dict[str, Any] | list[dict[str, Any]],
532+
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
533+
]:
541534
"""Perform a local search across multiple indexes and return the context data and response.
542535
543536
Parameters
@@ -584,7 +577,6 @@ async def multi_index_local_search(
584577
"relationships": -1,
585578
"covariates": 0,
586579
}
587-
588580
community_reports_dfs = []
589581
entities_dfs = []
590582
nodes_dfs = []
@@ -732,7 +724,6 @@ async def multi_index_local_search(
732724
covariates_combined = pd.concat(
733725
covariates_dfs, axis=0, ignore_index=True, sort=False
734726
)
735-
736727
result = await local_search(
737728
config,
738729
nodes=nodes_combined,
@@ -927,13 +918,10 @@ async def multi_index_drift_search(
927918
response_type: str,
928919
streaming: bool,
929920
query: str,
930-
) -> (
931-
tuple[
932-
str | dict[str, Any] | list[dict[str, Any]],
933-
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
934-
]
935-
| AsyncGenerator
936-
):
921+
) -> tuple[
922+
str | dict[str, Any] | list[dict[str, Any]],
923+
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
924+
]:
937925
"""Perform a DRIFT search across multiple indexes and return the context data and response.
938926
939927
Parameters
@@ -1240,13 +1228,10 @@ async def multi_index_basic_search(
12401228
index_names: list[str],
12411229
streaming: bool,
12421230
query: str,
1243-
) -> (
1244-
tuple[
1245-
str | dict[str, Any] | list[dict[str, Any]],
1246-
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
1247-
]
1248-
| AsyncGenerator
1249-
):
1231+
) -> tuple[
1232+
str | dict[str, Any] | list[dict[str, Any]],
1233+
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
1234+
]:
12501235
"""Perform a basic search across multiple indexes and return the context data and response.
12511236
12521237
Parameters

graphrag/cli/query.py

Lines changed: 164 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import asyncio
77
import sys
88
from pathlib import Path
9-
10-
import pandas as pd
9+
from typing import TYPE_CHECKING, Any
1110

1211
import graphrag.api as api
1312
from graphrag.config.load_config import load_config
@@ -16,6 +15,9 @@
1615
from graphrag.storage.factory import StorageFactory
1716
from graphrag.utils.storage import load_table_from_storage, storage_has_table
1817

18+
if TYPE_CHECKING:
19+
import pandas as pd
20+
1921
logger = PrintProgressLogger("")
2022

2123

@@ -49,14 +51,43 @@ def run_global_search(
4951
],
5052
optional_list=[],
5153
)
54+
55+
# Call the Multi-Index Global Search API
56+
if dataframe_dict["multi-index"]:
57+
final_nodes_list = dataframe_dict["create_final_nodes"]
58+
final_entities_list = dataframe_dict["create_final_entities"]
59+
final_communities_list = dataframe_dict["create_final_communities"]
60+
final_community_reports_list = dataframe_dict["create_final_community_reports"]
61+
index_names = dataframe_dict["index_names"]
62+
63+
response, context_data = asyncio.run(
64+
api.multi_index_global_search(
65+
config=config,
66+
nodes_list=final_nodes_list,
67+
entities_list=final_entities_list,
68+
communities_list=final_communities_list,
69+
community_reports_list=final_community_reports_list,
70+
index_names=index_names,
71+
community_level=community_level,
72+
dynamic_community_selection=dynamic_community_selection,
73+
response_type=response_type,
74+
streaming=streaming,
75+
query=query,
76+
)
77+
)
78+
logger.success(f"Global Search Response:\n{response}")
79+
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
80+
# External users should use the API directly to get the response and context data.
81+
return response, context_data
82+
83+
# Otherwise, call the Single-Index Global Search API
5284
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
5385
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
5486
final_communities: pd.DataFrame = dataframe_dict["create_final_communities"]
5587
final_community_reports: pd.DataFrame = dataframe_dict[
5688
"create_final_community_reports"
5789
]
5890

59-
# call the Query API
6091
if streaming:
6192

6293
async def run_streaming_search():
@@ -137,6 +168,46 @@ def run_local_search(
137168
"create_final_covariates",
138169
],
139170
)
171+
# Call the Multi-Index Local Search API
172+
if dataframe_dict["multi-index"]:
173+
final_nodes_list = dataframe_dict["create_final_nodes"]
174+
final_entities_list = dataframe_dict["create_final_entities"]
175+
final_community_reports_list = dataframe_dict["create_final_community_reports"]
176+
final_text_units_list = dataframe_dict["create_final_text_units"]
177+
final_relationships_list = dataframe_dict["create_final_relationships"]
178+
index_names = dataframe_dict["index_names"]
179+
180+
# If any covariates tables are missing from any index, set the covariates list to None
181+
if (
182+
len(dataframe_dict["create_final_covariates"])
183+
!= dataframe_dict["num_indexes"]
184+
):
185+
final_covariates_list = None
186+
else:
187+
final_covariates_list = dataframe_dict["create_final_covariates"]
188+
189+
response, context_data = asyncio.run(
190+
api.multi_index_local_search(
191+
config=config,
192+
nodes_list=final_nodes_list,
193+
entities_list=final_entities_list,
194+
community_reports_list=final_community_reports_list,
195+
text_units_list=final_text_units_list,
196+
relationships_list=final_relationships_list,
197+
covariates_list=final_covariates_list,
198+
index_names=index_names,
199+
community_level=community_level,
200+
response_type=response_type,
201+
streaming=streaming,
202+
query=query,
203+
)
204+
)
205+
logger.success(f"Local Search Response:\n{response}")
206+
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
207+
# External users should use the API directly to get the response and context data.
208+
return response, context_data
209+
210+
# Otherwise, call the Single-Index Local Search API
140211
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
141212
final_community_reports: pd.DataFrame = dataframe_dict[
142213
"create_final_community_reports"
@@ -146,7 +217,6 @@ def run_local_search(
146217
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
147218
final_covariates: pd.DataFrame | None = dataframe_dict["create_final_covariates"]
148219

149-
# call the Query API
150220
if streaming:
151221

152222
async def run_streaming_search():
@@ -226,6 +296,37 @@ def run_drift_search(
226296
"create_final_entities",
227297
],
228298
)
299+
300+
# Call the Multi-Index Drift Search API
301+
if dataframe_dict["multi-index"]:
302+
final_nodes_list = dataframe_dict["create_final_nodes"]
303+
final_entities_list = dataframe_dict["create_final_entities"]
304+
final_community_reports_list = dataframe_dict["create_final_community_reports"]
305+
final_text_units_list = dataframe_dict["create_final_text_units"]
306+
final_relationships_list = dataframe_dict["create_final_relationships"]
307+
index_names = dataframe_dict["index_names"]
308+
309+
response, context_data = asyncio.run(
310+
api.multi_index_drift_search(
311+
config=config,
312+
nodes_list=final_nodes_list,
313+
entities_list=final_entities_list,
314+
community_reports_list=final_community_reports_list,
315+
text_units_list=final_text_units_list,
316+
relationships_list=final_relationships_list,
317+
index_names=index_names,
318+
community_level=community_level,
319+
response_type=response_type,
320+
streaming=streaming,
321+
query=query,
322+
)
323+
)
324+
logger.success(f"DRIFT Search Response:\n{response}")
325+
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
326+
# External users should use the API directly to get the response and context data.
327+
return response, context_data
328+
329+
# Otherwise, call the Single-Index Drift Search API
229330
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
230331
final_community_reports: pd.DataFrame = dataframe_dict[
231332
"create_final_community_reports"
@@ -234,7 +335,6 @@ def run_drift_search(
234335
final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"]
235336
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
236337

237-
# call the Query API
238338
if streaming:
239339

240340
async def run_streaming_search():
@@ -308,9 +408,29 @@ def run_basic_search(
308408
"create_final_text_units",
309409
],
310410
)
411+
412+
# Call the Multi-Index Basic Search API
413+
if dataframe_dict["multi-index"]:
414+
final_text_units_list = dataframe_dict["create_final_text_units"]
415+
index_names = dataframe_dict["index_names"]
416+
417+
response, context_data = asyncio.run(
418+
api.multi_index_basic_search(
419+
config=config,
420+
text_units_list=final_text_units_list,
421+
index_names=index_names,
422+
streaming=streaming,
423+
query=query,
424+
)
425+
)
426+
logger.success(f"Basic Search Response:\n{response}")
427+
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
428+
# External users should use the API directly to get the response and context data.
429+
return response, context_data
430+
431+
# Otherwise, call the Single-Index Basic Search API
311432
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
312433

313-
# # call the Query API
314434
if streaming:
315435

316436
async def run_streaming_search():
@@ -351,9 +471,46 @@ def _resolve_output_files(
351471
config: GraphRagConfig,
352472
output_list: list[str],
353473
optional_list: list[str] | None = None,
354-
) -> dict[str, pd.DataFrame]:
474+
) -> dict[str, Any]:
355475
"""Read indexing output files to a dataframe dict."""
356476
dataframe_dict = {}
477+
478+
# Loading output files for multi-index search
479+
if config.outputs:
480+
dataframe_dict["multi-index"] = True
481+
dataframe_dict["num_indexes"] = len(config.outputs)
482+
dataframe_dict["index_names"] = config.outputs.keys()
483+
for output in config.outputs.values():
484+
output_config = output.model_dump()
485+
storage_obj = StorageFactory().create_storage(
486+
storage_type=output_config["type"], kwargs=output_config
487+
)
488+
for name in output_list:
489+
if name not in dataframe_dict:
490+
dataframe_dict[name] = []
491+
df_value = asyncio.run(
492+
load_table_from_storage(name=name, storage=storage_obj)
493+
)
494+
dataframe_dict[name].append(df_value)
495+
496+
# for optional output files, do not append if the dataframe does not exist
497+
if optional_list:
498+
for optional_file in optional_list:
499+
if optional_file not in dataframe_dict:
500+
dataframe_dict[optional_file] = []
501+
file_exists = asyncio.run(
502+
storage_has_table(optional_file, storage_obj)
503+
)
504+
if file_exists:
505+
df_value = asyncio.run(
506+
load_table_from_storage(
507+
name=optional_file, storage=storage_obj
508+
)
509+
)
510+
dataframe_dict[optional_file].append(df_value)
511+
return dataframe_dict
512+
# Loading output files for single-index search
513+
dataframe_dict["multi-index"] = False
357514
output_config = config.output.model_dump() # type: ignore
358515
storage_obj = StorageFactory().create_storage(
359516
storage_type=output_config["type"], kwargs=output_config
@@ -373,5 +530,4 @@ def _resolve_output_files(
373530
dataframe_dict[optional_file] = df_value
374531
else:
375532
dataframe_dict[optional_file] = None
376-
377533
return dataframe_dict

graphrag/config/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
SNAPSHOTS_EMBEDDINGS = False
133133
SNAPSHOTS_TRANSIENT = False
134134
OUTPUT_BASE_DIR = "output"
135+
OUTPUT_DEFAULT_ID = "default_output"
135136
OUTPUT_TYPE = OutputType.file
136137
SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500
137138
SUMMARIZE_MODEL_ID = DEFAULT_CHAT_MODEL_ID

0 commit comments

Comments
 (0)