Skip to content

Commit c644338

Browse files
dworthennatoverseAlonsoGuevara
authored
Refactor config (#1593)
* Refactor config - Add new ModelConfig to represent LLM settings - Combines LLMParameters, ParallelizationParameters, encoding_model, and async_mode - Add top level models config that is a list of available LLM ModelConfigs - Remove LLMConfig inheritance and delete LLMConfig - Replace the inheritance with a model_id reference to the ModelConfig listed in the top level models config - Remove all fallbacks and hydration logic from create_graphrag_config - This removes the automatic env variable overrides - Support env variables within config files using Templating - This requires "$" to be escaped with extra "$" so ".*\\.txt$" becomes ".*\\.txt$$" - Update init content to initialize new config file with the ModelConfig structure * Use dict of ModelConfig instead of list * Add model validations and unit tests * Fix ruff checks * Add semversioner change * Fix unit tests * validate root_dir in pydantic model * Rename ModelConfig to LanguageModelConfig * Rename ModelConfigMissingError to LanguageModelConfigMissingError * Add validationg for unexpected API keys * Allow skipping pydantic validation for testing/mocking purposes. * Add default lm configs to verb tests * smoke test * remove config from flows to fix llm arg mapping * Fix embedding llm arg mapping * Remove timestamp from smoke test outputs * Remove unused "subworkflows" smoke test properties * Add models to smoke test configs * Update smoke test output path * Send logs to logs folder * Fix output path * Fix csv test file pattern * Update placeholder * Format * Instantiate default model configs * Fix unit tests for config defaults * Fix migration notebook * Remove create_pipeline_config * Remove several unused config models * Remove indexing embedding and input configs * Move embeddings function to config * Remove skip_workflows * Remove skip embeddings in favor of explicit naming * fix unit test spelling mistake * self.models[model_id] is already a language model. Remove redundant casting. * update validation errors to instruct users to rerun graphrag init * instantiate LanguageModelConfigs with validation * skip validation in unit tests * update verb tests to use default model settings instead of skipping validation * test using llm settings * cleanup verb tests * remove unsafe default model config * remove the ability to skip pydantic validation * remove None union types when default values are set * move vector_store from embeddings to top level of config and delete resolve_paths * update vector store settings * fix vector store and smoke tests * fix serializing vector_store settings * fix vector_store usage * fix vector_store type * support cli overrides for loading graphrag config * rename storage to output * Add --force flag to init * Remove run_id and resume, fix Drift config assignment * Ruff --------- Co-authored-by: Nathan Evans <[email protected]> Co-authored-by: Alonso Guevara <[email protected]>
1 parent 47adfe1 commit c644338

File tree

104 files changed

+2253
-3610
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+2253
-3610
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Remove config inheritance, hydration, and automatic env var overlays."
4+
}

docs/examples_notebooks/index_migration.ipynb

