Skip to content

Commit 724eff9

Browse files
committed
Merge branch 'main' into joshbradley/cleanup-query-api
2 parents 164bba8 + fe46141 commit 724eff9

38 files changed

+602
-74
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "add option to prepend metadata into chunks"
4+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "update multi-index query to support new workflows"
4+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Export NLP community reports prompt."
4+
}

graphrag/api/query.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,12 @@ async def multi_index_global_search(
237237
raise NotImplementedError(message)
238238

239239
links = {
240-
"community": {},
240+
"communities": {},
241241
"community_reports": {},
242242
"entities": {},
243243
}
244244
max_vals = {
245-
"community": -1,
245+
"communities": -1,
246246
"community_reports": -1,
247247
"entities": -1,
248248
}
@@ -272,16 +272,20 @@ async def multi_index_global_search(
272272
communities_df["community"] = communities_df["community"].astype(int)
273273
communities_df["parent"] = communities_df["parent"].astype(int)
274274
for i in communities_df["community"]:
275-
links["community"][i + max_vals["community"] + 1] = {
275+
links["communities"][i + max_vals["communities"] + 1] = {
276276
"index_name": index_name,
277277
"id": str(i),
278278
}
279-
communities_df["community"] += max_vals["community"] + 1
279+
communities_df["community"] += max_vals["communities"] + 1
280280
communities_df["parent"] = communities_df["parent"].apply(
281-
lambda x: x if x == -1 else x + max_vals["community"] + 1
281+
lambda x: x if x == -1 else x + max_vals["communities"] + 1
282+
)
283+
communities_df["human_readable_id"] += max_vals["communities"] + 1
284+
# concat the index name to the entity_ids, since this is used for joining later
285+
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
286+
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
282287
)
283-
communities_df["human_readable_id"] += max_vals["community"] + 1
284-
max_vals["community"] = int(communities_df["community"].max())
288+
max_vals["communities"] = int(communities_df["community"].max())
285289
communities_dfs.append(communities_df)
286290

287291
# Prepare each index's entities dataframe for merging
@@ -514,13 +518,15 @@ async def multi_index_local_search(
514518

515519
links = {
516520
"community_reports": {},
521+
"communities": {},
517522
"entities": {},
518523
"text_units": {},
519524
"relationships": {},
520525
"covariates": {},
521526
}
522527
max_vals = {
523528
"community_reports": -1,
529+
"communities": -1,
524530
"entities": -1,
525531
"text_units": 0,
526532
"relationships": -1,
@@ -544,6 +550,10 @@ async def multi_index_local_search(
544550
}
545551
communities_df["community"] += max_vals["communities"] + 1
546552
communities_df["human_readable_id"] += max_vals["communities"] + 1
553+
# concat the index name to the entity_ids, since this is used for joining later
554+
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
555+
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
556+
)
547557
max_vals["communities"] = int(communities_df["community"].max())
548558
communities_dfs.append(communities_df)
549559

@@ -873,12 +883,14 @@ async def multi_index_drift_search(
873883

874884
links = {
875885
"community_reports": {},
886+
"communities": {},
876887
"entities": {},
877888
"text_units": {},
878889
"relationships": {},
879890
}
880891
max_vals = {
881892
"community_reports": -1,
893+
"communities": -1,
882894
"entities": -1,
883895
"text_units": 0,
884896
"relationships": -1,
@@ -901,6 +913,10 @@ async def multi_index_drift_search(
901913
}
902914
communities_df["community"] += max_vals["communities"] + 1
903915
communities_df["human_readable_id"] += max_vals["communities"] + 1
916+
# concat the index name to the entity_ids, since this is used for joining later
917+
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
918+
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
919+
)
904920
max_vals["communities"] = int(communities_df["community"].max())
905921
communities_dfs.append(communities_df)
906922

graphrag/cli/initialize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from graphrag.prompts.index.community_report import (
1111
COMMUNITY_REPORT_PROMPT,
1212
)
13+
from graphrag.prompts.index.community_report_text_units import (
14+
COMMUNITY_REPORT_TEXT_PROMPT,
15+
)
1316
from graphrag.prompts.index.extract_claims import EXTRACT_CLAIMS_PROMPT
1417
from graphrag.prompts.index.extract_graph import GRAPH_EXTRACTION_PROMPT
1518
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
@@ -72,7 +75,8 @@ def initialize_project_at(path: Path, force: bool) -> None:
7275
"extract_graph": GRAPH_EXTRACTION_PROMPT,
7376
"summarize_descriptions": SUMMARIZE_PROMPT,
7477
"extract_claims": EXTRACT_CLAIMS_PROMPT,
75-
"community_report": COMMUNITY_REPORT_PROMPT,
78+
"community_report_graph": COMMUNITY_REPORT_PROMPT,
79+
"community_report_text": COMMUNITY_REPORT_TEXT_PROMPT,
7680
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
7781
"drift_reduce_prompt": DRIFT_REDUCE_PROMPT,
7882
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,

graphrag/cli/query.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ def run_global_search(
5858
final_community_reports_list = dataframe_dict["community_reports"]
5959
index_names = dataframe_dict["index_names"]
6060

61+
logger.success(
62+
f"Running Multi-index Global Search: {dataframe_dict['index_names']}"
63+
)
64+
6165
response, context_data = asyncio.run(
6266
api.multi_index_global_search(
6367
config=config,
@@ -169,6 +173,10 @@ def run_local_search(
169173
final_relationships_list = dataframe_dict["relationships"]
170174
index_names = dataframe_dict["index_names"]
171175

176+
logger.success(
177+
f"Running Multi-index Local Search: {dataframe_dict['index_names']}"
178+
)
179+
172180
# If any covariates tables are missing from any index, set the covariates list to None
173181
if len(dataframe_dict["covariates"]) != dataframe_dict["num_indexes"]:
174182
final_covariates_list = None
@@ -293,6 +301,10 @@ def run_drift_search(
293301
final_relationships_list = dataframe_dict["relationships"]
294302
index_names = dataframe_dict["index_names"]
295303

304+
logger.success(
305+
f"Running Multi-index Drift Search: {dataframe_dict['index_names']}"
306+
)
307+
296308
response, context_data = asyncio.run(
297309
api.multi_index_drift_search(
298310
config=config,
@@ -399,6 +411,10 @@ def run_basic_search(
399411
final_text_units_list = dataframe_dict["text_units"]
400412
index_names = dataframe_dict["index_names"]
401413

414+
logger.success(
415+
f"Running Multi-index Basic Search: {dataframe_dict['index_names']}"
416+
)
417+
402418
response, context_data = asyncio.run(
403419
api.multi_index_basic_search(
404420
config=config,

graphrag/config/defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
CHUNK_OVERLAP = 100
6868
CHUNK_GROUP_BY_COLUMNS = ["id"]
6969
CHUNK_STRATEGY = ChunkStrategyType.tokens
70+
CHUNK_PREPEND_METADATA = False
71+
CHUNK_SIZE_INCLUDES_METADATA = False
7072

7173
# Claim extraction
7274
DESCRIPTION = "Any claims or facts that could be relevant to information discovery."

graphrag/config/init_content.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@
114114
115115
community_reports:
116116
model_id: {defs.COMMUNITY_REPORT_MODEL_ID}
117-
prompt: "prompts/community_report.txt"
117+
graph_prompt: "prompts/community_report_graph.txt"
118+
text_prompt: "prompts/community_report_text.txt"
118119
max_length: {defs.COMMUNITY_REPORT_MAX_LENGTH}
119120
max_input_length: {defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH}
120121

graphrag/config/models/chunking_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,11 @@ class ChunkingConfig(BaseModel):
2626
encoding_model: str = Field(
2727
description="The encoding model to use.", default=defs.ENCODING_MODEL
2828
)
29+
prepend_metadata: bool = Field(
30+
description="Prepend metadata into each chunk.",
31+
default=defs.CHUNK_PREPEND_METADATA,
32+
)
33+
chunk_size_includes_metadata: bool = Field(
34+
description="Count metadata in max tokens.",
35+
default=defs.CHUNK_SIZE_INCLUDES_METADATA,
36+
)

graphrag/config/models/community_reports_config.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414
class CommunityReportsConfig(BaseModel):
1515
"""Configuration section for community reports."""
1616

17-
prompt: str | None = Field(
18-
description="The community report extraction prompt to use.", default=None
17+
graph_prompt: str | None = Field(
18+
description="The community report extraction prompt to use for graph-based summarization.",
19+
default=None,
20+
)
21+
text_prompt: str | None = Field(
22+
description="The community report extraction prompt to use for text-based summarization.",
23+
default=None,
1924
)
2025
max_length: int = Field(
2126
description="The community report maximum length in tokens.",
@@ -46,10 +51,15 @@ def resolved_strategy(
4651
"llm": model_config.model_dump(),
4752
"stagger": model_config.parallelization_stagger,
4853
"num_threads": model_config.parallelization_num_threads,
49-
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
54+
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
55+
encoding="utf-8"
56+
)
57+
if self.graph_prompt
58+
else None,
59+
"text_prompt": (Path(root_dir) / self.text_prompt).read_text(
5060
encoding="utf-8"
5161
)
52-
if self.prompt
62+
if self.text_prompt
5363
else None,
5464
"max_report_length": self.max_length,
5565
"max_input_length": self.max_input_length,

0 commit comments

Comments
 (0)