Skip to content

Commit ff28c7a

Browse files
committed
fix state management
1 parent 153f6ac commit ff28c7a

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

graphrag/index/workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
from .prune_graph import (
4343
run_workflow as run_prune_graph,
4444
)
45+
from .update_clean_state import (
46+
run_workflow as run_update_clean_state,
47+
)
4548
from .update_communities import (
4649
run_workflow as run_update_communities,
4750
)
@@ -85,4 +88,5 @@
8588
"update_communities": run_update_communities,
8689
"update_covariates": run_update_covariates,
8790
"update_text_units": run_update_text_units,
91+
"update_clean_state": run_update_clean_state,
8892
})

graphrag/index/workflows/factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def _get_workflows_list(
5353
"update_communities",
5454
"update_community_reports",
5555
"update_text_embeddings",
56+
"update_clean_state",
5657
]
5758
if config.workflows:
5859
return config.workflows + (update_workflows if is_update_run else [])
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""A module containing run_workflow method definition."""
5+
6+
import logging
7+
8+
from graphrag.config.models.graph_rag_config import GraphRagConfig
9+
from graphrag.index.typing.context import PipelineRunContext
10+
from graphrag.index.typing.workflow import WorkflowFunctionOutput
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
async def run_workflow(
16+
config: GraphRagConfig,
17+
context: PipelineRunContext,
18+
) -> WorkflowFunctionOutput:
19+
"""Clean the state after the update."""
20+
logger.info("Cleaning State")
21+
keys_to_delete = [key_name for key_name in context.state if key_name.startswith("incremental_update_")]
22+
23+
for key_name in keys_to_delete:
24+
del context.state[key_name]
25+
26+
return WorkflowFunctionOutput(result=None)

0 commit comments

Comments
 (0)