Lines changed: 37 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": 66,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -25,7 +25,7 @@
2525
},
2626
{
2727
"cell_type": "code",
28-
"execution_count": 2,
28+
"execution_count": 67,
2929
"metadata": {},
3030
"outputs": [],
3131
"source": [
@@ -37,27 +37,28 @@
3737
},
3838
{
3939
"cell_type": "code",
40-
"execution_count": 3,
40+
"execution_count": null,
4141
"metadata": {},
4242
"outputs": [],
4343
"source": [
4444
"from pathlib import Path\n",
4545
"\n",
4646
"from graphrag.config.load_config import load_config\n",
4747
"from graphrag.config.resolve_path import resolve_paths\n",
48-
"from graphrag.index.create_pipeline_config import create_pipeline_config\n",
49-
"from graphrag.storage.factory import create_storage\n",
48+
"from graphrag.storage.factory import StorageFactory\n",
5049
"\n",
5150
"# This first block does some config loading, path resolution, and translation that is normally done by the CLI/API when running a full workflow\n",
5251
"config = load_config(Path(PROJECT_DIRECTORY))\n",
5352
"resolve_paths(config)\n",
54-
"pipeline_config = create_pipeline_config(config)\n",
55-
"storage = create_storage(pipeline_config.storage)"
53+
"storage_config = config.storage.model_dump() # type: ignore\n",
54+
"storage = StorageFactory().create_storage(\n",
55+
" storage_type=storage_config[\"type\"], kwargs=storage_config\n",
56+
")"
5657
]
5758
},
5859
{
5960
"cell_type": "code",
60-
"execution_count": 4,
61+
"execution_count": 69,
6162
"metadata": {},
6263
"outputs": [],
6364
"source": [
@@ -68,7 +69,7 @@
6869
},
6970
{
7071
"cell_type": "code",
71-
"execution_count": 63,
72+
"execution_count": 70,
7273
"metadata": {},
7374
"outputs": [],
7475
"source": [
@@ -97,7 +98,7 @@
9798
},
9899
{
99100
"cell_type": "code",
100-
"execution_count": 64,
101+
"execution_count": 71,
101102
"metadata": {},
102103
"outputs": [],
103104
"source": [
@@ -108,22 +109,16 @@
108109
"# First we'll go through any parquet files that had model changes and update them\n",
109110
"# The new data model may have removed excess columns as well, but we will only make the minimal changes required for compatibility\n",
110111
"\n",
111-
"final_documents = await load_table_from_storage(\n",
112-
" \"create_final_documents.parquet\", storage\n",
113-
")\n",
114-
"final_text_units = await load_table_from_storage(\n",
115-
" \"create_final_text_units.parquet\", storage\n",
116-
")\n",
117-
"final_entities = await load_table_from_storage(\"create_final_entities.parquet\", storage)\n",
118-
"final_nodes = await load_table_from_storage(\"create_final_nodes.parquet\", storage)\n",
112+
"final_documents = await load_table_from_storage(\"create_final_documents\", storage)\n",
113+
"final_text_units = await load_table_from_storage(\"create_final_text_units\", storage)\n",
114+
"final_entities = await load_table_from_storage(\"create_final_entities\", storage)\n",
115+
"final_nodes = await load_table_from_storage(\"create_final_nodes\", storage)\n",
119116
"final_relationships = await load_table_from_storage(\n",
120-
" \"create_final_relationships.parquet\", storage\n",
121-
")\n",
122-
"final_communities = await load_table_from_storage(\n",
123-
" \"create_final_communities.parquet\", storage\n",
117+
" \"create_final_relationships\", storage\n",
124118
")\n",
119+
"final_communities = await load_table_from_storage(\"create_final_communities\", storage)\n",
125120
"final_community_reports = await load_table_from_storage(\n",
126-
" \"create_final_community_reports.parquet\", storage\n",
121+
" \"create_final_community_reports\", storage\n",
127122
")\n",
128123
"\n",
129124
"\n",
@@ -183,44 +178,41 @@
183178
" parent_df, on=\"community\", how=\"left\"\n",
184179
" )\n",
185180
"\n",
186-
"await write_table_to_storage(final_documents, \"create_final_documents.parquet\", storage)\n",
181+
"await write_table_to_storage(final_documents, \"create_final_documents\", storage)\n",
182+
"await write_table_to_storage(final_text_units, \"create_final_text_units\", storage)\n",
183+
"await write_table_to_storage(final_entities, \"create_final_entities\", storage)\n",
184+
"await write_table_to_storage(final_nodes, \"create_final_nodes\", storage)\n",
185+
"await write_table_to_storage(final_relationships, \"create_final_relationships\", storage)\n",
186+
"await write_table_to_storage(final_communities, \"create_final_communities\", storage)\n",
187187
"await write_table_to_storage(\n",
188-
" final_text_units, \"create_final_text_units.parquet\", storage\n",
189-
")\n",
190-
"await write_table_to_storage(final_entities, \"create_final_entities.parquet\", storage)\n",
191-
"await write_table_to_storage(final_nodes, \"create_final_nodes.parquet\", storage)\n",
192-
"await write_table_to_storage(\n",
193-
" final_relationships, \"create_final_relationships.parquet\", storage\n",
194-
")\n",
195-
"await write_table_to_storage(\n",
196-
" final_communities, \"create_final_communities.parquet\", storage\n",
197-
")\n",
198-
"await write_table_to_storage(\n",
199-
" final_community_reports, \"create_final_community_reports.parquet\", storage\n",
188+
" final_community_reports, \"create_final_community_reports\", storage\n",
200189
")"
201190
]
202191
},
203192
{
204193
"cell_type": "code",
205-
"execution_count": 7,
194+
"execution_count": null,
206195
"metadata": {},
207196
"outputs": [],
208197
"source": [
209-
"from graphrag.cache.factory import create_cache\n",
198+
"from graphrag.cache.factory import CacheFactory\n",
210199
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
200+
"from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings\n",
211201
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
212202
"\n",
213203
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
214204
"# We'll construct the context and run this function flow directly to avoid everything else\n",
215205
"\n",
216-
"workflow = next(\n",
217-
" (x for x in pipeline_config.workflows if x.name == \"generate_text_embeddings\"), None\n",
218-
")\n",
219-
"config = workflow.config\n",
220-
"text_embed = config.get(\"text_embed\", {})\n",
221-
"embedded_fields = config.get(\"embedded_fields\", {})\n",
206+
"\n",
207+
"embedded_fields = get_embedded_fields(config)\n",
208+
"text_embed = get_embedding_settings(config)\n",
222209
"callbacks = NoopWorkflowCallbacks()\n",
223-
"cache = create_cache(pipeline_config.cache, PROJECT_DIRECTORY)\n",
210+
"cache_config = config.cache.model_dump() # type: ignore\n",
211+
"cache = CacheFactory().create_cache(\n",
212+
" cache_type=cache_config[\"type\"], # type: ignore\n",
213+
" root_dir=PROJECT_DIRECTORY,\n",
214+
" kwargs=cache_config,\n",
215+
")\n",
224216
"\n",
225217
"await generate_text_embeddings(\n",
226218
" final_documents=None,\n",

graphrag/api/index.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import logging
1212

1313
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
14-
from graphrag.callbacks.factory import create_pipeline_reporter
14+
from graphrag.callbacks.reporting import create_pipeline_reporter
1515
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1616
from graphrag.config.enums import CacheType
1717
from graphrag.config.models.graph_rag_config import GraphRagConfig
@@ -24,8 +24,6 @@
2424

2525
async def build_index(
2626
config: GraphRagConfig,
27-
run_id: str = "",
28-
is_resume_run: bool = False,
2927
memory_profile: bool = False,
3028
callbacks: list[WorkflowCallbacks] | None = None,
3129
progress_logger: ProgressLogger | None = None,
@@ -36,10 +34,6 @@ async def build_index(
3634
----------
3735
config : GraphRagConfig
3836
The configuration.
39-
run_id : str
40-
The run id. Creates a output directory with this name.
41-
is_resume_run : bool default=False
42-
Whether to resume a previous index run.
4337
memory_profile : bool
4438
Whether to enable memory profiling.
4539
callbacks : list[WorkflowCallbacks] | None default=None
@@ -52,11 +46,7 @@ async def build_index(
5246
list[PipelineRunResult]
5347
The list of pipeline run results
5448
"""
55-
is_update_run = bool(config.update_index_storage)
56-
57-
if is_resume_run and is_update_run:
58-
msg = "Cannot resume and update a run at the same time."
59-
raise ValueError(msg)
49+
is_update_run = bool(config.update_index_output)
6050

6151
pipeline_cache = (
6252
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
@@ -78,7 +68,6 @@ async def build_index(
7868
cache=pipeline_cache,
7969
callbacks=callbacks,
8070
logger=progress_logger,
81-
run_id=run_id,
8271
is_update_run=is_update_run,
8372
):
8473
outputs.append(output)

graphrag/api/prompt_tune.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from graphrag.config.models.graph_rag_config import GraphRagConfig
1818
from graphrag.index.llm.load_llm import load_llm
1919
from graphrag.logger.print_progress import PrintProgressLogger
20-
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT
20+
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT, PROMPT_TUNING_MODEL_ID
2121
from graphrag.prompt_tune.generator.community_report_rating import (
2222
generate_community_report_rating,
2323
)
@@ -95,9 +95,11 @@ async def generate_indexing_prompts(
9595
)
9696

9797
# Create LLM from config
98+
# TODO: Expose way to specify Prompt Tuning model ID through config
99+
default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID)
98100
llm = load_llm(
99101
"prompt_tuning",
100-
config.llm,
102+
default_llm_settings,
101103
cache=None,
102104
callbacks=NoopWorkflowCallbacks(),
103105
)
@@ -120,14 +122,17 @@ async def generate_indexing_prompts(
120122
)
121123

122124
entity_types = None
125+
entity_extraction_llm_settings = config.get_language_model_config(
126+
config.entity_extraction.model_id
127+
)
123128
if discover_entity_types:
124129
logger.info("Generating entity types...")
125130
entity_types = await generate_entity_types(
126131
llm,
127132
domain=domain,
128133
persona=persona,
129134
docs=doc_list,
130-
json_mode=config.llm.model_supports_json or False,
135+
json_mode=entity_extraction_llm_settings.model_supports_json or False,
131136
)
132137

133138
logger.info("Generating entity relationship examples...")
@@ -147,7 +152,7 @@ async def generate_indexing_prompts(
147152
examples=examples,
148153
language=language,
149154
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
150-
encoding_model=config.encoding_model,
155+
encoding_model=entity_extraction_llm_settings.encoding_model,
151156
max_token_count=max_tokens,
152157
min_examples_required=min_examples_required,
153158
)

graphrag/api/query.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
import pandas as pd
2525
from pydantic import validate_call
2626

27-
from graphrag.config.models.graph_rag_config import GraphRagConfig
28-
from graphrag.index.config.embeddings import (
27+
from graphrag.config.embeddings import (
2928
community_full_content_embedding,
29+
create_collection_name,
3030
entity_description_embedding,
3131
text_unit_text_embedding,
3232
)
33+
from graphrag.config.models.graph_rag_config import GraphRagConfig
3334
from graphrag.logger.print_progress import PrintProgressLogger
3435
from graphrag.query.factory import (
3536
get_basic_search_engine,
@@ -47,7 +48,6 @@
4748
read_indexer_text_units,
4849
)
4950
from graphrag.utils.cli import redact
50-
from graphrag.utils.embeddings import create_collection_name
5151
from graphrag.vector_stores.base import BaseVectorStore
5252
from graphrag.vector_stores.factory import VectorStoreFactory
5353

@@ -244,7 +244,7 @@ async def local_search(
244244
------
245245
TODO: Document any exceptions to expect.
246246
"""
247-
vector_store_args = config.embeddings.vector_store
247+
vector_store_args = config.vector_store.model_dump()
248248
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
249249

250250
description_embedding_store = _get_embedding_store(
@@ -310,7 +310,7 @@ async def local_search_streaming(
310310
------
311311
TODO: Document any exceptions to expect.
312312
"""
313-
vector_store_args = config.embeddings.vector_store
313+
vector_store_args = config.vector_store.model_dump()
314314
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
315315

316316
description_embedding_store = _get_embedding_store(
@@ -381,7 +381,7 @@ async def drift_search_streaming(
381381
------
382382
TODO: Document any exceptions to expect.
383383
"""
384-
vector_store_args = config.embeddings.vector_store
384+
vector_store_args = config.vector_store.model_dump()
385385
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
386386

387387
description_embedding_store = _get_embedding_store(
@@ -465,7 +465,7 @@ async def drift_search(
465465
------
466466
TODO: Document any exceptions to expect.
467467
"""
468-
vector_store_args = config.embeddings.vector_store
468+
vector_store_args = config.vector_store.model_dump()
469469
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
470470

471471
description_embedding_store = _get_embedding_store(
@@ -531,7 +531,7 @@ async def basic_search(
531531
------
532532
TODO: Document any exceptions to expect.
533533
"""
534-
vector_store_args = config.embeddings.vector_store
534+
vector_store_args = config.vector_store.model_dump()
535535
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
536536

537537
description_embedding_store = _get_embedding_store(
@@ -576,7 +576,7 @@ async def basic_search_streaming(
576576
------
577577
TODO: Document any exceptions to expect.
578578
"""
579-
vector_store_args = config.embeddings.vector_store
579+
vector_store_args = config.vector_store.model_dump()
580580
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
581581

582582
description_embedding_store = _get_embedding_store(

graphrag/callbacks/factory.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

0 commit comments

Comments
 (0)