Skip to content

Commit e14024d

Browse files
committed
Return GrpcStreamBroadcaster instances from the streaming methods
This makes it easier to close the streamers when they are no longer needed. Also rename the methods from `stream_*` to `*_stream`. Signed-off-by: Sahas Subramanian <[email protected]>
1 parent e3ccfec commit e14024d

File tree

3 files changed

+26
-39
lines changed

3 files changed

+26
-39
lines changed

src/frequenz/client/electricity_trading/_client.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from frequenz.api.electricity_trading.v1.electricity_trading_pb2_grpc import (
2525
ElectricityTradingServiceStub,
2626
)
27-
from frequenz.channels import Receiver
2827
from frequenz.client.base.client import BaseApiClient
2928
from frequenz.client.base.exception import ClientNotConnected
3029
from frequenz.client.base.streaming import GrpcStreamBroadcaster
@@ -218,7 +217,7 @@ def stub(self) -> electricity_trading_pb2_grpc.ElectricityTradingServiceAsyncStu
218217
# type-checker, so it can only be used for type hints.
219218
return self._stub # type: ignore
220219

221-
def stream_gridpool_orders(
220+
def gridpool_orders_stream(
222221
# pylint: disable=too-many-arguments, too-many-positional-arguments
223222
self,
224223
gridpool_id: int,
@@ -227,9 +226,9 @@ def stream_gridpool_orders(
227226
delivery_area: DeliveryArea | None = None,
228227
delivery_period: DeliveryPeriod | None = None,
229228
tag: str | None = None,
230-
max_size: int = 50,
231-
warn_on_overflow: bool = False,
232-
) -> Receiver[OrderDetail]:
229+
) -> GrpcStreamBroadcaster[
230+
electricity_trading_pb2.ReceiveGridpoolOrdersStreamResponse, OrderDetail
231+
]:
233232
"""
234233
Stream gridpool orders.
235234
@@ -240,10 +239,6 @@ def stream_gridpool_orders(
240239
delivery_area: Delivery area to filter for.
241240
delivery_period: Delivery period to filter for.
242241
tag: Tag to filter for.
243-
max_size: The maximum number of messages to buffer.
244-
warn_on_overflow: Whether to log a warning when the receiver's
245-
buffer is full and a message is dropped.
246-
247242
248243
Returns:
249244
Async generator of orders.
@@ -281,11 +276,9 @@ def stream_gridpool_orders(
281276
"Error occurred while streaming gridpool orders: %s", e
282277
)
283278
raise
284-
return self._gridpool_orders_streams[stream_key].new_receiver(
285-
warn_on_overflow=warn_on_overflow, maxsize=max_size
286-
)
279+
return self._gridpool_orders_streams[stream_key]
287280

288-
def stream_gridpool_trades(
281+
def gridpool_trades_stream(
289282
# pylint: disable=too-many-arguments, too-many-positional-arguments
290283
self,
291284
gridpool_id: int,
@@ -294,9 +287,9 @@ def stream_gridpool_trades(
294287
market_side: MarketSide | None = None,
295288
delivery_period: DeliveryPeriod | None = None,
296289
delivery_area: DeliveryArea | None = None,
297-
max_size: int = 50,
298-
warn_on_overflow: bool = False,
299-
) -> Receiver[Trade]:
290+
) -> GrpcStreamBroadcaster[
291+
electricity_trading_pb2.ReceiveGridpoolTradesStreamResponse, Trade
292+
]:
300293
"""
301294
Stream gridpool trades.
302295
@@ -307,9 +300,6 @@ def stream_gridpool_trades(
307300
market_side: The market side to filter for.
308301
delivery_period: The delivery period to filter for.
309302
delivery_area: The delivery area to filter for.
310-
max_size: The maximum number of messages to buffer.
311-
warn_on_overflow: Whether to log a warning when the receiver's
312-
buffer is full and a message is dropped.
313303
314304
Returns:
315305
The gridpool trades streamer.
@@ -347,20 +337,18 @@ def stream_gridpool_trades(
347337
"Error occurred while streaming gridpool trades: %s", e
348338
)
349339
raise
350-
return self._gridpool_trades_streams[stream_key].new_receiver(
351-
warn_on_overflow=warn_on_overflow, maxsize=max_size
352-
)
340+
return self._gridpool_trades_streams[stream_key]
353341

354-
def stream_public_trades(
342+
def public_trades_stream(
355343
# pylint: disable=too-many-arguments, too-many-positional-arguments
356344
self,
357345
states: list[TradeState] | None = None,
358346
delivery_period: DeliveryPeriod | None = None,
359347
buy_delivery_area: DeliveryArea | None = None,
360348
sell_delivery_area: DeliveryArea | None = None,
361-
max_size: int = 50,
362-
warn_on_overflow: bool = False,
363-
) -> Receiver[PublicTrade]:
349+
) -> GrpcStreamBroadcaster[
350+
electricity_trading_pb2.ReceivePublicTradesStreamResponse, PublicTrade
351+
]:
364352
"""
365353
Stream public trades.
366354
@@ -369,9 +357,6 @@ def stream_public_trades(
369357
delivery_period: Delivery period to filter for.
370358
buy_delivery_area: Buy delivery area to filter for.
371359
sell_delivery_area: Sell delivery area to filter for.
372-
max_size: The maximum number of messages to buffer.
373-
warn_on_overflow: Whether to log a warning when the receiver's
374-
buffer is full and a message is dropped.
375360
376361
Returns:
377362
Async generator of orders.
@@ -405,9 +390,7 @@ def stream_public_trades(
405390
except grpc.RpcError as e:
406391
_logger.exception("Error occurred while streaming public trades: %s", e)
407392
raise
408-
return self._public_trades_streams[public_trade_filter].new_receiver(
409-
warn_on_overflow=warn_on_overflow, maxsize=max_size
410-
)
393+
return self._public_trades_streams[public_trade_filter]
411394

412395
def validate_params(
413396
# pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-branches

src/frequenz/client/electricity_trading/cli/etrading.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async def list_public_trades(url: str, key: str, *, delivery_start: datetime) ->
7373
if delivery_start <= datetime.now(timezone.utc):
7474
return
7575

76-
stream = client.stream_public_trades(delivery_period=delivery_period)
76+
stream = client.public_trades_stream(delivery_period=delivery_period).new_receiver()
7777
async for trade in stream:
7878
print_public_trade(trade)
7979

@@ -111,7 +111,9 @@ async def list_gridpool_trades(
111111
if delivery_start and delivery_start <= datetime.now(timezone.utc):
112112
return
113113

114-
stream = client.stream_gridpool_trades(gid, delivery_period=delivery_period)
114+
stream = client.gridpool_trades_stream(
115+
gid, delivery_period=delivery_period
116+
).new_receiver()
115117
async for trade in stream:
116118
print_trade(trade)
117119

@@ -154,7 +156,9 @@ async def list_gridpool_orders(
154156
if delivery_start and delivery_start <= datetime.now(timezone.utc):
155157
return
156158

157-
stream = client.stream_gridpool_orders(gid, delivery_period=delivery_period)
159+
stream = client.gridpool_orders_stream(
160+
gid, delivery_period=delivery_period
161+
).new_receiver()
158162
async for order in stream:
159163
print_order(order)
160164

tests/test_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def set_up_order_detail_response(
133133

134134
async def test_stream_gridpool_orders(set_up: SetupParams) -> None:
135135
"""Test the method streaming gridpool orders."""
136-
set_up.client.stream_gridpool_orders(set_up.gridpool_id)
136+
set_up.client.gridpool_orders_stream(set_up.gridpool_id)
137137
await asyncio.sleep(0)
138138

139139
set_up.mock_stub.ReceiveGridpoolOrdersStream.assert_called_once()
@@ -146,7 +146,7 @@ async def test_stream_gridpool_orders_with_optional_inputs(set_up: SetupParams)
146146
# Fields to filter for
147147
order_states = [OrderState.ACTIVE]
148148

149-
set_up.client.stream_gridpool_orders(set_up.gridpool_id, order_states=order_states)
149+
set_up.client.gridpool_orders_stream(set_up.gridpool_id, order_states=order_states)
150150
await asyncio.sleep(0)
151151

152152
set_up.mock_stub.ReceiveGridpoolOrdersStream.assert_called_once()
@@ -161,7 +161,7 @@ async def test_stream_gridpool_trades(
161161
set_up: SetupParams,
162162
) -> None:
163163
"""Test the method streaming gridpool trades."""
164-
set_up.client.stream_gridpool_trades(
164+
set_up.client.gridpool_trades_stream(
165165
gridpool_id=set_up.gridpool_id, market_side=set_up.side
166166
)
167167
await asyncio.sleep(0)
@@ -179,7 +179,7 @@ async def test_stream_public_trades(
179179
# Fields to filter for
180180
trade_states = [TradeState.ACTIVE]
181181

182-
set_up.client.stream_public_trades(states=trade_states)
182+
set_up.client.public_trades_stream(states=trade_states)
183183
await asyncio.sleep(0)
184184

185185
set_up.mock_stub.ReceivePublicTradesStream.assert_called_once()

0 commit comments

Comments
 (0)