Skip to content

Commit 5713205

Browse files
AlonsoGuevarasnehitgajjarSnehit GajjarCopilot
authored
Feat/additional context (#2021)
* Users/snehitgajjar/add optional api param for pipeline state (#2019) * Add support for additional context for PipelineState * Clean up * Clean up * Clean up * Nit --------- Co-authored-by: Snehit Gajjar <[email protected]> * Semver * Update graphrag/api/index.py Co-authored-by: Copilot <[email protected]> * Remove additional_context from serialization --------- Co-authored-by: Snehit Gajjar <[email protected]> Co-authored-by: Snehit Gajjar <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 1da1380 commit 5713205

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Add additional context variable to build index signature for custom parameter bag"
4+
}

graphrag/api/index.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
import logging
12+
from typing import Any
1213

1314
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
1415
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@@ -30,6 +31,7 @@ async def build_index(
3031
is_update_run: bool = False,
3132
memory_profile: bool = False,
3233
callbacks: list[WorkflowCallbacks] | None = None,
34+
additional_context: dict[str, Any] | None = None,
3335
) -> list[PipelineRunResult]:
3436
"""Run the pipeline with the given configuration.
3537
@@ -43,6 +45,8 @@ async def build_index(
4345
Whether to enable memory profiling.
4446
callbacks : list[WorkflowCallbacks] | None default=None
4547
A list of callbacks to register.
48+
additional_context : dict[str, Any] | None default=None
49+
Additional context to pass to the pipeline run. This can be accessed in the pipeline state under the 'additional_context' key.
4650
4751
Returns
4852
-------
@@ -73,6 +77,7 @@ async def build_index(
7377
config,
7478
callbacks=workflow_callbacks,
7579
is_update_run=is_update_run,
80+
additional_context=additional_context,
7681
):
7782
outputs.append(output)
7883
if output.errors and len(output.errors) > 0:

graphrag/index/run/run_pipeline.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import time
1010
from collections.abc import AsyncIterable
1111
from dataclasses import asdict
12+
from typing import Any
1213

1314
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1415
from graphrag.config.models.graph_rag_config import GraphRagConfig
@@ -28,6 +29,7 @@ async def run_pipeline(
2829
config: GraphRagConfig,
2930
callbacks: WorkflowCallbacks,
3031
is_update_run: bool = False,
32+
additional_context: dict[str, Any] | None = None,
3133
) -> AsyncIterable[PipelineRunResult]:
3234
"""Run all workflows using a simplified pipeline."""
3335
root_dir = config.root_dir
@@ -40,6 +42,9 @@ async def run_pipeline(
4042
state_json = await output_storage.get("context.json")
4143
state = json.loads(state_json) if state_json else {}
4244

45+
if additional_context:
46+
state.setdefault("additional_context", {}).update(additional_context)
47+
4348
if is_update_run:
4449
logger.info("Running incremental indexing.")
4550

@@ -126,9 +131,17 @@ async def _dump_json(context: PipelineRunContext) -> None:
126131
await context.output_storage.set(
127132
"stats.json", json.dumps(asdict(context.stats), indent=4, ensure_ascii=False)
128133
)
129-
await context.output_storage.set(
130-
"context.json", json.dumps(context.state, indent=4, ensure_ascii=False)
131-
)
134+
# Dump context state, excluding additional_context
135+
temp_context = context.state.pop(
136+
"additional_context", None
137+
) # Remove reference only, as object size is uncertain
138+
try:
139+
state_blob = json.dumps(context.state, indent=4, ensure_ascii=False)
140+
finally:
141+
if temp_context:
142+
context.state["additional_context"] = temp_context
143+
144+
await context.output_storage.set("context.json", state_blob)
132145

133146

134147
async def _copy_previous_output(

0 commit comments

Comments
 (0)