Skip to content

Commit eb0dfe3

Browse files
authored
Remove strategy dicts (#2090)
* Remove "strategy" from community reports config/workflow * Remove extraction strategy from extract_graph * Remove summarization strategy from extract_graph * Remove strategy from claim extraction * Strongly type prompt templates * Remove strategy from embed_text * Push hydrated params into community report workflows * Push hyrdated params into extract covariates * Push hydrated params into extract graph NLP * Push hydrated params into extract graph * Push hydrated params into text embeddings * Remove a few more low-level defaults * Semver * Remove configurable prompt delimiters * Update smoke tests
1 parent 79ad9b9 commit eb0dfe3

File tree

56 files changed

+946
-1405
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+946
-1405
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "major",
3+
"description": "Simplify internal args with stronger types and firmer boundaries."
4+
}

docs/examples_notebooks/index_migration_to_v1.ipynb

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -202,45 +202,44 @@
202202
"metadata": {},
203203
"outputs": [],
204204
"source": [
205-
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
206-
"\n",
207205
"from graphrag.cache.factory import CacheFactory\n",
208206
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
209-
"from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings\n",
207+
"from graphrag.config.get_vector_store_settings import get_vector_store_settings\n",
208+
"from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings\n",
210209
"\n",
211210
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
212211
"# We'll construct the context and run this function flow directly to avoid everything else\n",
213212
"\n",
214213
"\n",
215-
"embedded_fields = get_embedded_fields(config)\n",
216-
"text_embed = get_embedding_settings(config)\n",
214+
"vector_store_config = get_vector_store_settings(config)\n",
215+
"model_config = config.get_language_model_config(config.embed_text.model_id)\n",
217216
"callbacks = NoopWorkflowCallbacks()\n",
218217
"cache_config = config.cache.model_dump() # type: ignore\n",
219218
"cache = CacheFactory().create_cache(\n",
220219
" cache_type=cache_config[\"type\"], # type: ignore\n",
221-
" root_dir=PROJECT_DIRECTORY,\n",
222-
" kwargs=cache_config,\n",
220+
" **cache_config,\n",
223221
")\n",
224222
"\n",
225223
"await generate_text_embeddings(\n",
226-
" final_documents=None,\n",
227-
" final_relationships=None,\n",
228-
" final_text_units=final_text_units,\n",
229-
" final_entities=final_entities,\n",
230-
" final_community_reports=final_community_reports,\n",
224+
" documents=None,\n",
225+
" relationships=None,\n",
226+
" text_units=final_text_units,\n",
227+
" entities=final_entities,\n",
228+
" community_reports=final_community_reports,\n",
231229
" callbacks=callbacks,\n",
232230
" cache=cache,\n",
233-
" storage=storage,\n",
234-
" text_embed_config=text_embed,\n",
235-
" embedded_fields=embedded_fields,\n",
236-
" snapshot_embeddings_enabled=False,\n",
231+
" model_config=model_config,\n",
232+
" batch_size=config.embed_text.batch_size,\n",
233+
" batch_max_tokens=config.embed_text.batch_max_tokens,\n",
234+
" vector_store_config=vector_store_config,\n",
235+
" embedded_fields=config.embed_text.names,\n",
237236
")"
238237
]
239238
}
240239
],
241240
"metadata": {
242241
"kernelspec": {
243-
"display_name": ".venv",
242+
"display_name": "graphrag",
244243
"language": "python",
245244
"name": "python3"
246245
},
@@ -254,7 +253,7 @@
254253
"name": "python",
255254
"nbconvert_exporter": "python",
256255
"pygments_lexer": "ipython3",
257-
"version": "3.11.9"
256+
"version": "3.12.10"
258257
}
259258
},
260259
"nbformat": 4,

graphrag/config/defaults.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
ENCODING_MODEL = "o200k_base"
6161
COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default"
6262

63+
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
6364

