Skip to content

Commit 1df8972

Browse files
Pipeline registration (microsoft#1940)
* Move covariate run conditional * All pipeline registration * Fix method name construction * Rename context storage -> output_storage * Rename OutputConfig as generic StorageConfig * Reuse Storage model under InputConfig * Move input storage creation out of document loading * Move document loading into workflows * Semver * Fix smoke test config for new workflows * Fix unit tests --------- Co-authored-by: Alonso Guevara <[email protected]>
1 parent 17e431c commit 1df8972

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+602
-424
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Allow injection of custom pipelines."
4+
}

graphrag/api/index.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
async def build_index(
2929
config: GraphRagConfig,
30-
method: IndexingMethod = IndexingMethod.Standard,
30+
method: IndexingMethod | str = IndexingMethod.Standard,
3131
is_update_run: bool = False,
3232
memory_profile: bool = False,
3333
callbacks: list[WorkflowCallbacks] | None = None,
@@ -65,7 +65,9 @@ async def build_index(
6565
if memory_profile:
6666
log.warning("New pipeline does not yet support memory profiling.")
6767

68-
pipeline = PipelineFactory.create_pipeline(config, method, is_update_run)
68+
# todo: this could propagate out to the cli for better clarity, but will be a breaking api change
69+
method = _get_method(method, is_update_run)
70+
pipeline = PipelineFactory.create_pipeline(config, method)
6971

7072
workflow_callbacks.pipeline_start(pipeline.names())
7173

@@ -90,3 +92,8 @@ async def build_index(
9092
def register_workflow_function(name: str, workflow: WorkflowFunction):
9193
"""Register a custom workflow function. You can then include the name in the settings.yaml workflows list."""
9294
PipelineFactory.register(name, workflow)
95+
96+
97+
def _get_method(method: IndexingMethod | str, is_update_run: bool) -> str:
98+
m = method.value if isinstance(method, IndexingMethod) else method
99+
return f"{m}-update" if is_update_run else m

graphrag/api/prompt_tune.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
async def generate_indexing_prompts(
5353
config: GraphRagConfig,
5454
logger: ProgressLogger,
55-
root: str,
5655
chunk_size: PositiveInt = graphrag_config_defaults.chunks.size,
5756
overlap: Annotated[
5857
int, annotated_types.Gt(-1)
@@ -93,7 +92,6 @@ async def generate_indexing_prompts(
9392
# Retrieve documents
9493
logger.info("Chunking documents...")
9594
doc_list = await load_docs_in_chunks(
96-
root=root,
9795
config=config,
9896
limit=limit,
9997
select_method=selection_method,

graphrag/cli/index.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def index_cli(
8080
cli_overrides["reporting.base_dir"] = str(output_dir)
8181
cli_overrides["update_index_output.base_dir"] = str(output_dir)
8282
config = load_config(root_dir, config_filepath, cli_overrides)
83-
8483
_run_index(
8584
config=config,
8685
method=method,

graphrag/cli/prompt_tune.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ async def prompt_tune(
8686

8787
prompts = await api.generate_indexing_prompts(
8888
config=graph_config,
89-
root=str(root_path),
9089
logger=progress_logger,
9190
chunk_size=chunk_size,
9291
overlap=overlap,

graphrag/config/defaults.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414
CacheType,
1515
ChunkStrategyType,
1616
InputFileType,
17-
InputType,
1817
ModelType,
1918
NounPhraseExtractorType,
20-
OutputType,
2119
ReportingType,
20+
StorageType,
2221
)
2322
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
2423
EN_STOP_WORDS,
@@ -234,16 +233,31 @@ class GlobalSearchDefaults:
234233
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
235234

236235

236+
@dataclass
237+
class StorageDefaults:
238+
"""Default values for storage."""
239+
240+
type = StorageType.file
241+
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
242+
connection_string: None = None
243+
container_name: None = None
244+
storage_account_blob_url: None = None
245+
cosmosdb_account_url: None = None
246+
247+
248+
@dataclass
249+
class InputStorageDefaults(StorageDefaults):
250+
"""Default values for input storage."""
251+
252+
base_dir: str = "input"
253+
254+
237255
@dataclass
238256
class InputDefaults:
239257
"""Default values for input."""
240258

241-
type = InputType.file
259+
storage: InputStorageDefaults = field(default_factory=InputStorageDefaults)
242260
file_type = InputFileType.text
243-
base_dir: str = "input"
244-
connection_string: None = None
245-
storage_account_blob_url: None = None
246-
container_name: None = None
247261
encoding: str = "utf-8"
248262
file_pattern: str = ""
249263
file_filter: None = None
@@ -301,15 +315,10 @@ class LocalSearchDefaults:
301315

302316

303317
@dataclass
304-
class OutputDefaults:
318+
class OutputDefaults(StorageDefaults):
305319
"""Default values for output."""
306320

307-
type = OutputType.file
308321
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
309-
connection_string: None = None
310-
container_name: None = None
311-
storage_account_blob_url: None = None
312-
cosmosdb_account_url: None = None
313322

314323

315324
@dataclass
@@ -364,14 +373,10 @@ class UmapDefaults:
364373

365374

366375
@dataclass
367-
class UpdateIndexOutputDefaults:
376+
class UpdateIndexOutputDefaults(StorageDefaults):
368377
"""Default values for update index output."""
369378

370-
type = OutputType.file
371379
base_dir: str = "update_output"
372-
connection_string: None = None
373-
container_name: None = None
374-
storage_account_blob_url: None = None
375380

376381

377382
@dataclass
@@ -395,6 +400,7 @@ class GraphRagConfigDefaults:
395400
root_dir: str = ""
396401
models: dict = field(default_factory=dict)
397402
reporting: ReportingDefaults = field(default_factory=ReportingDefaults)
403+
storage: StorageDefaults = field(default_factory=StorageDefaults)
398404
output: OutputDefaults = field(default_factory=OutputDefaults)
399405
outputs: None = None
400406
update_index_output: UpdateIndexOutputDefaults = field(

graphrag/config/enums.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,7 @@ def __repr__(self):
4242
return f'"{self.value}"'
4343

4444

45-
class InputType(str, Enum):
46-
"""The input type for the pipeline."""
47-
48-
file = "file"
49-
"""The file storage type."""
50-
blob = "blob"
51-
"""The blob storage type."""
52-
53-
def __repr__(self):
54-
"""Get a string representation."""
55-
return f'"{self.value}"'
56-
57-
58-
class OutputType(str, Enum):
45+
class StorageType(str, Enum):
5946
"""The output type for the pipeline."""
6047

6148
file = "file"
@@ -152,6 +139,10 @@ class IndexingMethod(str, Enum):
152139
"""Traditional GraphRAG indexing, with all graph construction and summarization performed by a language model."""
153140
Fast = "fast"
154141
"""Fast indexing, using NLP for graph construction and language model for summarization."""
142+
StandardUpdate = "standard-update"
143+
"""Incremental update with standard indexing."""
144+
FastUpdate = "fast-update"
145+
"""Incremental update with fast indexing."""
155146

156147

157148
class NounPhraseExtractorType(str, Enum):

graphrag/config/init_content.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@
5858
### Input settings ###
5959
6060
input:
61-
type: {graphrag_config_defaults.input.type.value} # or blob
61+
storage:
62+
type: {graphrag_config_defaults.input.storage.type.value} # or blob
63+
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
6264
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
63-
base_dir: "{graphrag_config_defaults.input.base_dir}"
65+
6466
6567
chunks:
6668
size: {graphrag_config_defaults.chunks.size}

graphrag/config/models/graph_rag_config.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
from graphrag.config.models.input_config import InputConfig
2727
from graphrag.config.models.language_model_config import LanguageModelConfig
2828
from graphrag.config.models.local_search_config import LocalSearchConfig
29-
from graphrag.config.models.output_config import OutputConfig
3029
from graphrag.config.models.prune_graph_config import PruneGraphConfig
3130
from graphrag.config.models.reporting_config import ReportingConfig
3231
from graphrag.config.models.snapshots_config import SnapshotsConfig
32+
from graphrag.config.models.storage_config import StorageConfig
3333
from graphrag.config.models.summarize_descriptions_config import (
3434
SummarizeDescriptionsConfig,
3535
)
@@ -102,29 +102,39 @@ def _validate_input_pattern(self) -> None:
102102
else:
103103
self.input.file_pattern = f".*\\.{self.input.file_type.value}$"
104104

105+
def _validate_input_base_dir(self) -> None:
106+
"""Validate the input base directory."""
107+
if self.input.storage.type == defs.StorageType.file:
108+
if self.input.storage.base_dir.strip() == "":
109+
msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration."
110+
raise ValueError(msg)
111+
self.input.storage.base_dir = str(
112+
(Path(self.root_dir) / self.input.storage.base_dir).resolve()
113+
)
114+
105115
chunks: ChunkingConfig = Field(
106116
description="The chunking configuration to use.",
107117
default=ChunkingConfig(),
108118
)
109119
"""The chunking configuration to use."""
110120

111-
output: OutputConfig = Field(
121+
output: StorageConfig = Field(
112122
description="The output configuration.",
113-
default=OutputConfig(),
123+
default=StorageConfig(),
114124
)
115125
"""The output configuration."""
116126

117127
def _validate_output_base_dir(self) -> None:
118128
"""Validate the output base directory."""
119-
if self.output.type == defs.OutputType.file:
129+
if self.output.type == defs.StorageType.file:
120130
if self.output.base_dir.strip() == "":
121131
msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
122132
raise ValueError(msg)
123133
self.output.base_dir = str(
124134
(Path(self.root_dir) / self.output.base_dir).resolve()
125135
)
126136

127-
outputs: dict[str, OutputConfig] | None = Field(
137+
outputs: dict[str, StorageConfig] | None = Field(
128138
description="A list of output configurations used for multi-index query.",
129139
default=graphrag_config_defaults.outputs,
130140
)
@@ -133,26 +143,25 @@ def _validate_multi_output_base_dirs(self) -> None:
133143
"""Validate the outputs dict base directories."""
134144
if self.outputs:
135145
for output in self.outputs.values():
136-
if output.type == defs.OutputType.file:
146+
if output.type == defs.StorageType.file:
137147
if output.base_dir.strip() == "":
138148
msg = "Output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
139149
raise ValueError(msg)
140150
output.base_dir = str(
141151
(Path(self.root_dir) / output.base_dir).resolve()
142152
)
143153

144-
update_index_output: OutputConfig = Field(
154+
update_index_output: StorageConfig = Field(
145155
description="The output configuration for the updated index.",
146-
default=OutputConfig(
147-
type=graphrag_config_defaults.update_index_output.type,
156+
default=StorageConfig(
148157
base_dir=graphrag_config_defaults.update_index_output.base_dir,
149158
),
150159
)
151160
"""The output configuration for the updated index."""
152161

153162
def _validate_update_index_output_base_dir(self) -> None:
154163
"""Validate the update index output base directory."""
155-
if self.update_index_output.type == defs.OutputType.file:
164+
if self.update_index_output.type == defs.StorageType.file:
156165
if self.update_index_output.base_dir.strip() == "":
157166
msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration."
158167
raise ValueError(msg)
@@ -345,6 +354,7 @@ def _validate_model(self):
345354
self._validate_root_dir()
346355
self._validate_models()
347356
self._validate_input_pattern()
357+
self._validate_input_base_dir()
348358
self._validate_reporting_base_dir()
349359
self._validate_output_base_dir()
350360
self._validate_multi_output_base_dirs()

graphrag/config/models/input_config.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,23 @@
77

88
import graphrag.config.defaults as defs
99
from graphrag.config.defaults import graphrag_config_defaults
10-
from graphrag.config.enums import InputFileType, InputType
10+
from graphrag.config.enums import InputFileType
11+
from graphrag.config.models.storage_config import StorageConfig
1112

1213

1314
class InputConfig(BaseModel):
1415
"""The default configuration section for Input."""
1516

16-
type: InputType = Field(
17-
description="The input type to use.",
18-
default=graphrag_config_defaults.input.type,
17+
storage: StorageConfig = Field(
18+
description="The storage configuration to use for reading input documents.",
19+
default=StorageConfig(
20+
base_dir=graphrag_config_defaults.input.storage.base_dir,
21+
),
1922
)
2023
file_type: InputFileType = Field(
2124
description="The input file type to use.",
2225
default=graphrag_config_defaults.input.file_type,
2326
)
24-
base_dir: str = Field(
25-
description="The input base directory to use.",
26-
default=graphrag_config_defaults.input.base_dir,
27-
)
28-
connection_string: str | None = Field(
29-
description="The azure blob storage connection string to use.",
30-
default=graphrag_config_defaults.input.connection_string,
31-
)
32-
storage_account_blob_url: str | None = Field(
33-
description="The storage account blob url to use.",
34-
default=graphrag_config_defaults.input.storage_account_blob_url,
35-
)
36-
container_name: str | None = Field(
37-
description="The azure blob storage container name to use.",
38-
default=graphrag_config_defaults.input.container_name,
39-
)
4027
encoding: str = Field(
4128
description="The input file encoding to use.",
4229
default=defs.graphrag_config_defaults.input.encoding,

0 commit comments

Comments
 (0)