Skip to content

Commit a2647da

Browse files
authored
Simplify flow config (microsoft#1554)
* Flatten compute_communities config * Remove cluster strategy type * Flatten create_base_text_units config * Move cluster seed to config default, leave as None in functions * Remove "prechunked" logic * Remove hard-coded encoding model * Remove unused variables * Strongly type embed_config * Simplify layout_graph config * Semver * Fix integration test * Fix config unit tests: ignore new config defaults * Remove pipeline integ test
1 parent e6de713 commit a2647da

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+285
-626
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Simplify and streamline internal config."
4+
}

graphrag/config/create_graphrag_config.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from graphrag.config.input_models.graphrag_config_input import GraphRagConfigInput
3232
from graphrag.config.input_models.llm_config_input import LLMConfigInput
3333
from graphrag.config.models.cache_config import CacheConfig
34-
from graphrag.config.models.chunking_config import ChunkingConfig
34+
from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
3535
from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig
3636
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
3737
from graphrag.config.models.community_reports_config import CommunityReportsConfig
@@ -318,13 +318,16 @@ def hydrate_parallelization_params(
318318
reader.envvar_prefix(Section.node2vec),
319319
reader.use(values.get("embed_graph")),
320320
):
321+
use_lcc = reader.bool("use_lcc")
321322
embed_graph_model = EmbedGraphConfig(
322323
enabled=reader.bool(Fragment.enabled) or defs.NODE2VEC_ENABLED,
324+
dimensions=reader.int("dimensions") or defs.NODE2VEC_DIMENSIONS,
323325
num_walks=reader.int("num_walks") or defs.NODE2VEC_NUM_WALKS,
324326
walk_length=reader.int("walk_length") or defs.NODE2VEC_WALK_LENGTH,
325327
window_size=reader.int("window_size") or defs.NODE2VEC_WINDOW_SIZE,
326328
iterations=reader.int("iterations") or defs.NODE2VEC_ITERATIONS,
327329
random_seed=reader.int("random_seed") or defs.NODE2VEC_RANDOM_SEED,
330+
use_lcc=use_lcc if use_lcc is not None else defs.USE_LCC,
328331
)
329332
with reader.envvar_prefix(Section.input), reader.use(values.get("input")):
330333
input_type = reader.str("type")
@@ -412,12 +415,15 @@ def hydrate_parallelization_params(
412415
encoding_model = (
413416
reader.str(Fragment.encoding_model) or global_encoding_model
414417
)
415-
418+
strategy = reader.str("strategy")
416419
chunks_model = ChunkingConfig(
417420
size=reader.int("size") or defs.CHUNK_SIZE,
418421
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
419422
group_by_columns=group_by_columns,
420423
encoding_model=encoding_model,
424+
strategy=ChunkStrategyType(strategy)
425+
if strategy
426+
else ChunkStrategyType.tokens,
421427
)
422428
with (
423429
reader.envvar_prefix(Section.snapshot),
@@ -522,8 +528,13 @@ def hydrate_parallelization_params(
522528
)
523529

524530
with reader.use(values.get("cluster_graph")):
531+
use_lcc = reader.bool("use_lcc")
532+
seed = reader.int("seed")
525533
cluster_graph_model = ClusterGraphConfig(
526-
max_cluster_size=reader.int("max_cluster_size") or defs.MAX_CLUSTER_SIZE
534+
max_cluster_size=reader.int("max_cluster_size")
535+
or defs.MAX_CLUSTER_SIZE,
536+
use_lcc=use_lcc if use_lcc is not None else defs.USE_LCC,
537+
seed=seed if seed is not None else defs.CLUSTER_GRAPH_SEED,
527538
)
528539

529540
with (

graphrag/config/defaults.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
CLAIM_MAX_GLEANINGS = 1
6161
CLAIM_EXTRACTION_ENABLED = False
6262
MAX_CLUSTER_SIZE = 10
63+
USE_LCC = True
64+
CLUSTER_GRAPH_SEED = 0xDEADBEEF
6365
COMMUNITY_REPORT_MAX_LENGTH = 2000
6466
COMMUNITY_REPORT_MAX_INPUT_LENGTH = 8000
6567
ENTITY_EXTRACTION_ENTITY_TYPES = ["organization", "person", "geo", "event"]
@@ -74,6 +76,7 @@
7476
PARALLELIZATION_STAGGER = 0.3
7577
PARALLELIZATION_NUM_THREADS = 50
7678
NODE2VEC_ENABLED = False
79+
NODE2VEC_DIMENSIONS = 1536
7780
NODE2VEC_NUM_WALKS = 10
7881
NODE2VEC_WALK_LENGTH = 40
7982
NODE2VEC_WINDOW_SIZE = 2

graphrag/config/init_content.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
### LLM settings ###
1313
## There are a number of settings to tune the threading and token limits for LLM calls - check the docs.
1414
15-
encoding_model: cl100k_base # this needs to be matched to your model!
15+
encoding_model: {defs.ENCODING_MODEL} # this needs to be matched to your model!
1616
1717
llm:
1818
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
@@ -111,7 +111,7 @@
111111
enabled: false # if true, will generate node2vec embeddings for nodes
112112
113113
umap:
114-
enabled: false # if true, will generate UMAP embeddings for nodes
114+
enabled: false # if true, will generate UMAP embeddings for nodes (embed_graph must also be enabled)
115115
116116
snapshots:
117117
graphml: false

graphrag/config/models/chunking_config.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,24 @@
33

44
"""Parameterization settings for the default configuration."""
55

6+
from enum import Enum
7+
68
from pydantic import BaseModel, Field
79

810
import graphrag.config.defaults as defs
911

1012

13+
class ChunkStrategyType(str, Enum):
14+
"""ChunkStrategy class definition."""
15+
16+
tokens = "tokens"
17+
sentence = "sentence"
18+
19+
def __repr__(self):
20+
"""Get a string representation."""
21+
return f'"{self.value}"'
22+
23+
1124
class ChunkingConfig(BaseModel):
1225
"""Configuration section for chunking."""
1326

@@ -19,22 +32,9 @@ class ChunkingConfig(BaseModel):
1932
description="The chunk by columns to use.",
2033
default=defs.CHUNK_GROUP_BY_COLUMNS,
2134
)
22-
strategy: dict | None = Field(
23-
description="The chunk strategy to use, overriding the default tokenization strategy",
24-
default=None,
35+
strategy: ChunkStrategyType = Field(
36+
description="The chunking strategy to use.", default=ChunkStrategyType.tokens
2537
)
26-
encoding_model: str | None = Field(
27-
default=None, description="The encoding model to use."
38+
encoding_model: str = Field(
39+
description="The encoding model to use.", default=defs.ENCODING_MODEL
2840
)
29-
30-
def resolved_strategy(self, encoding_model: str | None) -> dict:
31-
"""Get the resolved chunking strategy."""
32-
from graphrag.index.operations.chunk_text import ChunkStrategyType
33-
34-
return self.strategy or {
35-
"type": ChunkStrategyType.tokens,
36-
"chunk_size": self.size,
37-
"chunk_overlap": self.overlap,
38-
"group_by_columns": self.group_by_columns,
39-
"encoding_name": encoding_model or self.encoding_model,
40-
}

graphrag/config/models/cluster_graph_config.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,11 @@ class ClusterGraphConfig(BaseModel):
1414
max_cluster_size: int = Field(
1515
description="The maximum cluster size to use.", default=defs.MAX_CLUSTER_SIZE
1616
)
17-
strategy: dict | None = Field(
18-
description="The cluster strategy to use.", default=None
17+
use_lcc: bool = Field(
18+
description="Whether to use the largest connected component.",
19+
default=defs.USE_LCC,
20+
)
21+
seed: int | None = Field(
22+
description="The seed to use for the clustering.",
23+
default=defs.CLUSTER_GRAPH_SEED,
1924
)
20-
21-
def resolved_strategy(self) -> dict:
22-
"""Get the resolved cluster strategy."""
23-
from graphrag.index.operations.cluster_graph import GraphCommunityStrategyType
24-
25-
return self.strategy or {
26-
"type": GraphCommunityStrategyType.leiden,
27-
"max_cluster_size": self.max_cluster_size,
28-
}

graphrag/config/models/embed_graph_config.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ class EmbedGraphConfig(BaseModel):
1515
description="A flag indicating whether to enable node2vec.",
1616
default=defs.NODE2VEC_ENABLED,
1717
)
18+
dimensions: int = Field(
19+
description="The node2vec vector dimensions.", default=defs.NODE2VEC_DIMENSIONS
20+
)
1821
num_walks: int = Field(
1922
description="The node2vec number of walks.", default=defs.NODE2VEC_NUM_WALKS
2023
)
@@ -30,21 +33,7 @@ class EmbedGraphConfig(BaseModel):
3033
random_seed: int = Field(
3134
description="The node2vec random seed.", default=defs.NODE2VEC_RANDOM_SEED
3235
)
33-
strategy: dict | None = Field(
34-
description="The graph embedding strategy override.", default=None
36+
use_lcc: bool = Field(
37+
description="Whether to use the largest connected component.",
38+
default=defs.USE_LCC,
3539
)
36-
37-
def resolved_strategy(self) -> dict:
38-
"""Get the resolved node2vec strategy."""
39-
from graphrag.index.operations.embed_graph.typing import (
40-
EmbedGraphStrategyType,
41-
)
42-
43-
return self.strategy or {
44-
"type": EmbedGraphStrategyType.node2vec,
45-
"num_walks": self.num_walks,
46-
"walk_length": self.walk_length,
47-
"window_size": self.window_size,
48-
"iterations": self.iterations,
49-
"random_seed": self.iterations,
50-
}

graphrag/config/models/entity_extraction_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,5 @@ def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
4848
if self.prompt
4949
else None,
5050
"max_gleanings": self.max_gleanings,
51-
# It's prechunked in create_base_text_units
5251
"encoding_name": encoding_model or self.encoding_model,
53-
"prechunked": True,
5452
}

graphrag/index/create_pipeline_config.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,8 @@ def _text_unit_workflows(
176176
PipelineWorkflowReference(
177177
name=create_base_text_units,
178178
config={
179+
"chunks": settings.chunks,
179180
"snapshot_transient": settings.snapshots.transient,
180-
"chunk_by": settings.chunks.group_by_columns,
181-
"text_chunk": {
182-
"strategy": settings.chunks.resolved_strategy(
183-
settings.encoding_model
184-
)
185-
},
186181
},
187182
),
188183
PipelineWorkflowReference(
@@ -243,9 +238,7 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
243238
PipelineWorkflowReference(
244239
name=compute_communities,
245240
config={
246-
"cluster_graph": {
247-
"strategy": settings.cluster_graph.resolved_strategy()
248-
},
241+
"cluster_graph": settings.cluster_graph,
249242
"snapshot_transient": settings.snapshots.transient,
250243
},
251244
),
@@ -260,9 +253,8 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
260253
PipelineWorkflowReference(
261254
name=create_final_nodes,
262255
config={
263-
"layout_graph_enabled": settings.umap.enabled,
264-
"embed_graph_enabled": settings.embed_graph.enabled,
265-
"embed_graph": {"strategy": settings.embed_graph.resolved_strategy()},
256+
"layout_enabled": settings.umap.enabled,
257+
"embed_graph": settings.embed_graph,
266258
},
267259
),
268260
]

graphrag/index/flows/compute_communities.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
"""All the steps to create the base entity graph."""
55

6-
from typing import Any
7-
86
import pandas as pd
97

108
from graphrag.index.operations.cluster_graph import cluster_graph
@@ -13,14 +11,18 @@
1311

1412
def compute_communities(
1513
base_relationship_edges: pd.DataFrame,
16-
clustering_strategy: dict[str, Any],
14+
max_cluster_size: int,
15+
use_lcc: bool,
16+
seed: int | None = None,
1717
) -> pd.DataFrame:
1818
"""All the steps to create the base entity graph."""
1919
graph = create_graph(base_relationship_edges)
2020

2121
communities = cluster_graph(
2222
graph,
23-
strategy=clustering_strategy,
23+
max_cluster_size,
24+
use_lcc,
25+
seed=seed,
2426
)
2527

2628
base_communities = pd.DataFrame(

0 commit comments

Comments
 (0)