Skip to content

Commit 91c0c6f

Browse files
committed
remove config from flows to fix llm arg mapping
1 parent 19537c3 commit 91c0c6f

30 files changed

+27
-167
lines changed

graphrag/config/models/graph_rag_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ def get_language_model_config(self, model_id: str) -> LanguageModelConfig:
212212
if model_id not in self.models:
213213
err_msg = f"Model ID {model_id} not found in configuration."
214214
raise ValueError(err_msg)
215-
return self.models[model_id]
215+
# TODO: shouldn't self.models be validated already?
216+
return LanguageModelConfig.model_construct(**dict(self.models[model_id])) # type: ignore
216217

217218
@model_validator(mode="after")
218219
def _validate_model(self):

graphrag/index/flows/create_final_community_reports.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from graphrag.cache.pipeline_cache import PipelineCache
1111
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1212
from graphrag.config.enums import AsyncType
13-
from graphrag.config.models.graph_rag_config import GraphRagConfig
1413
from graphrag.index.operations.summarize_communities import (
1514
prepare_community_reports,
1615
restore_community_hierarchy,
@@ -47,7 +46,6 @@ async def create_final_community_reports(
4746
callbacks: WorkflowCallbacks,
4847
cache: PipelineCache,
4948
summarization_strategy: dict,
50-
config: GraphRagConfig,
5149
async_mode: AsyncType = AsyncType.AsyncIO,
5250
num_threads: int = 4,
5351
) -> pd.DataFrame:
@@ -80,7 +78,6 @@ async def create_final_community_reports(
8078
strategy=summarization_strategy,
8179
async_mode=async_mode,
8280
num_threads=num_threads,
83-
config=config,
8481
)
8582

8683
community_reports["community"] = community_reports["community"].astype(int)

graphrag/index/flows/create_final_covariates.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from graphrag.cache.pipeline_cache import PipelineCache
1212
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1313
from graphrag.config.enums import AsyncType
14-
from graphrag.config.models.graph_rag_config import GraphRagConfig
1514
from graphrag.index.operations.extract_covariates.extract_covariates import (
1615
extract_covariates,
1716
)
@@ -23,7 +22,6 @@ async def create_final_covariates(
2322
cache: PipelineCache,
2423
covariate_type: str,
2524
extraction_strategy: dict[str, Any] | None,
26-
config: GraphRagConfig,
2725
async_mode: AsyncType = AsyncType.AsyncIO,
2826
entity_types: list[str] | None = None,
2927
num_threads: int = 4,
@@ -42,7 +40,6 @@ async def create_final_covariates(
4240
async_mode=async_mode,
4341
entity_types=entity_types,
4442
num_threads=num_threads,
45-
config=config,
4643
)
4744
text_units.drop(columns=["text_unit_id"], inplace=True) # don't pollute the global
4845
covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4()))

