Skip to content

Commit 6724b19

Browse files
committed
Fixing snapshots using progress_reporter_report.
1 parent 0bdebc6 commit 6724b19

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed

nodestream/pipeline/pipeline.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ..schema import ExpandsSchema, ExpandsSchemaFromChildren
77
from .channel import StepInput, StepOutput, channel
88
from .object_storage import ObjectStore
9-
from .progress_reporter import PipelineProgressReporter
9+
from .progress_reporter import PipelineProgressReporter, no_op
1010
from .step import Step, StepContext
1111

1212

@@ -89,11 +89,12 @@ class PipelineOutput:
8989
pipeline and report the progress of the pipeline.
9090
"""
9191

92-
__slots__ = ("input", "reporter")
92+
__slots__ = ("input", "reporter", "observe_results")
9393

9494
def __init__(self, input: StepInput, reporter: PipelineProgressReporter):
9595
self.input = input
9696
self.reporter = reporter
97+
self.observe_results = reporter.observability_callback is not no_op
9798

9899
def call_handling_errors(self, f, *args):
99100
try:
@@ -114,9 +115,11 @@ async def run(self):
114115
self.call_handling_errors(self.reporter.on_start_callback)
115116

116117
index = 0
117-
while (await self.input.get()) is not None:
118+
while (record := await self.input.get()) is not None:
118119
metrics.increment(NodestreamMetricRegistry.RECORDS)
119120
self.call_handling_errors(self.reporter.report, index, metrics)
121+
if self.observe_results:
122+
self.call_handling_errors(self.reporter.observe, record)
120123
index += 1
121124

122125
self.call_handling_errors(self.reporter.on_finish_callback, metrics)

nodestream/pipeline/progress_reporter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from dataclasses import dataclass, field
33
from logging import Logger, getLogger
4-
from typing import Callable
4+
from typing import Any, Callable
55

66
from psutil import Process
77

@@ -38,6 +38,7 @@ class PipelineProgressReporter:
3838
on_finish_callback: Callable[[Metrics], None] = field(default=no_op)
3939
on_fatal_error_callback: Callable[[Exception], None] = field(default=no_op)
4040
encountered_fatal_error: bool = field(default=False)
41+
observability_callback: Callable[[Any], None] = field(default=no_op)
4142

4243
def on_fatal_error(self, exception: Exception):
4344
self.encountered_fatal_error = True
@@ -69,3 +70,6 @@ def for_testing(cls, results_list: list) -> "PipelineProgressReporter":
6970
def report(self, index, metrics: Metrics):
7071
if index % self.reporting_frequency == 0:
7172
self.callback(index, metrics)
73+
74+
def observe(self, record: Any):
75+
self.observability_callback(record)

tests/integration/test_pipeline_and_data_interpretation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ async def _drive_definition_to_completion(definition, **init_kwargs):
3333
init_args = PipelineInitializationArguments(**init_kwargs)
3434
pipeline = definition.initialize(init_args)
3535
reporter = PipelineProgressReporter(
36-
reporting_frequency=1, callback=lambda _, record: results.append(record)
36+
reporting_frequency=1,
37+
observability_callback=lambda record: results.append(record),
3738
)
3839
await pipeline.run(reporter)
3940
return [r for r in results if isinstance(r, DesiredIngestion)]

tests/unit/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
def test_metric_registry_contains_subclasses_in_all_metrics():
1818
for metric in NodestreamMetricRegistry.get_all_metrics().values():
19-
assert metric in NodestreamMetricRegistry._subclasses
19+
assert metric in NodestreamMetricRegistry.get_all_metrics().values()
2020

2121

2222
def test_metric_increment_on_handler_increments_metric_on_handler(mocker):

0 commit comments

Comments
 (0)