Skip to content

Commit 7ec9ef0

Browse files
Refactor callbacks (#1583)
* Unify Workflow and Verb callbacks interfaces * Semver * Fix storage class instantiation (#1582) --------- Co-authored-by: Josh Bradley <[email protected]>
1 parent cbb8f87 commit 7ec9ef0

File tree

70 files changed

+193
-367
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+193
-367
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Simplify callbacks model."
4+
}

docs/examples_notebooks/index_migration.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@
207207
"outputs": [],
208208
"source": [
209209
"from graphrag.cache.factory import create_cache\n",
210-
"from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks\n",
210+
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
211211
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
212212
"\n",
213213
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
@@ -219,7 +219,7 @@
219219
"config = workflow.config\n",
220220
"text_embed = config.get(\"text_embed\", {})\n",
221221
"embedded_fields = config.get(\"embedded_fields\", {})\n",
222-
"callbacks = NoopVerbCallbacks()\n",
222+
"callbacks = NoopWorkflowCallbacks()\n",
223223
"cache = create_cache(pipeline_config.cache, PROJECT_DIRECTORY)\n",
224224
"\n",
225225
"await generate_text_embeddings(\n",

graphrag/api/prompt_tune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from pydantic import PositiveInt, validate_call
1515

16-
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
16+
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
1717
from graphrag.config.models.graph_rag_config import GraphRagConfig
1818
from graphrag.index.llm.load_llm import load_llm
1919
from graphrag.logger.print_progress import PrintProgressLogger
@@ -99,7 +99,7 @@ async def generate_indexing_prompts(
9999
"prompt_tuning",
100100
config.llm,
101101
cache=None,
102-
callbacks=NoopVerbCallbacks(),
102+
callbacks=NoopWorkflowCallbacks(),
103103
)
104104

105105
if not domain:

graphrag/callbacks/blob_workflow_callbacks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _write_log(self, log: dict[str, Any]):
8484
# update the blob's block count
8585
self._num_blocks += 1
8686

87-
def on_error(
87+
def error(
8888
self,
8989
message: str,
9090
cause: BaseException | None = None,
@@ -100,10 +100,10 @@ def on_error(
100100
"details": details,
101101
})
102102

103-
def on_warning(self, message: str, details: dict | None = None):
103+
def warning(self, message: str, details: dict | None = None):
104104
"""Report a warning."""
105105
self._write_log({"type": "warning", "data": message, "details": details})
106106

107-
def on_log(self, message: str, details: dict | None = None):
107+
def log(self, message: str, details: dict | None = None):
108108
"""Report a generic log message."""
109109
self._write_log({"type": "log", "data": message, "details": details})

graphrag/callbacks/console_workflow_callbacks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
1010
"""A logger that writes to a console."""
1111

12-
def on_error(
12+
def error(
1313
self,
1414
message: str,
1515
cause: BaseException | None = None,
@@ -19,11 +19,11 @@ def on_error(
1919
"""Handle when an error occurs."""
2020
print(message, str(cause), stack, details) # noqa T201
2121

22-
def on_warning(self, message: str, details: dict | None = None):
22+
def warning(self, message: str, details: dict | None = None):
2323
"""Handle when a warning occurs."""
2424
_print_warning(message)
2525

26-
def on_log(self, message: str, details: dict | None = None):
26+
def log(self, message: str, details: dict | None = None):
2727
"""Handle when a log message is produced."""
2828
print(message, details) # noqa T201
2929

graphrag/callbacks/delegating_verb_callbacks.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

graphrag/callbacks/file_workflow_callbacks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, directory: str):
2525
Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict"
2626
)
2727

28-
def on_error(
28+
def error(
2929
self,
3030
message: str,
3131
cause: BaseException | None = None,
@@ -50,7 +50,7 @@ def on_error(
5050
message = f"{message} details={details}"
5151
log.info(message)
5252

53-
def on_warning(self, message: str, details: dict | None = None):
53+
def warning(self, message: str, details: dict | None = None):
5454
"""Handle when a warning occurs."""
5555
self._out_stream.write(
5656
json.dumps(
@@ -61,7 +61,7 @@ def on_warning(self, message: str, details: dict | None = None):
6161
)
6262
_print_warning(message)
6363

64-
def on_log(self, message: str, details: dict | None = None):
64+
def log(self, message: str, details: dict | None = None):
6565
"""Handle when a log message is produced."""
6666
self._out_stream.write(
6767
json.dumps(

graphrag/callbacks/noop_verb_callbacks.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

graphrag/callbacks/noop_workflow_callbacks.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,23 @@
33

44
"""A no-op implementation of WorkflowCallbacks."""
55

6-
from typing import Any
7-
86
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
97
from graphrag.logger.progress import Progress
108

119

1210
class NoopWorkflowCallbacks(WorkflowCallbacks):
1311
"""A no-op implementation of WorkflowCallbacks."""
1412

15-
def on_workflow_start(self, name: str, instance: object) -> None:
13+
def workflow_start(self, name: str, instance: object) -> None:
1614
"""Execute this callback when a workflow starts."""
1715

18-
def on_workflow_end(self, name: str, instance: object) -> None:
16+
def workflow_end(self, name: str, instance: object) -> None:
1917
"""Execute this callback when a workflow ends."""
2018

21-
def on_step_start(self, step_name: str) -> None:
22-
"""Execute this callback every time a step starts."""
23-
24-
def on_step_end(self, step_name: str, result: Any) -> None:
25-
"""Execute this callback every time a step ends."""
26-
27-
def on_step_progress(self, step_name: str, progress: Progress) -> None:
19+
def progress(self, progress: Progress) -> None:
2820
"""Handle when progress occurs."""
2921

30-
def on_error(
22+
def error(
3123
self,
3224
message: str,
3325
cause: BaseException | None = None,
@@ -36,11 +28,8 @@ def on_error(
3628
) -> None:
3729
"""Handle when an error occurs."""
3830

39-
def on_warning(self, message: str, details: dict | None = None) -> None:
31+
def warning(self, message: str, details: dict | None = None) -> None:
4032
"""Handle when a warning occurs."""
4133

42-
def on_log(self, message: str, details: dict | None = None) -> None:
34+
def log(self, message: str, details: dict | None = None) -> None:
4335
"""Handle when a log message occurs."""
44-
45-
def on_measure(self, name: str, value: float, details: dict | None = None) -> None:
46-
"""Handle when a measurement occurs."""

graphrag/callbacks/progress_workflow_callbacks.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
"""A workflow callback manager that emits updates."""
55

6-
from typing import Any
7-
86
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
97
from graphrag.logger.base import ProgressLogger
108
from graphrag.logger.progress import Progress
@@ -31,23 +29,14 @@ def _push(self, name: str) -> None:
3129
def _latest(self) -> ProgressLogger:
3230
return self._progress_stack[-1]
3331

34-
def on_workflow_start(self, name: str, instance: object) -> None:
32+
def workflow_start(self, name: str, instance: object) -> None:
3533
"""Execute this callback when a workflow starts."""
3634
self._push(name)
3735

38-
def on_workflow_end(self, name: str, instance: object) -> None:
36+
def workflow_end(self, name: str, instance: object) -> None:
3937
"""Execute this callback when a workflow ends."""
4038
self._pop()
4139

42-
def on_step_start(self, step_name: str) -> None:
43-
"""Execute this callback every time a step starts."""
44-
self._push(f"Step {step_name}")
45-
self._latest(Progress(percent=0))
46-
47-
def on_step_end(self, step_name: str, result: Any) -> None:
48-
"""Execute this callback every time a step ends."""
49-
self._pop()
50-
51-
def on_step_progress(self, step_name: str, progress: Progress) -> None:
40+
def progress(self, progress: Progress) -> None:
5241
"""Handle when progress occurs."""
5342
self._latest(progress)

0 commit comments

Comments
 (0)