graphrag/index/flows/extract_graph.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from graphrag.cache.pipeline_cache import PipelineCache
1212
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1313
from graphrag.config.enums import AsyncType
14-
from graphrag.config.models.graph_rag_config import GraphRagConfig
1514
from graphrag.index.operations.extract_entities import extract_entities
1615
from graphrag.index.operations.summarize_descriptions import (
1716
summarize_descriptions,
@@ -22,7 +21,6 @@ async def extract_graph(
2221
text_units: pd.DataFrame,
2322
callbacks: WorkflowCallbacks,
2423
cache: PipelineCache,
25-
config: GraphRagConfig,
2624
extraction_strategy: dict[str, Any] | None = None,
2725
extraction_num_threads: int = 4,
2826
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
@@ -42,7 +40,6 @@ async def extract_graph(
4240
async_mode=extraction_async_mode,
4341
entity_types=entity_types,
4442
num_threads=extraction_num_threads,
45-
config=config,
4643
)
4744

4845
if not _validate_data(entities):
@@ -64,7 +61,6 @@ async def extract_graph(
6461
cache=cache,
6562
strategy=summarization_strategy,
6663
num_threads=summarization_num_threads,
67-
config=config,
6864
)
6965

7066
base_relationship_edges = _prep_edges(relationships, relationship_summaries)

graphrag/index/flows/generate_text_embeddings.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from graphrag.cache.pipeline_cache import PipelineCache
1111
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
12-
from graphrag.config.models.graph_rag_config import GraphRagConfig
1312
from graphrag.index.config.embeddings import (
1413
community_full_content_embedding,
1514
community_summary_embedding,
@@ -38,7 +37,6 @@ async def generate_text_embeddings(
3837
storage: PipelineStorage,
3938
text_embed_config: dict,
4039
embedded_fields: set[str],
41-
config: GraphRagConfig,
4240
snapshot_embeddings_enabled: bool = False,
4341
) -> None:
4442
"""All the steps to generate all embeddings."""
@@ -104,7 +102,6 @@ async def generate_text_embeddings(
104102
storage=storage,
105103
text_embed_config=text_embed_config,
106104
snapshot_embeddings_enabled=snapshot_embeddings_enabled,
107-
config=config,
108105
**embedding_param_map[field],
109106
)
110107

@@ -118,7 +115,6 @@ async def _run_and_snapshot_embeddings(
118115
storage: PipelineStorage,
119116
text_embed_config: dict,
120117
snapshot_embeddings_enabled: bool,
121-
config: GraphRagConfig,
122118
) -> None:
123119
"""All the steps to generate single embedding."""
124120
if text_embed_config:
@@ -129,7 +125,6 @@ async def _run_and_snapshot_embeddings(
129125
embed_column=embed_column,
130126
embedding_name=name,
131127
strategy=text_embed_config["strategy"],
132-
config=config,
133128
)
134129

135130
if snapshot_embeddings_enabled is True:

graphrag/index/operations/embed_text/embed_text.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from graphrag.cache.pipeline_cache import PipelineCache
1414
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
15-
from graphrag.config.models.graph_rag_config import GraphRagConfig
1615
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
1716
from graphrag.utils.embeddings import create_collection_name
1817
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
@@ -43,7 +42,6 @@ async def embed_text(
4342
embed_column: str,
4443
strategy: dict,
4544
embedding_name: str,
46-
config: GraphRagConfig,
4745
id_column: str = "id",
4846
title_column: str | None = None,
4947
):
@@ -98,7 +96,6 @@ async def embed_text(
9896
vector_store_config=vector_store_workflow_config,
9997
id_column=id_column,
10098
title_column=title_column,
101-
config=config,
10299
)
103100

104101
return await _text_embed_in_memory(
@@ -107,7 +104,6 @@ async def embed_text(
107104
cache=cache,
108105
embed_column=embed_column,
109106
strategy=strategy,
110-
config=config,
111107
)
112108

113109

@@ -117,14 +113,13 @@ async def _text_embed_in_memory(
117113
cache: PipelineCache,
118114
embed_column: str,
119115
strategy: dict,
120-
config: GraphRagConfig,
121116
):
122117
strategy_type = strategy["type"]
123118
strategy_exec = load_strategy(strategy_type)
124119
strategy_args = {**strategy}
125120

126121
texts: list[str] = input[embed_column].to_numpy().tolist()
127-
result = await strategy_exec(texts, callbacks, cache, strategy_args, config)
122+
result = await strategy_exec(texts, callbacks, cache, strategy_args)
128123

129124
return result.embeddings
130125

@@ -137,7 +132,6 @@ async def _text_embed_with_vector_store(
137132
strategy: dict[str, Any],
138133
vector_store: BaseVectorStore,
139134
vector_store_config: dict,
140-
config: GraphRagConfig,
141135
id_column: str = "id",
142136
title_column: str | None = None,
143137
):
@@ -182,7 +176,7 @@ async def _text_embed_with_vector_store(
182176
texts: list[str] = batch[embed_column].to_numpy().tolist()
183177
titles: list[str] = batch[title].to_numpy().tolist()
184178
ids: list[str] = batch[id_column].to_numpy().tolist()
185-
result = await strategy_exec(texts, callbacks, cache, strategy_args, config)
179+
result = await strategy_exec(texts, callbacks, cache, strategy_args)
186180
if result.embeddings:
187181
embeddings = [
188182
embedding for embedding in result.embeddings if embedding is not None

graphrag/index/operations/embed_text/strategies/mock.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from graphrag.cache.pipeline_cache import PipelineCache
1111
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
12-
from graphrag.config.models.graph_rag_config import GraphRagConfig
1312
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult
1413
from graphrag.logger.progress import ProgressTicker, progress_ticker
1514

@@ -19,7 +18,6 @@ async def run( # noqa RUF029 async is required for interface
1918
callbacks: WorkflowCallbacks,
2019
cache: PipelineCache,
2120
_args: dict[str, Any],
22-
_config: GraphRagConfig,
2321
) -> TextEmbeddingResult:
2422
"""Run the Claim extraction chain."""
2523
input = input if isinstance(input, Iterable) else [input]

graphrag/index/operations/embed_text/strategies/openai.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from graphrag.cache.pipeline_cache import PipelineCache
1414
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
15-
from graphrag.config.models.graph_rag_config import GraphRagConfig
1615
from graphrag.config.models.language_model_config import LanguageModelConfig
1716
from graphrag.index.llm.load_llm import load_llm_embeddings
1817
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult
@@ -28,19 +27,16 @@ async def run(
2827
callbacks: WorkflowCallbacks,
2928
cache: PipelineCache,
3029
args: dict[str, Any],
31-
config: GraphRagConfig,
3230
) -> TextEmbeddingResult:
3331
"""Run the Claim extraction chain."""
3432
if is_null(input):
3533
return TextEmbeddingResult(embeddings=None)
3634

3735
batch_size = args.get("batch_size", 16)
3836
batch_max_tokens = args.get("batch_max_tokens", 8191)
39-
embeddings_llm_settings = config.get_language_model_config(
40-
config.embeddings.model_id
41-
)
42-
splitter = _get_splitter(embeddings_llm_settings, batch_max_tokens)
43-
llm = _get_llm(embeddings_llm_settings, callbacks, cache)
37+
llm_config = args["llm"]
38+
splitter = _get_splitter(llm_config, batch_max_tokens)
39+
llm = _get_llm(llm_config, callbacks, cache)
4440
semaphore: asyncio.Semaphore = asyncio.Semaphore(args.get("num_threads", 4))
4541

4642
# Break up the input texts. The sizes here indicate how many snippets are in each input text

graphrag/index/operations/embed_text/strategies/typing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from graphrag.cache.pipeline_cache import PipelineCache
1010
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
11-
from graphrag.config.models.graph_rag_config import GraphRagConfig
1211

1312

1413
@dataclass
@@ -24,7 +23,6 @@ class TextEmbeddingResult:
2423
WorkflowCallbacks,
2524
PipelineCache,
2625
dict,
27-
GraphRagConfig,
2826
],
2927
Awaitable[TextEmbeddingResult],
3028
]

graphrag/index/operations/extract_covariates/extract_covariates.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from graphrag.cache.pipeline_cache import PipelineCache
1515
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1616
from graphrag.config.enums import AsyncType
17-
from graphrag.config.models.graph_rag_config import GraphRagConfig
17+
from graphrag.config.models.language_model_config import LanguageModelConfig
1818
from graphrag.index.llm.load_llm import load_llm
1919
from graphrag.index.operations.extract_covariates.claim_extractor import ClaimExtractor
2020
from graphrag.index.operations.extract_covariates.typing import (
@@ -35,7 +35,6 @@ async def extract_covariates(
3535
cache: PipelineCache,
3636
column: str,
3737
covariate_type: str,
38-
config: GraphRagConfig,
3938
strategy: dict[str, Any] | None,
4039
async_mode: AsyncType = AsyncType.AsyncIO,
4140
entity_types: list[str] | None = None,
@@ -60,7 +59,6 @@ async def run_strategy(row):
6059
callbacks=callbacks,
6160
cache=cache,
6261
strategy_config=strategy_config,
63-
config=config,
6462
)
6563
return [
6664
create_row_from_claim_data(row, item, covariate_type)
@@ -89,15 +87,12 @@ async def run_claim_extraction(
8987
callbacks: WorkflowCallbacks,
9088
cache: PipelineCache,
9189
strategy_config: dict[str, Any],
92-
config: GraphRagConfig,
9390
) -> CovariateExtractionResult:
9491
"""Run the Claim extraction chain."""
95-
claim_extraction_llm_settings = config.get_language_model_config(
96-
config.claim_extraction.model_id
97-
)
92+
llm_config = LanguageModelConfig.model_construct(**strategy_config["llm"])
9893
llm = load_llm(
9994
"claim_extraction",
100-
claim_extraction_llm_settings,
95+
llm_config,
10196
callbacks=callbacks,
10297
cache=cache,
10398
)

0 commit comments

Comments
 (0)