Skip to content

Commit 1e93419

Browse files
committed
Fixed the channel implementation to be non-lazy
1 parent 4a181cd commit 1e93419

File tree

8 files changed

+57
-119
lines changed

8 files changed

+57
-119
lines changed

nodestream/databases/query_executor_with_statistics.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from typing import Iterable
22

33
from ..metrics import (
4-
Metric,
5-
Metrics,
4+
INGEST_HOOKS_EXECUTED,
65
NODES_UPSERTED,
76
RELATIONSHIPS_UPSERTED,
87
TIME_TO_LIVE_OPERATIONS,
9-
INGEST_HOOKS_EXECUTED,
8+
Metric,
9+
Metrics,
1010
)
1111
from ..model import IngestionHook, Node, RelationshipWithNodes, TimeToLiveConfiguration
1212
from .query_executor import (
@@ -82,9 +82,7 @@ async def upsert_relationships_in_bulk_of_same_operation(
8282

8383
# Increment metrics in bulk
8484
metrics = Metrics.get()
85-
metrics.increment(
86-
RELATIONSHIPS_UPSERTED, total_relationships
87-
)
85+
metrics.increment(RELATIONSHIPS_UPSERTED, total_relationships)
8886
for rel_type, count in relationship_type_counts.items():
8987
metric = self._get_or_create_relationship_metric(rel_type)
9088
metrics.increment(metric, count)

nodestream/metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def get_gauge(self, metric: Metric) -> Gauge:
149149
return self.instruments_by_metric[metric]
150150

151151
def increment(self, metric: Metric, value: Number):
152-
print(f"Incrementing {metric.name} by {value}")
153152
self.get_gauge(metric).inc(value)
154153

155154
def decrement(self, metric: Metric, value: Number):

nodestream/pipeline/channel.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,32 +37,22 @@ class Channel:
3737
prevent one step from overwhelming another step with too many records.
3838
"""
3939

40-
__slots__ = ("queue", "input_dropped", "_metric", "input_name", "output_name")
40+
__slots__ = ("queue", "input_dropped", "metric")
4141

42-
def __init__(self, size: int) -> None:
42+
def __init__(self, size: int, metric: Metric) -> None:
4343
self.queue = Queue(maxsize=size)
4444
self.input_dropped = False
45-
self.input_name = "Void"
46-
self.output_name = "Void"
47-
self._metric = None
48-
49-
@property
50-
def metric(self) -> Metric:
51-
"""Get the metric for the channel."""
52-
if self._metric is None:
53-
self._metric = Metric(
54-
f"buffered_{self.output_name}_to_{self.input_name}",
55-
f"Records buffered: {self.output_name}{self.input_name}",
56-
)
57-
return self._metric
58-
59-
def register_input(self, name: str) -> None:
60-
"""Register the name of the step that will consume from this channel."""
61-
self.input_name = name
62-
63-
def register_output(self, name: str) -> None:
64-
"""Register the name of the step that will produce to this channel."""
65-
self.output_name = name
45+
self.metric = metric
46+
47+
@classmethod
48+
def create_with_naming(
49+
cls, size: int, input_name: str = "Void", output_name: str = "Void"
50+
) -> "Channel":
51+
metric = Metric(
52+
f"buffered_{input_name}_to_{output_name}",
53+
f"Records buffered: {input_name}{output_name}",
54+
)
55+
return cls(size, metric)
6656

6757
async def get(self):
6858
"""Get an object from the channel.
@@ -160,10 +150,6 @@ class StepInput:
160150
def __init__(self, channel: Channel) -> None:
161151
self.channel = channel
162152

163-
def register(self, name: str) -> None:
164-
"""Register the name of the step that will consume from this channel."""
165-
self.channel.register_input(name)
166-
167153
async def get(self) -> Optional[object]:
168154
"""Get an object from the input channel.
169155
@@ -187,13 +173,15 @@ def done(self):
187173
self.channel.input_dropped = True
188174

189175

190-
def channel(size: int) -> Tuple[StepInput, StepOutput]:
176+
def channel(
177+
size: int, input_name: str | None = None, output_name: str | None = None
178+
) -> Tuple[StepInput, StepOutput]:
191179
"""Create a new input and output channel.
192180
193181
Args:
194182
size: The size of the channel.
195183
input_name: The name of the input step.
196184
output_name: The name of the output step.
197185
"""
198-
channel = Channel(size)
186+
channel = Channel.create_with_naming(size, input_name, output_name)
199187
return StepInput(channel), StepOutput(channel)

nodestream/pipeline/pipeline.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from typing import Iterable, List, Tuple
44

55
from ..metrics import (
6-
Metrics,
76
RECORDS,
87
STEPS_RUNNING,
8+
Metrics,
99
)
1010
from ..schema import ExpandsSchema, ExpandsSchemaFromChildren
1111
from .channel import StepInput, StepOutput, channel
@@ -62,7 +62,6 @@ async def emit_record(self, record):
6262
async def drive_step(self):
6363
try:
6464
while (next_record := await self.input.get()) is not None:
65-
print(f"Driving step with record {next_record}")
6665
results = self.step.process_record(next_record, self.context)
6766
async for record in results:
6867
if not await self.emit_record(record):
@@ -179,8 +178,12 @@ async def run(self, reporter: PipelineProgressReporter):
179178
# the steps in the pipeline. The channels have a fixed size to control
180179
# the flow of records between the steps.
181180
executors: List[StepExecutor] = []
182-
current_input, current_output = channel(self.step_outbox_size)
183-
print(current_input, current_output)
181+
current_input_name = None
182+
current_output_name = self.steps[-1].__class__.__name__ + f"_{len(self.steps)}"
183+
184+
current_input, current_output = channel(
185+
self.step_outbox_size, current_output_name, current_input_name
186+
)
184187
pipeline_output = PipelineOutput(current_input, reporter)
185188

186189
# Create the executors for the steps in the pipeline. The executors
@@ -191,9 +194,16 @@ async def run(self, reporter: PipelineProgressReporter):
191194
index = len(self.steps) - reversed_index - 1
192195
storage = self.object_store.namespaced(str(index))
193196
context = StepContext(step.__class__.__name__, index, reporter, storage)
194-
current_input, next_output = channel(self.step_outbox_size)
195-
current_input.register(step.__class__.__name__)
196-
current_output.register(step.__class__.__name__)
197+
current_output_name = (
198+
self.steps[reversed_index - 1].__class__.__name__
199+
+ f"_{reversed_index - 1}"
200+
if reversed_index - 1 >= 0
201+
else None
202+
)
203+
current_input_name = step.__class__.__name__ + f"_{reversed_index}"
204+
current_input, next_output = channel(
205+
self.step_outbox_size, current_output_name, current_input_name
206+
)
197207
exec = StepExecutor(step, current_input, current_output, context)
198208
current_output = next_output
199209
executors.append(exec)

nodestream/pipeline/step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import AsyncGenerator, Optional
22

33
from ..metrics import (
4-
Metrics,
54
FATAL_ERRORS,
65
NON_FATAL_ERRORS,
6+
Metrics,
77
)
88
from .object_storage import ObjectStore
99
from .progress_reporter import PipelineProgressReporter

tests/unit/databases/test_query_executor_with_statistics.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
QueryExecutorWithStatistics,
77
)
88
from nodestream.metrics import (
9-
Metrics,
9+
INGEST_HOOKS_EXECUTED,
1010
NODES_UPSERTED,
1111
RELATIONSHIPS_UPSERTED,
1212
TIME_TO_LIVE_OPERATIONS,
13-
INGEST_HOOKS_EXECUTED,
13+
Metrics,
1414
)
1515
from nodestream.model import Node, Relationship, RelationshipWithNodes
1616

@@ -24,44 +24,42 @@ def query_executor_with_statistics(mocker):
2424
async def test_upsert_nodes_in_bulk_with_same_operation_increments_counter_by_size_of_list(
2525
query_executor_with_statistics, mocker
2626
):
27+
nodes = [
28+
Node("node_type", "node1", "id1"),
29+
Node("node_type", "node2", "id2"),
30+
]
2731
with Metrics.capture() as metrics:
2832
metrics.increment = mocker.Mock()
2933
await query_executor_with_statistics.upsert_nodes_in_bulk_with_same_operation(
3034
"operation",
31-
[Node("node_type", "node1", "id1"), Node("node_type", "node2", "id2")],
35+
nodes,
3236
)
3337
query_executor_with_statistics.inner.upsert_nodes_in_bulk_with_same_operation.assert_awaited_once_with(
3438
"operation",
35-
[Node("node_type", "node1", "id1"), Node("node_type", "node2", "id2")],
39+
nodes,
3640
)
3741

3842
assert "node_type" in query_executor_with_statistics.node_metric_by_type
3943
assert (
4044
call(query_executor_with_statistics.node_metric_by_type["node_type"], 2)
4145
in metrics.increment.call_args_list
4246
)
43-
assert (
44-
call(NODES_UPSERTED, 2)
45-
in metrics.increment.call_args_list
46-
)
47+
assert call(NODES_UPSERTED, 2) in metrics.increment.call_args_list
4748

4849

4950
@pytest.mark.asyncio
5051
async def test_upsert_relationships_in_bulk_of_same_operation_increments_counter_by_size_of_list(
5152
query_executor_with_statistics, mocker
5253
):
54+
relationships = [
55+
RelationshipWithNodes("node1", "node2", Relationship("relationship_type")),
56+
RelationshipWithNodes("node3", "node4", Relationship("relationship_type")),
57+
]
5358
with Metrics.capture() as metrics:
5459
metrics.increment = mocker.Mock()
5560
await query_executor_with_statistics.upsert_relationships_in_bulk_of_same_operation(
5661
"operation",
57-
[
58-
RelationshipWithNodes(
59-
"node1", "node2", Relationship("relationship_type")
60-
),
61-
RelationshipWithNodes(
62-
"node3", "node4", Relationship("relationship_type")
63-
),
64-
],
62+
relationships,
6563
)
6664
query_executor_with_statistics.inner.upsert_relationships_in_bulk_of_same_operation.assert_awaited_once_with(
6765
"operation",
@@ -87,10 +85,7 @@ async def test_upsert_relationships_in_bulk_of_same_operation_increments_counter
8785
)
8886
in metrics.increment.call_args_list
8987
)
90-
assert (
91-
call(RELATIONSHIPS_UPSERTED, 2)
92-
in metrics.increment.call_args_list
93-
)
88+
assert call(RELATIONSHIPS_UPSERTED, 2) in metrics.increment.call_args_list
9489

9590

9691
@pytest.mark.asyncio
@@ -103,9 +98,7 @@ async def test_perform_ttl_op_increments_counter_by_one(
10398
query_executor_with_statistics.inner.perform_ttl_op.assert_awaited_once_with(
10499
"config"
105100
)
106-
metrics.increment.assert_called_once_with(
107-
TIME_TO_LIVE_OPERATIONS
108-
)
101+
metrics.increment.assert_called_once_with(TIME_TO_LIVE_OPERATIONS)
109102

110103

111104
@pytest.mark.asyncio
@@ -118,9 +111,7 @@ async def test_execute_hook_increments_counter_by_one(
118111
query_executor_with_statistics.inner.execute_hook.assert_awaited_once_with(
119112
"hook"
120113
)
121-
metrics.increment.assert_called_once_with(
122-
INGEST_HOOKS_EXECUTED
123-
)
114+
metrics.increment.assert_called_once_with(INGEST_HOOKS_EXECUTED)
124115

125116

126117
@pytest.mark.asyncio

tests/unit/pipeline/test_pipeline.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
import pytest
22

3-
from nodestream.pipeline import Extractor
4-
from nodestream.pipeline.channel import Channel
5-
from nodestream.pipeline.object_storage import NullObjectStore
63
from nodestream.pipeline.pipeline import (
7-
Pipeline,
84
PipelineOutput,
95
PipelineProgressReporter,
106
Step,
@@ -140,43 +136,3 @@ def on_start_callback():
140136

141137
output.call_handling_errors(output.reporter.on_start_callback)
142138
output.reporter.logger.exception.assert_called_once()
143-
144-
145-
class TestExtractor(Extractor):
146-
async def extract_records(self):
147-
return None
148-
149-
150-
class TestStep2(Step):
151-
async def process_record(self, record, context: StepContext):
152-
yield record
153-
154-
155-
@pytest.mark.asyncio
156-
async def test_pipeline_channels_obtain_input_and_output_names_from_steps(mocker):
157-
channel = mocker.Mock(Channel)
158-
mock_input = mocker.Mock(StepInput, channel=channel)
159-
mock_output = mocker.Mock(StepOutput, channel=channel)
160-
mock_input.get.return_value = None
161-
mock_output.put.return_value = True
162-
with mocker.patch(
163-
"nodestream.pipeline.pipeline.channel", return_value=(mock_input, mock_output)
164-
):
165-
step1 = TestExtractor()
166-
step2 = TestStep2()
167-
pipeline = Pipeline(
168-
steps=[step1, step2], step_outbox_size=1, object_store=NullObjectStore()
169-
)
170-
await pipeline.run(
171-
reporter=PipelineProgressReporter(
172-
on_start_callback=lambda: None, logger=mocker.Mock()
173-
)
174-
)
175-
assert mock_input.register.call_args_list == [
176-
mocker.call("TestStep2"),
177-
mocker.call("TestExtractor"),
178-
]
179-
assert mock_output.register.call_args_list == [
180-
mocker.call("TestStep2"),
181-
mocker.call("TestExtractor"),
182-
]

tests/unit/test_metrics.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,12 @@
33
import pytest
44

55
from nodestream.metrics import (
6+
RECORDS,
67
AggregateHandler,
78
ConsoleMetricHandler,
89
JsonLogMetricHandler,
910
Metric,
1011
Metrics,
11-
RECORDS,
12-
NODES_UPSERTED,
13-
RELATIONSHIPS_UPSERTED,
14-
TIME_TO_LIVE_OPERATIONS,
15-
INGEST_HOOKS_EXECUTED,
1612
NullMetricHandler,
1713
PrometheusMetricHandler,
1814
)

0 commit comments

Comments
 (0)