@@ -37,15 +37,32 @@ class Channel:
37
37
prevent one step from overwhelming another step with too many records.
38
38
"""
39
39
40
- __slots__ = ("queue" , "input_dropped" , "metric " )
40
+ __slots__ = ("queue" , "input_dropped" , "_metric" , "input_name" , "output_name " )
41
41
42
- def __init__ (self , size : int , input_name : str , output_name : str ) -> None :
42
+ def __init__ (self , size : int ) -> None :
43
43
self .queue = Queue (maxsize = size )
44
44
self .input_dropped = False
45
- self .metric = Metric (
46
- f"buffered_{ input_name } _to_{ output_name } " ,
47
- f"Records buffered: { input_name } → { output_name } " ,
48
- )
45
+ self .input_name = "Void"
46
+ self .output_name = "Void"
47
+ self ._metric = None
48
+
49
+ @property
50
+ def metric (self ) -> Metric :
51
+ """Get the metric for the channel."""
52
+ if self ._metric is None :
53
+ self ._metric = Metric (
54
+ f"buffered_{ self .output_name } _to_{ self .input_name } " ,
55
+ f"Records buffered: { self .output_name } → { self .input_name } " ,
56
+ )
57
+ return self ._metric
58
+
59
+ def register_input (self , name : str ) -> None :
60
+ """Register the name of the step that will consume from this channel."""
61
+ self .input_name = name
62
+
63
+ def register_output (self , name : str ) -> None :
64
+ """Register the name of the step that will produce to this channel."""
65
+ self .output_name = name
49
66
50
67
async def get (self ):
51
68
"""Get an object from the channel.
@@ -58,6 +75,7 @@ async def get(self):
58
75
object: The object that was retrieved from the channel.
59
76
"""
60
77
object = await self .queue .get ()
78
+ Metrics .get ().decrement (self .metric )
61
79
return object
62
80
63
81
async def put (self , obj ) -> bool :
@@ -92,6 +110,10 @@ class StepOutput:
92
110
def __init__ (self , channel : Channel ) -> None :
93
111
self .channel = channel
94
112
113
+ def register (self , name : str ) -> None :
114
+ """Register the name of the step that will produce to this channel."""
115
+ self .channel .register_output (name )
116
+
95
117
async def done (self ):
96
118
"""Mark the output channel as done.
97
119
@@ -138,6 +160,10 @@ class StepInput:
138
160
def __init__ (self , channel : Channel ) -> None :
139
161
self .channel = channel
140
162
163
+ def register (self , name : str ) -> None :
164
+ """Register the name of the step that will consume from this channel."""
165
+ self .channel .register_input (name )
166
+
141
167
async def get (self ) -> Optional [object ]:
142
168
"""Get an object from the input channel.
143
169
@@ -161,15 +187,13 @@ def done(self):
161
187
self .channel .input_dropped = True
162
188
163
189
164
- def channel (
165
- size : int , input_name : str , output_name : str
166
- ) -> Tuple [StepInput , StepOutput ]:
190
+ def channel (size : int ) -> Tuple [StepInput , StepOutput ]:
167
191
"""Create a new input and output channel.
168
192
169
193
Args:
170
194
size: The size of the channel.
171
195
input_name: The name of the input step.
172
196
output_name: The name of the output step.
173
197
"""
174
- channel = Channel (size , input_name , output_name )
198
+ channel = Channel (size )
175
199
return StepInput (channel ), StepOutput (channel )
0 commit comments