Skip to content

Commit 7748493

Browse files
committed
Streamline chunking config
1 parent a741bfb commit 7748493

File tree

13 files changed

+47
-51
lines changed

13 files changed

+47
-51
lines changed

packages/graphrag-chunking/graphrag_chunking/chunk_strategy_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from enum import StrEnum
77

88

9-
class ChunkStrategyType(StrEnum):
10-
"""ChunkStrategy class definition."""
9+
class ChunkerType(StrEnum):
10+
"""ChunkerType class definition."""
1111

1212
Tokens = "tokens"
1313
Sentence = "sentence"

packages/graphrag-chunking/graphrag_chunking/chunker_factory.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from graphrag_common.factory.factory import Factory, ServiceScope
99

10-
from graphrag_chunking.chunk_strategy_type import ChunkStrategyType
10+
from graphrag_chunking.chunk_strategy_type import ChunkerType
1111
from graphrag_chunking.chunker import Chunker
1212
from graphrag_chunking.chunking_config import ChunkingConfig
1313

@@ -58,18 +58,18 @@ def create_chunker(
5858
config_model["encode"] = encode
5959
if decode is not None:
6060
config_model["decode"] = decode
61-
chunker_strategy = config.strategy
61+
chunker_strategy = config.type
6262

6363
if chunker_strategy not in chunker_factory:
6464
match chunker_strategy:
65-
case ChunkStrategyType.Tokens:
65+
case ChunkerType.Tokens:
6666
from graphrag_chunking.token_chunker import TokenChunker
6767

68-
register_chunker(ChunkStrategyType.Tokens, TokenChunker)
69-
case ChunkStrategyType.Sentence:
68+
register_chunker(ChunkerType.Tokens, TokenChunker)
69+
case ChunkerType.Sentence:
7070
from graphrag_chunking.sentence_chunker import SentenceChunker
7171

72-
register_chunker(ChunkStrategyType.Sentence, SentenceChunker)
72+
register_chunker(ChunkerType.Sentence, SentenceChunker)
7373
case _:
7474
msg = f"ChunkingConfig.strategy '{chunker_strategy}' is not registered in the ChunkerFactory. Registered types: {', '.join(chunker_factory.keys())}."
7575
raise ValueError(msg)

packages/graphrag-chunking/graphrag_chunking/chunking_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from pydantic import BaseModel, ConfigDict, Field
77

8-
from graphrag_chunking.chunk_strategy_type import ChunkStrategyType
8+
from graphrag_chunking.chunk_strategy_type import ChunkerType
99

1010

1111
class ChunkingConfig(BaseModel):
@@ -14,9 +14,9 @@ class ChunkingConfig(BaseModel):
1414
model_config = ConfigDict(extra="allow")
1515
"""Allow extra fields to support custom cache implementations."""
1616

17-
strategy: str = Field(
18-
description="The chunking strategy to use.",
19-
default=ChunkStrategyType.Tokens,
17+
type: str = Field(
18+
description="The chunking type to use.",
19+
default=ChunkerType.Tokens,
2020
)
2121
encoding_model: str | None = Field(
2222
description="The encoding model to use.",

packages/graphrag/graphrag/cli/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,14 @@ def _prompt_tune_cli(
306306
help="The minimum number of examples to generate/include in the entity extraction prompt.",
307307
),
308308
chunk_size: int = typer.Option(
309-
graphrag_config_defaults.chunks.size,
309+
graphrag_config_defaults.chunking.size,
310310
"--chunk-size",
311-
help="The size of each example text chunk. Overrides chunks.size in the configuration file.",
311+
help="The size of each example text chunk. Overrides chunking.size in the configuration file.",
312312
),
313313
overlap: int = typer.Option(
314-
graphrag_config_defaults.chunks.overlap,
314+
graphrag_config_defaults.chunking.overlap,
315315
"--overlap",
316-
help="The overlap size for chunking documents. Overrides chunks.overlap in the configuration file.",
316+
help="The overlap size for chunking documents. Overrides chunking.overlap in the configuration file.",
317317
),
318318
language: str | None = typer.Option(
319319
None,

packages/graphrag/graphrag/cli/prompt_tune.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ async def prompt_tune(
6161
)
6262

6363
# override chunking config in the configuration
64-
if chunk_size != graph_config.chunks.size:
65-
graph_config.chunks.size = chunk_size
64+
if chunk_size != graph_config.chunking.size:
65+
graph_config.chunking.size = chunk_size
6666

67-
if overlap != graph_config.chunks.overlap:
68-
graph_config.chunks.overlap = overlap
67+
if overlap != graph_config.chunking.overlap:
68+
graph_config.chunking.overlap = overlap
6969

7070
# configure the root logger with the specified log level
7171
from graphrag.logger.standard_logging import init_loggers

packages/graphrag/graphrag/config/defaults.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import ClassVar
99

1010
from graphrag_cache import CacheType
11-
from graphrag_chunking.chunk_strategy_type import ChunkStrategyType
11+
from graphrag_chunking.chunk_strategy_type import ChunkerType
1212
from graphrag_storage import StorageType
1313

1414
from graphrag.config.embeddings import default_embeddings
@@ -57,10 +57,10 @@ class BasicSearchDefaults:
5757

5858

5959
@dataclass
60-
class ChunksDefaults:
61-
"""Default values for chunks."""
60+
class ChunkingDefaults:
61+
"""Default values for chunking."""
6262

63-
strategy: str = ChunkStrategyType.Tokens
63+
type: str = ChunkerType.Tokens
6464
size: int = 1200
6565
overlap: int = 100
6666
encoding_model: str = ENCODING_MODEL
@@ -126,7 +126,6 @@ class EmbedTextDefaults:
126126
batch_size: int = 16
127127
batch_max_tokens: int = 8191
128128
names: list[str] = field(default_factory=lambda: default_embeddings)
129-
strategy: None = None
130129

131130

132131
@dataclass
@@ -139,7 +138,6 @@ class ExtractClaimsDefaults:
139138
"Any claims or facts that could be relevant to information discovery."
140139
)
141140
max_gleanings: int = 1
142-
strategy: None = None
143141
model_id: str = DEFAULT_CHAT_MODEL_ID
144142
model_instance_name: str = "extract_claims"
145143

@@ -153,7 +151,6 @@ class ExtractGraphDefaults:
153151
default_factory=lambda: ["organization", "person", "geo", "event"]
154152
)
155153
max_gleanings: int = 1
156-
strategy: None = None
157154
model_id: str = DEFAULT_CHAT_MODEL_ID
158155
model_instance_name: str = "extract_graph"
159156

@@ -360,7 +357,6 @@ class SummarizeDescriptionsDefaults:
360357
prompt: None = None
361358
max_length: int = 500
362359
max_input_tokens: int = 4_000
363-
strategy: None = None
364360
model_id: str = DEFAULT_CHAT_MODEL_ID
365361
model_instance_name: str = "summarize_descriptions"
366362

@@ -401,7 +397,7 @@ class GraphRagConfigDefaults:
401397
cache: CacheDefaults = field(default_factory=CacheDefaults)
402398
input: InputDefaults = field(default_factory=InputDefaults)
403399
embed_text: EmbedTextDefaults = field(default_factory=EmbedTextDefaults)
404-
chunks: ChunksDefaults = field(default_factory=ChunksDefaults)
400+
chunking: ChunkingDefaults = field(default_factory=ChunkingDefaults)
405401
snapshots: SnapshotsDefaults = field(default_factory=SnapshotsDefaults)
406402
extract_graph: ExtractGraphDefaults = field(default_factory=ExtractGraphDefaults)
407403
extract_graph_nlp: ExtractGraphNLPDefaults = field(

packages/graphrag/graphrag/config/init_content.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@
5454
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
5555
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
5656
57-
chunks:
58-
strategy: {graphrag_config_defaults.chunks.strategy}
59-
size: {graphrag_config_defaults.chunks.size}
60-
overlap: {graphrag_config_defaults.chunks.overlap}
61-
encoding_model: {graphrag_config_defaults.chunks.encoding_model}
57+
chunking:
58+
type: {graphrag_config_defaults.chunking.type}
59+
size: {graphrag_config_defaults.chunking.size}
60+
overlap: {graphrag_config_defaults.chunking.overlap}
61+
encoding_model: {graphrag_config_defaults.chunking.encoding_model}
6262
6363
### Output/storage settings ###
6464
## If blob storage is specified in the following four sections,

packages/graphrag/graphrag/config/models/graph_rag_config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,14 @@ def _validate_input_base_dir(self) -> None:
125125
Path(self.input.storage.base_dir).resolve()
126126
)
127127

128-
chunks: ChunkingConfig = Field(
128+
chunking: ChunkingConfig = Field(
129129
description="The chunking configuration to use.",
130130
default=ChunkingConfig(
131-
strategy=graphrag_config_defaults.chunks.strategy,
132-
size=graphrag_config_defaults.chunks.size,
133-
overlap=graphrag_config_defaults.chunks.overlap,
134-
encoding_model=graphrag_config_defaults.chunks.encoding_model,
135-
prepend_metadata=graphrag_config_defaults.chunks.prepend_metadata,
131+
type=graphrag_config_defaults.chunking.type,
132+
size=graphrag_config_defaults.chunking.size,
133+
overlap=graphrag_config_defaults.chunking.overlap,
134+
encoding_model=graphrag_config_defaults.chunking.encoding_model,
135+
prepend_metadata=graphrag_config_defaults.chunking.prepend_metadata,
136136
),
137137
)
138138
"""The chunking configuration to use."""

packages/graphrag/graphrag/index/workflows/create_base_text_units.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ async def run_workflow(
3333
logger.info("Workflow started: create_base_text_units")
3434
documents = await load_table_from_storage("documents", context.output_storage)
3535

36-
tokenizer = get_tokenizer(encoding_model=config.chunks.encoding_model)
37-
chunker = create_chunker(config.chunks, tokenizer.encode, tokenizer.decode)
36+
tokenizer = get_tokenizer(encoding_model=config.chunking.encoding_model)
37+
chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode)
3838
output = create_base_text_units(
3939
documents,
4040
context.callbacks,
4141
tokenizer=tokenizer,
4242
chunker=chunker,
43-
prepend_metadata=config.chunks.prepend_metadata,
43+
prepend_metadata=config.chunking.prepend_metadata,
4444
)
4545

4646
await write_table_to_storage(output, "text_units", context.output_storage)

packages/graphrag/graphrag/prompt_tune/loader/input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def load_docs_in_chunks(
6262
cache=NoopCache(),
6363
)
6464
tokenizer = get_tokenizer(embeddings_llm_settings)
65-
chunker = create_chunker(config.chunks, tokenizer.encode, tokenizer.decode)
65+
chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode)
6666
input_storage = create_storage(config.input.storage)
6767
input_reader = InputReaderFactory().create(
6868
config.input.file_type,

0 commit comments

Comments
 (0)