Skip to content

Commit 51912b2

Browse files
authored
Move prompts (#1404)
* Move indexing prompts to root * Move query prompts to root * Export query prompts during init * Extract general knowledge prompt * Load query prompts from disk * Semver * Fix unit tests
1 parent c8c354e commit 51912b2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+269
-93
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Centralized prompts and export all for easier injection."
4+
}

docs/prompt_tuning/manual_prompt_tuning.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain
88

99
## Entity/Relationship Extraction
1010

11-
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/graph/prompts.py)
11+
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/entity_extraction.py)
1212

1313
### Tokens (values provided by extractor)
1414

@@ -20,7 +20,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain
2020

2121
## Summarize Entity/Relationship Descriptions
2222

23-
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/summarize/prompts.py)
23+
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/summarize_descriptions.py)
2424

2525
### Tokens (values provided by extractor)
2626

@@ -29,7 +29,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain
2929

3030
## Claim Extraction
3131

32-
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/claims/prompts.py)
32+
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/claim_extraction.py)
3333

3434
### Tokens (values provided by extractor)
3535

@@ -47,7 +47,7 @@ See the [configuration documentation](../config/overview.md) for details on how
4747

4848
## Generate Community Reports
4949

50-
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/community_reports/prompts.py)
50+
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/community_report.py)
5151

5252
### Tokens (values provided by extractor)
5353

