Skip to content

Commit 61ec9c6

Browse files
Fetch subsequent paginated data on list_public_trades
Signed-off-by: camille-bouvy-frequenz <[email protected]>
1 parent 779b3a5 commit 61ec9c6

File tree

2 files changed

+57
-37
lines changed

2 files changed

+57
-37
lines changed

integration_tests/test_api.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -347,14 +347,26 @@ async def test_list_gridpool_trades(set_up: dict[str, Any]) -> None:
347347

348348

349349
@pytest.mark.asyncio
350-
async def test_list_public_trades(set_up: dict[str, Any]) -> None:
350+
async def test_list_public_trades_pages(set_up: dict[str, Any]) -> None:
351351
"""Test listing public trades."""
352-
public_trades = await set_up["client"].list_public_trades(
353-
delivery_period=set_up["delivery_period"],
354-
max_nr_trades=10,
355-
)
356-
assert len(public_trades) >= 0
352+
public_trades = []
353+
page_limit = 3
354+
page_count = 0
357355

356+
delivery_period = DeliveryPeriod(
357+
start=datetime.fromisoformat("2024-06-10T10:00:00+00:00"),
358+
duration=timedelta(minutes=15)
359+
)
360+
async for page in set_up["client"].list_public_trades_pages(
361+
delivery_period=delivery_period,
362+
page_size=5,
363+
):
364+
public_trades.extend(page)
365+
page_count += 1
366+
if page_count >= page_limit:
367+
break
368+
369+
assert len(public_trades) == 15, "Public trades page count mismatch"
358370

359371
@pytest.mark.asyncio
360372
async def test_stream_gridpool_orders(set_up: dict[str, Any]) -> None:

src/frequenz/client/electricity_trading/_client.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import logging
99
from datetime import datetime, timezone
1010
from decimal import Decimal, InvalidOperation
11-
from typing import Awaitable, cast
11+
from typing import AsyncIterator, Awaitable, cast
1212

1313
import grpc
14+
from frequenz.api.common.v1.pagination.pagination_params_pb2 import PaginationParams
1415

1516
# pylint: disable=no-member
1617
from frequenz.api.electricity_trading.v1 import (
@@ -857,29 +858,27 @@ async def list_gridpool_trades(
857858
_logger.exception("Error occurred while listing gridpool trades: %s", e)
858859
raise
859860

860-
async def list_public_trades(
861+
async def list_public_trades_pages(
861862
# pylint: disable=too-many-arguments, too-many-positional-arguments
862863
self,
863864
states: list[TradeState] | None = None,
864865
delivery_period: DeliveryPeriod | None = None,
865866
buy_delivery_area: DeliveryArea | None = None,
866867
sell_delivery_area: DeliveryArea | None = None,
867-
max_nr_trades: int | None = None,
868-
page_token: str | None = None,
869-
) -> list[PublicTrade]:
868+
page_size: int | None = None,
869+
) -> AsyncIterator[list[PublicTrade]]:
870870
"""
871-
List all executed public orders with optional filters.
871+
List all executed public orders with optional filters and pagination.
872872
873873
Args:
874874
states: List of order states to filter by.
875875
delivery_period: The delivery period to filter by.
876876
buy_delivery_area: The buy delivery area to filter by.
877877
sell_delivery_area: The sell delivery area to filter by.
878-
max_nr_trades: The maximum number of trades to return.
879-
page_token: The page token to use for pagination.
878+
page_size: The number of public trades to return per page.
880879
881-
Returns:
882-
The list of public trades.
880+
Yields:
881+
The list of public trades for each page.
883882
884883
Raises:
885884
grpc.RpcError: If an error occurs while listing public trades.
@@ -891,27 +890,36 @@ async def list_public_trades(
891890
sell_delivery_area=sell_delivery_area,
892891
)
893892

894-
pagination_params = Params(
895-
page_size=max_nr_trades,
896-
page_token=page_token,
893+
request = electricity_trading_pb2.ListPublicTradesRequest(
894+
filter=public_trade_filter.to_pb(),
895+
pagination_params=(
896+
Params(page_size=page_size).to_proto() if page_size else None
897+
),
897898
)
898899

899-
try:
900-
response = await cast(
901-
Awaitable[electricity_trading_pb2.ListPublicTradesResponse],
902-
self.stub.ListPublicTrades(
903-
electricity_trading_pb2.ListPublicTradesRequest(
904-
filter=public_trade_filter.to_pb(),
905-
pagination_params=pagination_params.to_proto(),
906-
),
907-
metadata=self._metadata,
908-
),
909-
)
900+
while True:
901+
try:
902+
response = await cast(
903+
Awaitable[electricity_trading_pb2.ListPublicTradesResponse],
904+
self.stub.ListPublicTrades(request, metadata=self._metadata),
905+
)
910906

911-
return [
912-
PublicTrade.from_pb(public_trade)
913-
for public_trade in response.public_trades
914-
]
915-
except grpc.RpcError as e:
916-
_logger.exception("Error occurred while listing public trades: %s", e)
917-
raise
907+
public_trades = [
908+
PublicTrade.from_pb(public_trade)
909+
for public_trade in response.public_trades
910+
]
911+
912+
yield public_trades
913+
914+
if len(response.pagination_info.next_page_token):
915+
request.pagination_params.CopyFrom(
916+
PaginationParams(
917+
page_token=response.pagination_info.next_page_token
918+
)
919+
)
920+
else:
921+
break
922+
923+
except grpc.RpcError as e:
924+
_logger.exception("Error occurred while listing public trades: %s", e)
925+
raise

0 commit comments

Comments
 (0)