99import time
1010from collections .abc import AsyncIterable
1111from dataclasses import asdict
12+ from typing import Any
1213
1314from graphrag .callbacks .workflow_callbacks import WorkflowCallbacks
1415from graphrag .config .models .graph_rag_config import GraphRagConfig
@@ -28,7 +29,7 @@ async def run_pipeline(
2829 config : GraphRagConfig ,
2930 callbacks : WorkflowCallbacks ,
3031 is_update_run : bool = False ,
31- additional_context : dict | None = None ,
32+ additional_context : dict [ str , Any ] | None = None ,
3233) -> AsyncIterable [PipelineRunResult ]:
3334 """Run all workflows using a simplified pipeline."""
3435 root_dir = config .root_dir
@@ -41,8 +42,13 @@ async def run_pipeline(
4142 state_json = await output_storage .get ("context.json" )
4243 state = json .loads (state_json ) if state_json else {}
4344
44- for key , value in (additional_context or {}).items ():
45- state ["additional_context" ][key ] = value
45+ if additional_context is not None :
46+ if "additional_context" not in state :
47+ state ["additional_context" ] = {}
48+
49+ # add additional context to the state
50+ for key , value in (additional_context or {}).items ():
51+ state ["additional_context" ][key ] = value
4652
4753 if is_update_run :
4854 logger .info ("Running incremental indexing." )
0 commit comments