Skip to content

Commit ef0f3d7

Browse files
committed
Ensuring error handling works. Delegating ProgressIndicator for ticking of metrics.
1 parent 5b3092f commit ef0f3d7

File tree

6 files changed

+78
-51
lines changed

6 files changed

+78
-51
lines changed

nodestream/cli/operations/run_pipeline.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ...utils import StringSuggester
1010
from ..commands.nodestream_command import NodestreamCommand
1111
from .operation import Operation
12+
from logging import getLogger
1213

1314
STATS_TABLE_COLS = ["Statistic", "Value"]
1415

@@ -113,7 +114,7 @@ def get_progress_indicator(
113114
self, command: NodestreamCommand, pipeline_name: str
114115
) -> "ProgressIndicator":
115116
if command.has_json_logging_set:
116-
return ProgressIndicator(command, pipeline_name)
117+
return JsonProgressIndicator(command, pipeline_name)
117118

118119
return SpinnerProgressIndicator(command, pipeline_name)
119120

@@ -122,6 +123,7 @@ def create_progress_reporter(
122123
) -> PipelineProgressReporter:
123124
indicator = self.get_progress_indicator(command, pipeline_name)
124125
return PipelineProgressReporter(
126+
logger=indicator.logger,
125127
reporting_frequency=int(command.option("reporting-frequency")),
126128
callback=indicator.progress_callback,
127129
on_start_callback=indicator.on_start,
@@ -134,6 +136,7 @@ class ProgressIndicator:
134136
def __init__(self, command: NodestreamCommand, pipeline_name: str) -> None:
135137
self.command = command
136138
self.pipeline_name = pipeline_name
139+
self.logger = getLogger()
137140

138141
def on_start(self):
139142
pass
@@ -160,14 +163,15 @@ def on_start(self):
160163
self.progress = self.command.progress_indicator()
161164
self.progress.start(f"Running pipeline: '{self.pipeline_name}'")
162165

163-
def progress_callback(self, index, _):
166+
def progress_callback(self, index, metrics: Metrics):
164167
self.progress.set_message(
165168
f"Currently processing record at index: <info>{index}</info>"
166169
)
170+
metrics.tick()
167171

168172
def on_finish(self, metrics: Metrics):
169173
self.progress.finish(f"Finished running pipeline: '{self.pipeline_name}'")
170-
174+
metrics.tick()
171175
if self.exception:
172176
raise self.exception
173177

@@ -176,3 +180,26 @@ def on_fatal_error(self, exception: Exception):
176180
"<error>Encountered a fatal error while running pipeline</error>"
177181
)
178182
self.exception = exception
183+
184+
185+
class JsonProgressIndicator(ProgressIndicator):
186+
def __init__(self, command: NodestreamCommand, pipeline_name: str) -> None:
187+
super().__init__(command, pipeline_name)
188+
self.exception = None
189+
190+
def on_start(self):
191+
self.logger.info("Starting Pipeline")
192+
193+
def progress_callback(self, index, metrics: Metrics):
194+
self.logger.info("Processing Record", extra={"index": index})
195+
metrics.tick()
196+
197+
def on_finish(self, metrics: Metrics):
198+
self.logger.info("Pipeline Completed")
199+
metrics.tick()
200+
if self.exception:
201+
raise self.exception
202+
203+
def on_fatal_error(self, exception: Exception):
204+
self.logger.error("Pipeline Failed", exc_info=exception)
205+
self.exception = exception

nodestream/metrics.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class NodestreamMetricRegistry(MetricRegistry):
113113
# Core metrics
114114
RECORDS = Metric("records", "Number of records processed")
115115
NON_FATAL_ERRORS = Metric("non_fatal_errors", "Number of non-fatal errors")
116+
FATAL_ERRORS = Metric("fatal_errors", "Number of fatal errors")
116117
NODES_UPSERTED = Metric("nodes_upserted", "Number of nodes upserted to the graph")
117118
RELATIONSHIPS_UPSERTED = Metric(
118119
"relationships_upserted", "Number of relationships upserted to the graph"
@@ -241,7 +242,7 @@ def tick(self):
241242
self.render()
242243

243244
def stop(self):
244-
self.render()
245+
pass
245246

246247

247248
class JsonLogMetricHandler(MetricHandler):
@@ -265,7 +266,7 @@ def render(self):
265266
)
266267

