Skip to content

Commit 17fad44

Browse files
authored
Merge branch 'main' into joshbradley/improve-llm-retry-logic
2 parents 1172e86 + fe46141 commit 17fad44

35 files changed

+559
-67
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "add option to prepend metadata into chunks"
4+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Export NLP community reports prompt."
4+
}

graphrag/cli/initialize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from graphrag.prompts.index.community_report import (
1111
COMMUNITY_REPORT_PROMPT,
1212
)
13+
from graphrag.prompts.index.community_report_text_units import (
14+
COMMUNITY_REPORT_TEXT_PROMPT,
15+
)
1316
from graphrag.prompts.index.extract_claims import EXTRACT_CLAIMS_PROMPT
1417
from graphrag.prompts.index.extract_graph import GRAPH_EXTRACTION_PROMPT
1518
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
@@ -72,7 +75,8 @@ def initialize_project_at(path: Path, force: bool) -> None:
7275
"extract_graph": GRAPH_EXTRACTION_PROMPT,
7376
"summarize_descriptions": SUMMARIZE_PROMPT,
7477
"extract_claims": EXTRACT_CLAIMS_PROMPT,
75-
"community_report": COMMUNITY_REPORT_PROMPT,
78+
"community_report_graph": COMMUNITY_REPORT_PROMPT,
79+
"community_report_text": COMMUNITY_REPORT_TEXT_PROMPT,
7680
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
7781
"drift_reduce_prompt": DRIFT_REDUCE_PROMPT,
7882
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,

graphrag/config/defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
CHUNK_OVERLAP = 100
6565
CHUNK_GROUP_BY_COLUMNS = ["id"]
6666
CHUNK_STRATEGY = ChunkStrategyType.tokens
67+
CHUNK_PREPEND_METADATA = False
68+
CHUNK_SIZE_INCLUDES_METADATA = False
6769

6870
# Claim extraction
6971
DESCRIPTION = "Any claims or facts that could be relevant to information discovery."

graphrag/config/init_content.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@
123123
124124
community_reports:
125125
model_id: {defs.COMMUNITY_REPORT_MODEL_ID}
126-
prompt: "prompts/community_report.txt"
126+
graph_prompt: "prompts/community_report_graph.txt"
127+
text_prompt: "prompts/community_report_text.txt"
127128
max_length: {defs.COMMUNITY_REPORT_MAX_LENGTH}
128129
max_input_length: {defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH}
129130

graphrag/config/models/chunking_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,11 @@ class ChunkingConfig(BaseModel):
2626
encoding_model: str = Field(
2727
description="The encoding model to use.", default=defs.ENCODING_MODEL
2828
)
29+
prepend_metadata: bool = Field(
30+
description="Prepend metadata into each chunk.",
31+
default=defs.CHUNK_PREPEND_METADATA,
32+
)
33+
chunk_size_includes_metadata: bool = Field(
34+
description="Count metadata in max tokens.",
35+
default=defs.CHUNK_SIZE_INCLUDES_METADATA,
36+
)

graphrag/config/models/community_reports_config.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414
class CommunityReportsConfig(BaseModel):
1515
"""Configuration section for community reports."""
1616