6465
DEFAULT_RETRY_SERVICES: dict[str, Callable[..., Retry]] = {
6566
"native": NativeRetry,
@@ -125,7 +126,6 @@ class CommunityReportDefaults:
125126
text_prompt: None = None
126127
max_length: int = 2000
127128
max_input_length: int = 8000
128-
strategy: None = None
129129
model_id: str = DEFAULT_CHAT_MODEL_ID
130130

131131

@@ -162,10 +162,9 @@ class DriftSearchDefaults:
162162
class EmbedTextDefaults:
163163
"""Default values for embedding text."""
164164

165-
model: str = "text-embedding-3-small"
165+
model_id: str = DEFAULT_EMBEDDING_MODEL_ID
166166
batch_size: int = 16
167167
batch_max_tokens: int = 8191
168-
model_id: str = DEFAULT_EMBEDDING_MODEL_ID
169168
names: list[str] = field(default_factory=lambda: default_embeddings)
170169
strategy: None = None
171170
vector_store_id: str = DEFAULT_VECTOR_STORE_ID
Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
# Copyright (c) 2024 Microsoft Corporation.
22
# Licensed under the MIT License
33

4-
"""A module containing get_embedding_settings."""
4+
"""A module containing get_vector_store_settings."""
55

66
from graphrag.config.models.graph_rag_config import GraphRagConfig
77

88

9-
def get_embedding_settings(
9+
def get_vector_store_settings(
1010
settings: GraphRagConfig,
1111
vector_store_params: dict | None = None,
1212
) -> dict:
1313
"""Transform GraphRAG config into settings for workflows."""
14-
embeddings_llm_settings = settings.get_language_model_config(
15-
settings.embed_text.model_id
16-
)
1714
vector_store_settings = settings.get_vector_store_config(
1815
settings.embed_text.vector_store_id
1916
).model_dump()
@@ -23,16 +20,7 @@ def get_embedding_settings(
2320
# settings.vector_store.base contains connection information, or may be undefined
2421
# settings.vector_store.<vector_name> contains the specific settings for this embedding
2522
#
26-
strategy = settings.embed_text.resolved_strategy(
27-
embeddings_llm_settings
28-
) # get the default strategy
29-
strategy.update({
30-
"vector_store": {
31-
**(vector_store_params or {}),
32-
**(vector_store_settings),
33-
}
34-
}) # update the default strategy with the vector store settings
35-
# This ensures the vector store config is part of the strategy and not the global config
3623
return {
37-
"strategy": strategy,
24+
**(vector_store_params or {}),
25+
**(vector_store_settings),
3826
}

graphrag/config/init_content.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
# api_version: 2024-05-01-preview
2929
model_supports_json: true # recommended if this is available for your model.
3030
concurrent_requests: {language_model_defaults.concurrent_requests}
31-
async_mode: {language_model_defaults.async_mode.value} # or asyncio
3231
retry_strategy: {language_model_defaults.retry_strategy}
3332
max_retries: {language_model_defaults.max_retries}
3433
tokens_per_minute: null
@@ -42,7 +41,6 @@
4241
# api_base: https://<instance>.openai.azure.com
4342
# api_version: 2024-05-01-preview
4443
concurrent_requests: {language_model_defaults.concurrent_requests}
45-
async_mode: {language_model_defaults.async_mode.value} # or asyncio
4644
retry_strategy: {language_model_defaults.retry_strategy}
4745
max_retries: {language_model_defaults.max_retries}
4846
tokens_per_minute: null
@@ -102,7 +100,6 @@
102100
extract_graph_nlp:
103101
text_analyzer:
104102
extractor_type: {graphrag_config_defaults.extract_graph_nlp.text_analyzer.extractor_type.value} # [regex_english, syntactic_parser, cfg]
105-
async_mode: {graphrag_config_defaults.extract_graph_nlp.async_mode.value} # or asyncio
106103
107104
cluster_graph:
108105
max_cluster_size: {graphrag_config_defaults.cluster_graph.max_cluster_size}

graphrag/config/models/community_reports_config.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,24 @@
33

44
"""Parameterization settings for the default configuration."""
55

6+
from dataclasses import dataclass
67
from pathlib import Path
78

89
from pydantic import BaseModel, Field
910

1011
from graphrag.config.defaults import graphrag_config_defaults
11-
from graphrag.config.models.language_model_config import LanguageModelConfig
12+
from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT
13+
from graphrag.prompts.index.community_report_text_units import (
14+
COMMUNITY_REPORT_TEXT_PROMPT,
15+
)
16+
17+
18+
@dataclass
19+
class CommunityReportPrompts:
20+
"""Community report prompt templates."""
21+
22+
graph_prompt: str
23+
text_prompt: str
1224

1325

1426
class CommunityReportsConfig(BaseModel):
@@ -34,32 +46,16 @@ class CommunityReportsConfig(BaseModel):
3446
description="The maximum input length in tokens to use when generating reports.",
3547
default=graphrag_config_defaults.community_reports.max_input_length,
3648
)
37-
strategy: dict | None = Field(
38-
description="The override strategy to use.",
39-
default=graphrag_config_defaults.community_reports.strategy,
40-
)
4149

42-
def resolved_strategy(
43-
self, root_dir: str, model_config: LanguageModelConfig
44-
) -> dict:
45-
"""Get the resolved community report extraction strategy."""
46-
from graphrag.index.operations.summarize_communities.typing import (
47-
CreateCommunityReportsStrategyType,
48-
)
49-
50-
return self.strategy or {
51-
"type": CreateCommunityReportsStrategyType.graph_intelligence,
52-
"llm": model_config.model_dump(),
53-
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
50+
def resolved_prompts(self, root_dir: str) -> CommunityReportPrompts:
51+
"""Get the resolved community report extraction prompts."""
52+
return CommunityReportPrompts(
53+
graph_prompt=(Path(root_dir) / self.graph_prompt).read_text(
5454
encoding="utf-8"
5555
)
5656
if self.graph_prompt
57-
else None,
58-
"text_prompt": (Path(root_dir) / self.text_prompt).read_text(
59-
encoding="utf-8"
60-
)
57+
else COMMUNITY_REPORT_PROMPT,
58+
text_prompt=(Path(root_dir) / self.text_prompt).read_text(encoding="utf-8")
6159
if self.text_prompt
62-
else None,
63-
"max_report_length": self.max_length,
64-
"max_input_length": self.max_input_length,
65-
}
60+
else COMMUNITY_REPORT_TEXT_PROMPT,
61+
)

graphrag/config/models/text_embedding_config.py renamed to graphrag/config/models/embed_text_config.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
from pydantic import BaseModel, Field
77

88
from graphrag.config.defaults import graphrag_config_defaults
9-
from graphrag.config.models.language_model_config import LanguageModelConfig
109

1110

12-
class TextEmbeddingConfig(BaseModel):
11+
class EmbedTextConfig(BaseModel):
1312
"""Configuration section for text embeddings."""
1413

1514
model_id: str = Field(
@@ -32,21 +31,3 @@ class TextEmbeddingConfig(BaseModel):
3231
description="The specific embeddings to perform.",
3332
default=graphrag_config_defaults.embed_text.names,
3433
)
35-
strategy: dict | None = Field(
36-
description="The override strategy to use.",
37-
default=graphrag_config_defaults.embed_text.strategy,
38-
)
39-
40-
def resolved_strategy(self, model_config: LanguageModelConfig) -> dict:
41-
"""Get the resolved text embedding strategy."""
42-
from graphrag.index.operations.embed_text.embed_text import (
43-
TextEmbedStrategyType,
44-
)
45-
46-
return self.strategy or {
47-
"type": TextEmbedStrategyType.openai,
48-
"llm": model_config.model_dump(),
49-
"num_threads": model_config.concurrent_requests,
50-
"batch_size": self.batch_size,
51-
"batch_max_tokens": self.batch_max_tokens,
52-
}

graphrag/config/models/extract_claims_config.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,23 @@
33

44
"""Parameterization settings for the default configuration."""
55

6+
from dataclasses import dataclass
67
from pathlib import Path
78

89
from pydantic import BaseModel, Field
910

1011
from graphrag.config.defaults import graphrag_config_defaults
11-
from graphrag.config.models.language_model_config import LanguageModelConfig
12+
from graphrag.prompts.index.extract_claims import EXTRACT_CLAIMS_PROMPT
1213

1314

14-
class ClaimExtractionConfig(BaseModel):
15+
@dataclass
16+
class ClaimExtractionPrompts:
17+
"""Claim extraction prompt templates."""
18+
19+
extraction_prompt: str
20+
21+
22+
class ExtractClaimsConfig(BaseModel):
1523
"""Configuration section for claim extraction."""
1624

1725
enabled: bool = Field(
@@ -34,22 +42,11 @@ class ClaimExtractionConfig(BaseModel):
3442
description="The maximum number of entity gleanings to use.",
3543
default=graphrag_config_defaults.extract_claims.max_gleanings,
3644
)
37-
strategy: dict | None = Field(
38-
description="The override strategy to use.",
39-
default=graphrag_config_defaults.extract_claims.strategy,
40-
)
4145

42-
def resolved_strategy(
43-
self, root_dir: str, model_config: LanguageModelConfig
44-
) -> dict:
45-
"""Get the resolved claim extraction strategy."""
46-
return self.strategy or {
47-
"llm": model_config.model_dump(),
48-
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
49-
encoding="utf-8"
50-
)
46+
def resolved_prompts(self, root_dir: str) -> ClaimExtractionPrompts:
47+
"""Get the resolved claim extraction prompts."""
48+
return ClaimExtractionPrompts(
49+
extraction_prompt=(Path(root_dir) / self.prompt).read_text(encoding="utf-8")
5150
if self.prompt
52-
else None,
53-
"claim_description": self.description,
54-
"max_gleanings": self.max_gleanings,
55-
}
51+
else EXTRACT_CLAIMS_PROMPT,
52+
)

graphrag/config/models/extract_graph_config.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,20 @@
33

44
"""Parameterization settings for the default configuration."""
55

6+
from dataclasses import dataclass
67
from pathlib import Path
78

89
from pydantic import BaseModel, Field
910

1011
from graphrag.config.defaults import graphrag_config_defaults
11-
from graphrag.config.models.language_model_config import LanguageModelConfig
12+
from graphrag.prompts.index.extract_graph import GRAPH_EXTRACTION_PROMPT
13+
14+
15+
@dataclass
16+
class ExtractGraphPrompts:
17+
"""Graph extraction prompt templates."""
18+
19+
extraction_prompt: str
1220

1321

1422
class ExtractGraphConfig(BaseModel):
@@ -30,26 +38,11 @@ class ExtractGraphConfig(BaseModel):
3038
description="The maximum number of entity gleanings to use.",
3139
default=graphrag_config_defaults.extract_graph.max_gleanings,
3240
)
33-
strategy: dict | None = Field(
34-
description="Override the default entity extraction strategy",
35-
default=graphrag_config_defaults.extract_graph.strategy,
36-
)
3741

38-
def resolved_strategy(
39-
self, root_dir: str, model_config: LanguageModelConfig
40-
) -> dict:
41-
"""Get the resolved entity extraction strategy."""
42-
from graphrag.index.operations.extract_graph.typing import (
43-
ExtractEntityStrategyType,
44-
)
45-
46-
return self.strategy or {
47-
"type": ExtractEntityStrategyType.graph_intelligence,
48-
"llm": model_config.model_dump(),
49-
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
50-
encoding="utf-8"
51-
)
42+
def resolved_prompts(self, root_dir: str) -> ExtractGraphPrompts:
43+
"""Get the resolved graph extraction prompts."""
44+
return ExtractGraphPrompts(
45+
extraction_prompt=(Path(root_dir) / self.prompt).read_text(encoding="utf-8")
5246
if self.prompt
53-
else None,
54-
"max_gleanings": self.max_gleanings,
55-
}
47+
else GRAPH_EXTRACTION_PROMPT,
48+
)

0 commit comments

Comments
 (0)