Skip to content

Commit 9cf5ae1

Browse files
Unify Public Trades streaming and listing
- Removed `list_public_trades` - Replaced `public_trades_stream` with `receive_public_trades` - Updated the method to support streaming with optional time range (`start_time`, `end_time`) - Update the unit tests with the new function name Signed-off-by: camille-bouvy-frequenz <[email protected]>
1 parent e92212c commit 9cf5ae1

File tree

2 files changed

+51
-105
lines changed

2 files changed

+51
-105
lines changed

src/frequenz/client/electricity_trading/_client.py

Lines changed: 48 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from frequenz.client.base.streaming import GrpcStreamBroadcaster
3030
from frequenz.client.common.pagination import Params
3131
from google.protobuf import field_mask_pb2, struct_pb2
32+
from google.protobuf.timestamp_pb2 import Timestamp
3233

3334
from ._types import (
3435
DeliveryArea,
@@ -345,62 +346,6 @@ def gridpool_trades_stream(
345346
raise
346347
return self._gridpool_trades_streams[stream_key]
347348

348-
def public_trades_stream(
349-
# pylint: disable=too-many-arguments, too-many-positional-arguments
350-
self,
351-
states: list[TradeState] | None = None,
352-
delivery_period: DeliveryPeriod | None = None,
353-
buy_delivery_area: DeliveryArea | None = None,
354-
sell_delivery_area: DeliveryArea | None = None,
355-
) -> GrpcStreamBroadcaster[
356-
electricity_trading_pb2.ReceivePublicTradesStreamResponse, PublicTrade
357-
]:
358-
"""
359-
Stream public trades.
360-
361-
Args:
362-
states: List of order states to filter for.
363-
delivery_period: Delivery period to filter for.
364-
buy_delivery_area: Buy delivery area to filter for.
365-
sell_delivery_area: Sell delivery area to filter for.
366-
367-
Returns:
368-
Async generator of orders.
369-
370-
Raises:
371-
grpc.RpcError: If an error occurs while streaming public trades.
372-
"""
373-
self.validate_params(delivery_period=delivery_period)
374-
375-
public_trade_filter = PublicTradeFilter(
376-
states=states,
377-
delivery_period=delivery_period,
378-
buy_delivery_area=buy_delivery_area,
379-
sell_delivery_area=sell_delivery_area,
380-
)
381-
382-
if (
383-
public_trade_filter not in self._public_trades_streams
384-
or not self._public_trades_streams[public_trade_filter].is_running
385-
):
386-
try:
387-
self._public_trades_streams[public_trade_filter] = (
388-
GrpcStreamBroadcaster(
389-
f"electricity-trading-{public_trade_filter}",
390-
lambda: self.stub.ReceivePublicTradesStream(
391-
electricity_trading_pb2.ReceivePublicTradesStreamRequest(
392-
filter=public_trade_filter.to_pb(),
393-
),
394-
metadata=self._metadata,
395-
),
396-
lambda response: PublicTrade.from_pb(response.public_trade),
397-
)
398-
)
399-
except grpc.RpcError as e:
400-
_logger.exception("Error occurred while streaming public trades: %s", e)
401-
raise
402-
return self._public_trades_streams[public_trade_filter]
403-
404349
def validate_params(
405350
# pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-branches
406351
self,
@@ -947,73 +892,74 @@ async def list_gridpool_trades(
947892
_logger.exception("Error occurred while listing gridpool trades: %s", e)
948893
raise
949894

950-
async def list_public_trades(
895+
def receive_public_trades(
951896
# pylint: disable=too-many-arguments, too-many-positional-arguments
952897
self,
953898
states: list[TradeState] | None = None,
954899
delivery_period: DeliveryPeriod | None = None,
955900
buy_delivery_area: DeliveryArea | None = None,
956901
sell_delivery_area: DeliveryArea | None = None,
957-
page_size: int | None = None,
958-
timeout: timedelta | None = None,
959-
) -> AsyncIterator[PublicTrade]:
902+
start_time: datetime | None = None,
903+
end_time: datetime | None = None,
904+
) -> GrpcStreamBroadcaster[
905+
electricity_trading_pb2.ReceivePublicTradesStreamResponse, PublicTrade
906+
]:
960907
"""
961-
List all executed public orders with optional filters and pagination.
908+
Stream public trades with optional filters and time range.
962909
963910
Args:
964-
states: List of order states to filter by.
965-
delivery_period: The delivery period to filter by.
966-
buy_delivery_area: The buy delivery area to filter by.
967-
sell_delivery_area: The sell delivery area to filter by.
968-
page_size: The number of public trades to return per page.
969-
timeout: Timeout duration, defaults to None.
911+
states: List of order states to filter for.
912+
delivery_period: Delivery period to filter for.
913+
buy_delivery_area: Buy delivery area to filter for.
914+
sell_delivery_area: Sell delivery area to filter for.
915+
start_time: The starting timestamp to stream trades from. If None, streams from now.
916+
end_time: The ending timestamp to stop streaming trades. If None, streams indefinitely.
970917
971-
Yields:
972-
The list of public trades for each page.
918+
Returns:
919+
Async generator of orders.
973920
974921
Raises:
975-
grpc.RpcError: If an error occurs while listing public trades.
922+
grpc.RpcError: If an error occurs while streaming public trades.
976923
"""
924+
925+
def dt_to_pb_timestamp(dt: datetime) -> Timestamp:
926+
ts = Timestamp()
927+
ts.FromDatetime(dt)
928+
return ts
929+
977930
public_trade_filter = PublicTradeFilter(
978931
states=states,
979932
delivery_period=delivery_period,
980933
buy_delivery_area=buy_delivery_area,
981934
sell_delivery_area=sell_delivery_area,
982935
)
983936

984-
request = electricity_trading_pb2.ListPublicTradesRequest(
985-
filter=public_trade_filter.to_pb(),
986-
pagination_params=(
987-
Params(page_size=page_size, page_token="").to_proto()
988-
if page_size
989-
else None
990-
),
991-
)
992-
993-
while True:
937+
if (
938+
public_trade_filter not in self._public_trades_streams
939+
or not self._public_trades_streams[public_trade_filter].is_running
940+
):
994941
try:
995-
response = await cast(
996-
Awaitable[electricity_trading_pb2.ListPublicTradesResponse],
997-
grpc_call_with_timeout(
998-
self.stub.ListPublicTrades,
999-
request,
1000-
metadata=self._metadata,
1001-
timeout=timeout,
1002-
),
1003-
)
1004-
1005-
for public_trade in response.public_trades:
1006-
yield PublicTrade.from_pb(public_trade)
1007-
1008-
if response.pagination_info.next_page_token:
1009-
request.pagination_params.CopyFrom(
1010-
PaginationParams(
1011-
page_token=response.pagination_info.next_page_token
1012-
)
942+
self._public_trades_streams[public_trade_filter] = (
943+
GrpcStreamBroadcaster(
944+
f"electricity-trading-{public_trade_filter}",
945+
lambda: self.stub.ReceivePublicTradesStream(
946+
electricity_trading_pb2.ReceivePublicTradesStreamRequest(
947+
filter=public_trade_filter.to_pb(),
948+
start_time=(
949+
dt_to_pb_timestamp(start_time)
950+
if start_time
951+
else None
952+
),
953+
end_time=(
954+
dt_to_pb_timestamp(end_time) if end_time else None
955+
),
956+
),
957+
metadata=self._metadata,
958+
),
959+
lambda response: PublicTrade.from_pb(response.public_trade),
1013960
)
1014-
else:
1015-
break
1016-
961+
)
1017962
except grpc.RpcError as e:
1018-
_logger.exception("Error occurred while listing public trades: %s", e)
963+
_logger.exception("Error occurred while streaming public trades: %s", e)
1019964
raise
965+
return self._public_trades_streams[public_trade_filter]

tests/test_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,14 @@ async def test_stream_gridpool_trades(
172172
assert args[0].filter.side == set_up.side.to_pb()
173173

174174

175-
async def test_stream_public_trades(
175+
async def test_receive_public_trades(
176176
set_up: SetupParams,
177177
) -> None:
178-
"""Test the method streaming public trades."""
178+
"""Test the method receiving public trades."""
179179
# Fields to filter for
180180
trade_states = [TradeState.ACTIVE]
181181

182-
set_up.client.public_trades_stream(states=trade_states)
182+
set_up.client.receive_public_trades(states=trade_states)
183183
await asyncio.sleep(0)
184184

185185
set_up.mock_stub.ReceivePublicTradesStream.assert_called_once()

0 commit comments

Comments
 (0)