docs/query/global_search.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ Below are the key parameters of the [GlobalSearch class](https://github.com/micr
5656

5757
* `llm`: OpenAI model object to be used for response generation
5858
* `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/community_context.py) object to be used for preparing context data from community reports
59-
* `map_system_prompt`: prompt template used in the `map` stage. Default template can be found at [map_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/map_system_prompt.py)
60-
* `reduce_system_prompt`: prompt template used in the `reduce` stage, default template can be found at [reduce_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/reduce_system_prompt.py)
59+
* `map_system_prompt`: prompt template used in the `map` stage. Default template can be found at [map_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_map_system_prompt.py)
60+
* `reduce_system_prompt`: prompt template used in the `reduce` stage, default template can be found at [reduce_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_reduce_system_prompt.py)
6161
* `response_type`: free-form text describing the desired response type and format (e.g., `Multiple Paragraphs`, `Multi-Page Report`)
6262
* `allow_general_knowledge`: setting this to True will include additional instructions to the `reduce_system_prompt` to prompt the LLM to incorporate relevant real-world knowledge outside of the dataset. Note that this may increase hallucinations, but can be useful for certain scenarios. Default is False
63-
*`general_knowledge_inclusion_prompt`: instruction to add to the `reduce_system_prompt` if `allow_general_knowledge` is enabled. Default instruction can be found at [general_knowledge_instruction](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/reduce_system_prompt.py)
63+
*`general_knowledge_inclusion_prompt`: instruction to add to the `reduce_system_prompt` if `allow_general_knowledge` is enabled. Default instruction can be found at [general_knowledge_instruction](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_knowledge_system_prompt.py)
6464
* `max_data_tokens`: token budget for the context data
6565
* `map_llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call at the `map` stage
6666
* `reduce_llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to passed to the LLM call at the `reduce` stage

docs/query/local_search.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Below are the key parameters of the [LocalSearch class](https://github.com/micro
5050

5151
* `llm`: OpenAI model object to be used for response generation
5252
* `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object to be used for preparing context data from collections of knowledge model objects
53-
* `system_prompt`: prompt template used to generate the search response. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/system_prompt.py)
53+
* `system_prompt`: prompt template used to generate the search response. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/local_search_system_prompt.py)
5454
* `response_type`: free-form text describing the desired response type and format (e.g., `Multiple Paragraphs`, `Multi-Page Report`)
5555
* `llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call
5656
* `context_builder_params`: a dictionary of additional parameters to be passed to the [`context_builder`](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object when building context for the search prompt

docs/query/question_generation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Below are the key parameters of the [Question Generation class](https://github.c
1313

1414
* `llm`: OpenAI model object to be used for response generation
1515
* `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object to be used for preparing context data from collections of knowledge model objects, using the same context builder class as in local search
16-
* `system_prompt`: prompt template used to generate candidate questions. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/question_gen/system_prompt.py)
16+
* `system_prompt`: prompt template used to generate candidate questions. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/question_gen_system_prompt.py)
1717
* `llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call
1818
* `context_builder_params`: a dictionary of additional parameters to be passed to the [`context_builder`](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object when building context for the question generation prompt
1919
* `callbacks`: optional callback functions, can be used to provide custom event handlers for LLM's completion streaming events

graphrag/api/query.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,24 @@ async def global_search(
9898
dynamic_community_selection=dynamic_community_selection,
9999
)
100100
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
101+
map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt)
102+
reduce_prompt = _load_search_prompt(
103+
config.root_dir, config.global_search.reduce_prompt
104+
)
105+
knowledge_prompt = _load_search_prompt(
106+
config.root_dir, config.global_search.knowledge_prompt
107+
)
108+
101109
search_engine = get_global_search_engine(
102110
config,
103111
reports=reports,
104112
entities=_entities,
105113
communities=_communities,
106114
response_type=response_type,
107115
dynamic_community_selection=dynamic_community_selection,
116+
map_system_prompt=map_prompt,
117+
reduce_system_prompt=reduce_prompt,
118+
general_knowledge_inclusion_prompt=knowledge_prompt,
108119
)
109120
result: SearchResult = await search_engine.asearch(query=query)
110121
response = result.response
@@ -156,13 +167,24 @@ async def global_search_streaming(
156167
dynamic_community_selection=dynamic_community_selection,
157168
)
158169
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
170+
map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt)
171+
reduce_prompt = _load_search_prompt(
172+
config.root_dir, config.global_search.reduce_prompt
173+
)
174+
knowledge_prompt = _load_search_prompt(
175+
config.root_dir, config.global_search.knowledge_prompt
176+
)
177+
159178
search_engine = get_global_search_engine(
160179
config,
161180
reports=reports,
162181
entities=_entities,
163182
communities=_communities,
164183
response_type=response_type,
165184
dynamic_community_selection=dynamic_community_selection,
185+
map_system_prompt=map_prompt,
186+
reduce_system_prompt=reduce_prompt,
187+
general_knowledge_inclusion_prompt=knowledge_prompt,
166188
)
167189
search_result = search_engine.astream_search(query=query)
168190

@@ -238,6 +260,7 @@ async def local_search(
238260

239261
_entities = read_indexer_entities(nodes, entities, community_level)
240262
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
263+
prompt = _load_search_prompt(config.root_dir, config.local_search.prompt)
241264

242265
search_engine = get_local_search_engine(
243266
config=config,
@@ -248,6 +271,7 @@ async def local_search(
248271
covariates={"claims": _covariates},
249272
description_embedding_store=description_embedding_store, # type: ignore
250273
response_type=response_type,
274+
system_prompt=prompt,
251275
)
252276

253277
result: SearchResult = await search_engine.asearch(query=query)
@@ -312,6 +336,7 @@ async def local_search_streaming(
312336

313337
_entities = read_indexer_entities(nodes, entities, community_level)
314338
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
339+
prompt = _load_search_prompt(config.root_dir, config.local_search.prompt)
315340

316341
search_engine = get_local_search_engine(
317342
config=config,
@@ -322,6 +347,7 @@ async def local_search_streaming(
322347
covariates={"claims": _covariates},
323348
description_embedding_store=description_embedding_store, # type: ignore
324349
response_type=response_type,
350+
system_prompt=prompt,
325351
)
326352
search_result = search_engine.astream_search(query=query)
327353

@@ -401,14 +427,15 @@ async def drift_search(
401427
_entities = read_indexer_entities(nodes, entities, community_level)
402428
_reports = read_indexer_reports(community_reports, nodes, community_level)
403429
read_indexer_report_embeddings(_reports, full_content_embedding_store)
404-
430+
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
405431
search_engine = get_drift_search_engine(
406432
config=config,
407433
reports=_reports,
408434
text_units=read_indexer_text_units(text_units),
409435
entities=_entities,
410436
relationships=read_indexer_relationships(relationships),
411437
description_embedding_store=description_embedding_store, # type: ignore
438+
local_system_prompt=prompt,
412439
)
413440

414441
result: SearchResult = await search_engine.asearch(query=query)
@@ -551,3 +578,17 @@ def _reformat_context_data(context_data: dict) -> dict:
551578
continue
552579
final_format[key] = records
553580
return final_format
581+
582+
583+
def _load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None:
584+
"""
585+
Load the search prompt from disk if configured.
586+
587+
If not, leave it empty - the search functions will load their defaults.
588+
589+
"""
590+
if prompt_config:
591+
prompt_file = Path(root_dir) / prompt_config
592+
if prompt_file.exists():
593+
return prompt_file.read_bytes().decode(encoding="utf-8")
594+
return None

graphrag/cli/initialize.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,24 @@
55

66
from pathlib import Path
77

8-
from graphrag.index.graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
9-
from graphrag.index.graph.extractors.community_reports.prompts import (
10-
COMMUNITY_REPORT_PROMPT,
11-
)
12-
from graphrag.index.graph.extractors.graph.prompts import GRAPH_EXTRACTION_PROMPT
13-
from graphrag.index.graph.extractors.summarize.prompts import SUMMARIZE_PROMPT
148
from graphrag.index.init_content import INIT_DOTENV, INIT_YAML
159
from graphrag.logging import ReporterType, create_progress_reporter
10+
from graphrag.prompts.index.claim_extraction import CLAIM_EXTRACTION_PROMPT
11+
from graphrag.prompts.index.community_report import (
12+
COMMUNITY_REPORT_PROMPT,
13+
)
14+
from graphrag.prompts.index.entity_extraction import GRAPH_EXTRACTION_PROMPT
15+
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
16+
from graphrag.prompts.query.drift_search_system_prompt import DRIFT_LOCAL_SYSTEM_PROMPT
17+
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
18+
GENERAL_KNOWLEDGE_INSTRUCTION,
19+
)
20+
from graphrag.prompts.query.global_search_map_system_prompt import MAP_SYSTEM_PROMPT
21+
from graphrag.prompts.query.global_search_reduce_system_prompt import (
22+
REDUCE_SYSTEM_PROMPT,
23+
)
24+
from graphrag.prompts.query.local_search_system_prompt import LOCAL_SEARCH_SYSTEM_PROMPT
25+
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
1626

1727

1828
def initialize_project_at(path: Path) -> None:
@@ -40,28 +50,21 @@ def initialize_project_at(path: Path) -> None:
4050
if not prompts_dir.exists():
4151
prompts_dir.mkdir(parents=True, exist_ok=True)
4252

43-
entity_extraction = prompts_dir / "entity_extraction.txt"
44-
if not entity_extraction.exists():
45-
with entity_extraction.open("wb") as file:
46-
file.write(
47-
GRAPH_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict")
48-
)
49-
50-
summarize_descriptions = prompts_dir / "summarize_descriptions.txt"
51-
if not summarize_descriptions.exists():
52-
with summarize_descriptions.open("wb") as file:
53-
file.write(SUMMARIZE_PROMPT.encode(encoding="utf-8", errors="strict"))
54-
55-
claim_extraction = prompts_dir / "claim_extraction.txt"
56-
if not claim_extraction.exists():
57-
with claim_extraction.open("wb") as file:
58-
file.write(
59-
CLAIM_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict")
60-
)
53+
prompts = {
54+
"entity_extraction": GRAPH_EXTRACTION_PROMPT,
55+
"summarize_descriptions": SUMMARIZE_PROMPT,
56+
"claim_extraction": CLAIM_EXTRACTION_PROMPT,
57+
"community_report": COMMUNITY_REPORT_PROMPT,
58+
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
59+
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,
60+
"global_search_reduce_system_prompt": REDUCE_SYSTEM_PROMPT,
61+
"global_search_knowledge_system_prompt": GENERAL_KNOWLEDGE_INSTRUCTION,
62+
"local_search_system_prompt": LOCAL_SEARCH_SYSTEM_PROMPT,
63+
"question_gen_system_prompt": QUESTION_SYSTEM_PROMPT,
64+
}
6165

62-
community_report = prompts_dir / "community_report.txt"
63-
if not community_report.exists():
64-
with community_report.open("wb") as file:
65-
file.write(
66-
COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict")
67-
)
66+
for name, content in prompts.items():
67+
prompt_file = prompts_dir / f"{name}.txt"
68+
if not prompt_file.exists():
69+
with prompt_file.open("wb") as file:
70+
file.write(content.encode(encoding="utf-8", errors="strict"))

graphrag/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
ClaimExtractionConfig,
5252
ClusterGraphConfig,
5353
CommunityReportsConfig,
54+
DRIFTSearchConfig,
5455
EmbedGraphConfig,
5556
EntityExtractionConfig,
5657
GlobalSearchConfig,
@@ -85,6 +86,7 @@
8586
"ClusterGraphConfigInput",
8687
"CommunityReportsConfig",
8788
"CommunityReportsConfigInput",
89+
"DRIFTSearchConfig",
8890
"EmbedGraphConfig",
8991
"EmbedGraphConfigInput",
9092
"EntityExtractionConfig",

0 commit comments

Comments
 (0)