267268
def stop(self):
268-
self.render()
269+
pass
269270

270271
def tick(self):
271272
self.render()

nodestream/pipeline/channel.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,32 @@ class Channel:
3737
prevent one step from overwhelming another step with too many records.
3838
"""
3939

40-
__slots__ = ("queue", "input_dropped", "metric")
40+
__slots__ = ("queue", "input_dropped", "_metric", "input_name", "output_name")
4141

42-
def __init__(self, size: int, input_name: str, output_name: str) -> None:
42+
def __init__(self, size: int) -> None:
4343
self.queue = Queue(maxsize=size)
4444
self.input_dropped = False
45-
self.metric = Metric(
46-
f"buffered_{input_name}_to_{output_name}",
47-
f"Records buffered: {input_name}{output_name}",
48-
)
45+
self.input_name = "Void"
46+
self.output_name = "Void"
47+
self._metric = None
48+
49+
@property
50+
def metric(self) -> Metric:
51+
"""Get the metric for the channel."""
52+
if self._metric is None:
53+
self._metric = Metric(
54+
f"buffered_{self.output_name}_to_{self.input_name}",
55+
f"Records buffered: {self.output_name}{self.input_name}",
56+
)
57+
return self._metric
58+
59+
def register_input(self, name: str) -> None:
60+
"""Register the name of the step that will consume from this channel."""
61+
self.input_name = name
62+
63+
def register_output(self, name: str) -> None:
64+
"""Register the name of the step that will produce to this channel."""
65+
self.output_name = name
4966

5067
async def get(self):
5168
"""Get an object from the channel.
@@ -58,6 +75,7 @@ async def get(self):
5875
object: The object that was retrieved from the channel.
5976
"""
6077
object = await self.queue.get()
78+
Metrics.get().decrement(self.metric)
6179
return object
6280

6381
async def put(self, obj) -> bool:
@@ -92,6 +110,10 @@ class StepOutput:
92110
def __init__(self, channel: Channel) -> None:
93111
self.channel = channel
94112

113+
def register(self, name: str) -> None:
114+
"""Register the name of the step that will produce to this channel."""
115+
self.channel.register_output(name)
116+
95117
async def done(self):
96118
"""Mark the output channel as done.
97119
@@ -138,6 +160,10 @@ class StepInput:
138160
def __init__(self, channel: Channel) -> None:
139161
self.channel = channel
140162

163+
def register(self, name: str) -> None:
164+
"""Register the name of the step that will consume from this channel."""
165+
self.channel.register_input(name)
166+
141167
async def get(self) -> Optional[object]:
142168
"""Get an object from the input channel.
143169
@@ -161,15 +187,13 @@ def done(self):
161187
self.channel.input_dropped = True
162188

163189

164-
def channel(
165-
size: int, input_name: str, output_name: str
166-
) -> Tuple[StepInput, StepOutput]:
190+
def channel(size: int) -> Tuple[StepInput, StepOutput]:
167191
"""Create a new input and output channel.
168192
169193
Args:
170194
size: The size of the channel.
171195
input_name: The name of the input step.
172196
output_name: The name of the output step.
173197
"""
174-
channel = Channel(size, input_name, output_name)
198+
channel = Channel(size)
175199
return StepInput(channel), StepOutput(channel)

nodestream/pipeline/pipeline.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ async def run(self):
113113
self.call_handling_errors(self.reporter.on_start_callback)
114114

115115
index = 0
116-
while (obj := await self.input.get()) is not None:
116+
while (await self.input.get()) is not None:
117117
metrics.increment(NodestreamMetricRegistry.RECORDS)
118-
self.call_handling_errors(self.reporter.report, index, obj)
118+
self.call_handling_errors(self.reporter.report, index, metrics)
119119
index += 1
120120

121121
self.call_handling_errors(self.reporter.on_finish_callback, metrics)
@@ -161,8 +161,6 @@ async def run(self, reporter: PipelineProgressReporter):
161161
`PipelineProgressReporter`.
162162
163163
Args:
164-
channel_size: The size of the channels used to pass records between
165-
steps in the pipeline.
166164
reporter: The `PipelineProgressReporter` used to report on the
167165
progress of the pipeline.
168166
"""
@@ -173,15 +171,9 @@ async def run(self, reporter: PipelineProgressReporter):
173171
# the steps in the pipeline. The channels have a fixed size to control
174172
# the flow of records between the steps.
175173
executors: List[StepExecutor] = []
176-
current_input, current_output = channel(
177-
self.step_outbox_size,
178-
input_name=self.steps[-1].__class__.__name__ + f"_{len(self.steps) - 1}",
179-
output_name="Void",
180-
)
174+
current_input, current_output = channel(self.step_outbox_size)
181175
pipeline_output = PipelineOutput(current_input, reporter)
182176

