Skip to content

Commit e33bd3f

Browse files
Unify Public Trades streaming and listing
- Removed `list_public_trades` - Replaced `public_trades_stream` with `receive_public_trades` - Updated the method to support streaming with optional time range (`start_time`, `end_time`) - Update the unit tests with the new function name Signed-off-by: camille-bouvy-frequenz <[email protected]>
1 parent e534476 commit e33bd3f

File tree

2 files changed

+53
-103
lines changed

2 files changed

+53
-103
lines changed

src/frequenz/client/electricity_trading/_client.py

Lines changed: 50 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from frequenz.client.base.streaming import GrpcStreamBroadcaster
3030
from frequenz.client.common.pagination import Params
3131
from google.protobuf import field_mask_pb2, struct_pb2
32+
from google.protobuf.timestamp_pb2 import Timestamp
3233

3334
from ._types import (
3435
DeliveryArea,
@@ -345,62 +346,6 @@ def gridpool_trades_stream(
345346
raise
346347
return self._gridpool_trades_streams[stream_key]
347348

348-
def public_trades_stream(
349-
# pylint: disable=too-many-arguments, too-many-positional-arguments
350-
self,
351-
states: list[TradeState] | None = None,
352-
delivery_period: DeliveryPeriod | None = None,
353-
buy_delivery_area: DeliveryArea | None = None,
354-
sell_delivery_area: DeliveryArea | None = None,
355-
) -> GrpcStreamBroadcaster[
356-
electricity_trading_pb2.ReceivePublicTradesStreamResponse, PublicTrade
357-
]:
358-
"""
359-
Stream public trades.
360-
361-
Args:
362-
states: List of order states to filter for.
363-
delivery_period: Delivery period to filter for.
364-
buy_delivery_area: Buy delivery area to filter for.
365-
sell_delivery_area: Sell delivery area to filter for.
366-
367-
Returns:
368-
Async generator of orders.
369-
370-
Raises:
371-
grpc.RpcError: If an error occurs while streaming public trades.
372-
"""
373-
self.validate_params(delivery_period=delivery_period)
374-
375-
public_trade_filter = PublicTradeFilter(
376-
states=states,
377-
delivery_period=delivery_period,
378-
buy_delivery_area=buy_delivery_area,
379-
sell_delivery_area=sell_delivery_area,
380-
)
381-
382-
if (
383-
public_trade_filter not in self._public_trades_streams
384-
or not self._public_trades_streams[public_trade_filter].is_running
385-
):
386-
try:
387-
self._public_trades_streams[public_trade_filter] = (
388-
GrpcStreamBroadcaster(
389-
f"electricity-trading-{public_trade_filter}",
390-
lambda: self.stub.ReceivePublicTradesStream(
391-
electricity_trading_pb2.ReceivePublicTradesStreamRequest(
392-
filter=public_trade_filter.to_pb(),
393-
),
394-
metadata=self._metadata,
395-
),
396-
lambda response: PublicTrade.from_pb(response.public_trade),
397-
)
398-
)
399-
except grpc.RpcError as e:
400-
_logger.exception("Error occurred while streaming public trades: %s", e)
401-
raise
402-
return self._public_trades_streams[public_trade_filter]
403-
404349
def validate_params(
405350
# pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-branches
406351
self,
@@ -943,71 +888,76 @@ async def list_gridpool_trades(
943888
_logger.exception("Error occurred while listing gridpool trades: %s", e)
944889
raise
945890

946-
async def list_public_trades(
891+
def receive_public_trades(
947892
# pylint: disable=too-many-arguments, too-many-positional-arguments
948893
self,
949894
states: list[TradeState] | None = None,
950895
delivery_period: DeliveryPeriod | None = None,
951896
buy_delivery_area: DeliveryArea | None = None,
952897
sell_delivery_area: DeliveryArea | None = None,
953-
page_size: int | None = None,
954-
timeout: timedelta | None = None,
955-
) -> AsyncIterator[PublicTrade]:
898+
start_time: datetime | None = None,
899+
end_time: datetime | None = None,
900+
) -> GrpcStreamBroadcaster[
901+
electricity_trading_pb2.ReceivePublicTradesStreamResponse, PublicTrade
902+
]:
956903
"""
957-
List all executed public orders with optional filters and pagination.
904+
Stream public trades with optional filters and time range.
958905
959906
Args:
960-
states: List of order states to filter by.
961-
delivery_period: The delivery period to filter by.
962-
buy_delivery_area: The buy delivery area to filter by.
963-
sell_delivery_area: The sell delivery area to filter by.
964-
page_size: The number of public trades to return per page.
965-
timeout: Timeout duration, defaults to None.
907+
states: List of order states to filter for.
908+
delivery_period: Delivery period to filter for.
909+
buy_delivery_area: Buy delivery area to filter for.
910+
sell_delivery_area: Sell delivery area to filter for.
911+
start_time: The starting timestamp to stream trades from. If None, streams from now.
912+
end_time: The ending timestamp to stop streaming trades. If None, streams indefinitely.
966913
967-
Yields:
968-
The list of public trades for each page.
914+
Returns:
915+
Async generator of orders.
969916
970917
Raises:
971-
grpc.RpcError: If an error occurs while listing public trades.
918+
grpc.RpcError: If an error occurs while streaming public trades.
972919
"""
920+
921+
def dt_to_pb_timestamp(dt: datetime) -> Timestamp:
922+
ts = Timestamp()
923+
ts.FromDatetime(dt)
924+
return ts
925+
926+
self.validate_params(delivery_period=delivery_period)
927+
973928
public_trade_filter = PublicTradeFilter(
974929
states=states,
975930
delivery_period=delivery_period,
976931
buy_delivery_area=buy_delivery_area,
977932
sell_delivery_area=sell_delivery_area,
978933
)
979934

980-
request = electricity_trading_pb2.ListPublicTradesRequest(
981-
filter=public_trade_filter.to_pb(),
982-
pagination_params=(
983-
Params(page_size=page_size).to_proto() if page_size else None
984-
),
985-
)
986-
987-
while True:
935+
if (
936+
public_trade_filter not in self._public_trades_streams
937+
or not self._public_trades_streams[public_trade_filter].is_running
938+
):
988939
try:
989-
response = await cast(
990-
Awaitable[electricity_trading_pb2.ListPublicTradesResponse],
991-
grpc_call_with_timeout(
992-
self.stub.ListPublicTrades,
993-
request,
994-
metadata=self._metadata,
995-
timeout=timeout,
996-
),
997-
)
998-
999-
for public_trade in response.public_trades:
1000-
yield PublicTrade.from_pb(public_trade)
1001-
1002-
if response.pagination_info.next_page_token:
1003-
request.pagination_params.CopyFrom(
1004-
PaginationParams(
1005-
page_token=response.pagination_info.next_page_token
1006-
)
940+
self._public_trades_streams[public_trade_filter] = (
941+
GrpcStreamBroadcaster(
942+
f"electricity-trading-{public_trade_filter}",
943+
lambda: self.stub.ReceivePublicTradesStream(
944+
electricity_trading_pb2.ReceivePublicTradesStreamRequest(
945+
filter=public_trade_filter.to_pb(),
946+
start_time=(
947+
dt_to_pb_timestamp(start_time)
948+
if start_time
949+
else None
950+
),
951+
end_time=(
952+
dt_to_pb_timestamp(end_time) if end_time else None
953+
),
954+
),
955+
metadata=self._metadata,
956+
),
957+
lambda response: PublicTrade.from_pb(response.public_trade),
1007958
)
1008-
else:
1009-
break
1010-
959+
)
1011960
except grpc.RpcError as e:
1012-
_logger.exception("Error occurred while listing public trades: %s", e)
961+
_logger.exception("Error occurred while streaming public trades: %s", e)
1013962
raise
963+
return self._public_trades_streams[public_trade_filter]

tests/test_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,14 @@ async def test_stream_gridpool_trades(
172172
assert args[0].filter.side == set_up.side.to_pb()
173173

174174

175-
async def test_stream_public_trades(
175+
async def test_receive_public_trades(
176176
set_up: SetupParams,
177177
) -> None:
178-
"""Test the method streaming public trades."""
178+
"""Test the method receiving public trades."""
179179
# Fields to filter for
180180
trade_states = [TradeState.ACTIVE]
181181

182-
set_up.client.public_trades_stream(states=trade_states)
182+
set_up.client.receive_public_trades(states=trade_states)
183183
await asyncio.sleep(0)
184184

185185
set_up.mock_stub.ReceivePublicTradesStream.assert_called_once()

0 commit comments

Comments
 (0)