17-
prompt: str | None = Field(
18-
description="The community report extraction prompt to use.", default=None
17+
graph_prompt: str | None = Field(
18+
description="The community report extraction prompt to use for graph-based summarization.",
19+
default=None,
20+
)
21+
text_prompt: str | None = Field(
22+
description="The community report extraction prompt to use for text-based summarization.",
23+
default=None,
1924
)
2025
max_length: int = Field(
2126
description="The community report maximum length in tokens.",
@@ -45,10 +50,15 @@ def resolved_strategy(
4550
"type": CreateCommunityReportsStrategyType.graph_intelligence,
4651
"llm": model_config.model_dump(),
4752
"num_threads": model_config.concurrent_requests,
48-
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
53+
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
54+
encoding="utf-8"
55+
)
56+
if self.graph_prompt
57+
else None,
58+
"text_prompt": (Path(root_dir) / self.text_prompt).read_text(
4959
encoding="utf-8"
5060
)
51-
if self.prompt
61+
if self.text_prompt
5262
else None,
5363
"max_report_length": self.max_length,
5464
"max_input_length": self.max_input_length,

graphrag/index/flows/create_base_text_units.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33

44
"""All the steps to transform base text_units."""
55

6-
from typing import cast
6+
import json
7+
from typing import Any, cast
78

89
import pandas as pd
910

1011
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1112
from graphrag.config.models.chunking_config import ChunkStrategyType
1213
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
14+
from graphrag.index.operations.chunk_text.strategies import get_encoding_fn
1315
from graphrag.index.utils.hashing import gen_sha512_hash
1416
from graphrag.logger.progress import Progress
1517

@@ -22,6 +24,8 @@ def create_base_text_units(
2224
overlap: int,
2325
encoding_model: str,
2426
strategy: ChunkStrategyType,
27+
prepend_metadata: bool = False,
28+
chunk_size_includes_metadata: bool = False,
2529
) -> pd.DataFrame:
2630
"""All the steps to transform base text_units."""
2731
sort = documents.sort_values(by=["id"], ascending=[True])
@@ -32,25 +36,66 @@ def create_base_text_units(
3236

3337
callbacks.progress(Progress(percent=0))
3438

39+
agg_dict = {"text_with_ids": list}
40+
if "metadata" in documents:
41+
agg_dict["metadata"] = "first" # type: ignore
42+
3543
aggregated = (
3644
(
3745
sort.groupby(group_by_columns, sort=False)
3846
if len(group_by_columns) > 0
3947
else sort.groupby(lambda _x: True)
4048
)
41-
.agg(texts=("text_with_ids", list))
49+
.agg(agg_dict)
4250
.reset_index()
4351
)
52+
aggregated.rename(columns={"text_with_ids": "texts"}, inplace=True)
4453

45-
aggregated["chunks"] = chunk_text(
46-
aggregated,
47-
column="texts",
48-
size=size,
49-
overlap=overlap,
50-
encoding_model=encoding_model,
51-
strategy=strategy,
52-
callbacks=callbacks,
53-
)
54+
def chunker(row: dict[str, Any]) -> Any:
55+
line_delimiter = ".\n"
56+
metadata_str = ""
57+
metadata_tokens = 0
58+
59+
if prepend_metadata and "metadata" in row:
60+
metadata = row["metadata"]
61+
if isinstance(metadata, str):
62+
metadata = json.loads(metadata)
63+
if isinstance(metadata, dict):
64+
metadata_str = (
65+
line_delimiter.join(f"{k}: {v}" for k, v in metadata.items())
66+
+ line_delimiter
67+
)
68+
69+
if chunk_size_includes_metadata:
70+
encode, _ = get_encoding_fn(encoding_model)
71+
metadata_tokens = len(encode(metadata_str))
72+
if metadata_tokens >= size:
73+
message = "Metadata tokens exceeds the maximum tokens per chunk. Please increase the tokens per chunk."
74+
raise ValueError(message)
75+
76+
chunked = chunk_text(
77+
pd.DataFrame([row]).reset_index(drop=True),
78+
column="texts",
79+
size=size - metadata_tokens,
80+
overlap=overlap,
81+
encoding_model=encoding_model,
82+
strategy=strategy,
83+
callbacks=callbacks,
84+
)[0]
85+
86+
if prepend_metadata:
87+
for index, chunk in enumerate(chunked):
88+
if isinstance(chunk, str):
89+
chunked[index] = metadata_str + chunk
90+
else:
91+
chunked[index] = (
92+
(chunk[0], metadata_str + chunk[1], chunk[2]) if chunk else None
93+
)
94+
95+
row["chunks"] = chunked
96+
return row
97+
98+
aggregated = aggregated.apply(lambda row: chunker(row), axis=1)
5499

55100
aggregated = cast("pd.DataFrame", aggregated[[*group_by_columns, "chunks"]])
56101
aggregated = aggregated.explode("chunks")

graphrag/index/flows/create_community_reports.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ async def create_community_reports(
4646
if claims_input is not None:
4747
claims = _prep_claims(claims_input)
4848

49+
summarization_strategy["extraction_prompt"] = summarization_strategy["graph_prompt"]
50+
4951
max_input_length = summarization_strategy.get(
5052
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
5153
)

graphrag/index/flows/create_community_reports_text.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
build_level_context,
2525
build_local_context,
2626
)
27-
from graphrag.prompts.index.community_report_text_units import (
28-
COMMUNITY_REPORT_PROMPT,
29-
)
3027

3128
log = logging.getLogger(__name__)
3229

@@ -44,8 +41,7 @@ async def create_community_reports_text(
4441
"""All the steps to transform community reports."""
4542
nodes = explode_communities(communities, entities)
4643

47-
# TEMP: forcing override of the prompt until we can put it into config
48-
summarization_strategy["extraction_prompt"] = COMMUNITY_REPORT_PROMPT
44+
summarization_strategy["extraction_prompt"] = summarization_strategy["text_prompt"]
4945

5046
max_input_length = summarization_strategy.get(
5147
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH

0 commit comments

Comments
 (0)