183-
print([step.__class__.__name__ for step in self.steps])
184-
185177
# Create the executors for the steps in the pipeline. The executors
186178
# will be used to run the steps concurrently. The steps are created in
187179
# reverse order so that the output of each step is connected to the
@@ -190,20 +182,9 @@ async def run(self, reporter: PipelineProgressReporter):
190182
index = len(self.steps) - reversed_index - 1
191183
storage = self.object_store.namespaced(str(index))
192184
context = StepContext(step.__class__.__name__, index, reporter, storage)
193-
194-
previous_step = (
195-
self.steps[reversed_index - 1] if reversed_index > 0 else None
196-
)
197-
198-
current_input, next_output = channel(
199-
self.step_outbox_size,
200-
input_name=(
201-
previous_step.__class__.__name__ + f"_{reversed_index - 1}"
202-
if previous_step
203-
else "Void"
204-
),
205-
output_name=step.__class__.__name__ + f"_{reversed_index}",
206-
)
185+
current_input, next_output = channel(self.step_outbox_size)
186+
current_input.register(step.__class__.__name__)
187+
current_output.register(step.__class__.__name__)
207188
exec = StepExecutor(step, current_input, current_output, context)
208189
current_output = next_output
209190
executors.append(exec)
@@ -218,6 +199,4 @@ async def run(self, reporter: PipelineProgressReporter):
218199
# concurrently. This will block until all steps are finished.
219200
running_steps = (create_task(executor.run()) for executor in executors)
220201

221-
self.logger.info("Starting Pipeline")
222202
await gather(*running_steps, create_task(pipeline_output.run()))
223-
self.logger.info("Pipeline Completed")

nodestream/pipeline/progress_reporter.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from dataclasses import dataclass, field
33
from logging import Logger, getLogger
4-
from typing import Any, Callable
4+
from typing import Callable
55

66
from psutil import Process
77

@@ -33,7 +33,7 @@ class PipelineProgressReporter:
3333

3434
reporting_frequency: int = 10000
3535
logger: Logger = field(default_factory=getLogger)
36-
callback: Callable[[int, Any], None] = field(default=no_op)
36+
callback: Callable[[int, Metrics], None] = field(default=no_op)
3737
on_start_callback: Callable[[], None] = field(default=no_op)
3838
on_finish_callback: Callable[[Metrics], None] = field(default=no_op)
3939
on_fatal_error_callback: Callable[[Exception], None] = field(default=no_op)
@@ -66,11 +66,6 @@ def for_testing(cls, results_list: list) -> "PipelineProgressReporter":
6666
on_fatal_error_callback=raise_exception,
6767
)
6868

69-
def report(self, index, record):
69+
def report(self, index, metrics: Metrics):
7070
if index % 10000 == 0:
71-
# self.logger.info(
72-
# "Records Processed",
73-
# extra={"index": index, "max_memory": get_max_mem_mb()},
74-
# )
75-
# self.callback(index, record)
76-
Metrics.get().tick()
71+
self.callback(index, metrics)

nodestream/pipeline/step.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def report_error(
5757
)
5858
if fatal:
5959
self.reporter.on_fatal_error(exception)
60+
Metrics.get().increment(NodestreamMetricRegistry.FATAL_ERRORS)
6061
else:
6162
Metrics.get().increment(NodestreamMetricRegistry.NON_FATAL_ERRORS)
6263

0 commit comments

Comments
 (0)