Skip to content

Commit 8385158

Browse files
Fetch subsequent paginated data on list_* functions (#82)
This PR addresses #80.
2 parents b57adab + 93888bc commit 8385158

File tree

4 files changed

+146
-112
lines changed

4 files changed

+146
-112
lines changed

RELEASE_NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* Add idiomatic string representations for `Power` and `Price` classes.
1717
* Add support for timeouts in the gRPC function calls
1818
* Export Client constants for external use
19+
* Fetch subsequent paginated data for `list_*` methods
1920

2021
## Bug Fixes
2122

integration_tests/test_api.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,12 @@ async def test_list_gridpool_orders(set_up: dict[str, Any]) -> None:
225225
created_orders_id = [(await create_test_order(set_up)).order_id for _ in range(10)]
226226

227227
# List the orders and check they are present
228-
orders = await set_up["client"].list_gridpool_orders(
229-
gridpool_id=GRIDPOOL_ID, delivery_period=set_up["delivery_period"]
230-
) # filter by delivery period to avoid fetching too many orders
231-
228+
# filter by delivery period to avoid fetching too many orders
229+
orders = [
230+
order async for order in set_up["client"].list_gridpool_orders(
231+
gridpool_id=GRIDPOOL_ID,
232+
delivery_period=set_up["delivery_period"])
233+
]
232234
listed_orders_id = [order.order_id for order in orders]
233235
for order_id in created_orders_id:
234236
assert order_id in listed_orders_id, f"Order ID {order_id} not found"
@@ -327,7 +329,11 @@ async def test_cancel_all_orders(set_up: dict[str, Any]) -> None:
327329
# Cancel all orders and check that did indeed get cancelled
328330
await set_up["client"].cancel_all_gridpool_orders(GRIDPOOL_ID)
329331

330-
orders = await set_up["client"].list_gridpool_orders(gridpool_id=GRIDPOOL_ID)
332+
orders = [
333+
order async for order in set_up["client"].list_gridpool_orders(
334+
gridpool_id=GRIDPOOL_ID,
335+
)
336+
]
331337

332338
for order in orders:
333339
assert (
@@ -339,22 +345,32 @@ async def test_cancel_all_orders(set_up: dict[str, Any]) -> None:
339345
async def test_list_gridpool_trades(set_up: dict[str, Any]) -> None:
340346
"""Test listing gridpool trades."""
341347
buy_order, sell_order = await create_test_trade(set_up)
342-
trades = await set_up["client"].list_gridpool_trades(
343-
GRIDPOOL_ID,
344-
delivery_period=buy_order.order.delivery_period,
345-
)
348+
trades = [
349+
trade async for trade in set_up["client"].list_gridpool_trades(
350+
GRIDPOOL_ID,
351+
delivery_period=buy_order.order.delivery_period,
352+
)
353+
]
346354
assert len(trades) >= 1
347355

348356

349357
@pytest.mark.asyncio
350358
async def test_list_public_trades(set_up: dict[str, Any]) -> None:
351359
"""Test listing public trades."""
352-
public_trades = await set_up["client"].list_public_trades(
353-
delivery_period=set_up["delivery_period"],
354-
max_nr_trades=10,
355-
)
356-
assert len(public_trades) >= 0
360+
delivery_period = DeliveryPeriod(
361+
start=datetime.fromisoformat("2024-06-10T10:00:00+00:00"),
362+
duration=timedelta(minutes=15)
363+
)
364+
365+
public_trades = []
366+
counter = 0
367+
async for trade in set_up["client"].list_public_trades(delivery_period=delivery_period):
368+
public_trades.append(trade)
369+
counter += 1
370+
if counter == 10:
371+
break
357372

373+
assert len(public_trades) == 10, "Failed to retrieve 10 public trades"
358374

359375
@pytest.mark.asyncio
360376
async def test_stream_gridpool_orders(set_up: dict[str, Any]) -> None:

src/frequenz/client/electricity_trading/_client.py

Lines changed: 108 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
import logging
1010
from datetime import datetime, timedelta, timezone
1111
from decimal import Decimal, InvalidOperation
12-
from typing import Any, Awaitable, Callable, cast
12+
from typing import Any, AsyncIterator, Awaitable, Callable, cast
1313

1414
import grpc
15+
from frequenz.api.common.v1.pagination.pagination_params_pb2 import PaginationParams
1516

1617
# pylint: disable=no-member
1718
from frequenz.api.electricity_trading.v1 import (
@@ -785,10 +786,9 @@ async def list_gridpool_orders(
785786
delivery_period: DeliveryPeriod | None = None,
786787
delivery_area: DeliveryArea | None = None,
787788
tag: str | None = None,
788-
max_nr_orders: int | None = None,
789-
page_token: str | None = None,
789+
page_size: int | None = None,
790790
timeout: timedelta | None = None,
791-
) -> list[OrderDetail]:
791+
) -> AsyncIterator[OrderDetail]:
792792
"""
793793
List orders for a specific Gridpool with optional filters.
794794
@@ -799,58 +799,57 @@ async def list_gridpool_orders(
799799
delivery_period: The delivery period to filter by.
800800
delivery_area: The delivery area to filter by.
801801
tag: The tag to filter by.
802-
max_nr_orders: The maximum number of orders to return.
803-
page_token: The page token to use for pagination.
802+
page_size: The number of orders to return per page.
804803
timeout: Timeout duration, defaults to None.
805804
806-
Returns:
807-
The list of orders for that gridpool.
805+
Yields:
806+
The list of orders for the given gridpool.
808807
809808
Raises:
810809
grpc.RpcError: If an error occurs while listing the orders.
811810
"""
812-
gridpool_order_filer = GridpoolOrderFilter(
811+
gridpool_order_filter = GridpoolOrderFilter(
813812
order_states=order_states,
814813
side=side,
815814
delivery_period=delivery_period,
816815
delivery_area=delivery_area,
817816
tag=tag,
818817
)
819818

820-
pagination_params = Params(
821-
page_size=max_nr_orders,
822-
page_token=page_token,
819+
request = electricity_trading_pb2.ListGridpoolOrdersRequest(
820+
gridpool_id=gridpool_id,
821+
filter=gridpool_order_filter.to_pb(),
822+
pagination_params=(
823+
Params(page_size=page_size).to_proto() if page_size else None
824+
),
823825
)
824-
825-
try:
826-
response = await cast(
827-
Awaitable[electricity_trading_pb2.ListGridpoolOrdersResponse],
828-
grpc_call_with_timeout(
829-
self.stub.ListGridpoolOrders,
830-
electricity_trading_pb2.ListGridpoolOrdersRequest(
831-
gridpool_id=gridpool_id,
832-
filter=gridpool_order_filer.to_pb(),
833-
pagination_params=pagination_params.to_proto(),
826+
while True:
827+
try:
828+
response = await cast(
829+
Awaitable[electricity_trading_pb2.ListGridpoolOrdersResponse],
830+
grpc_call_with_timeout(
831+
self.stub.ListGridpoolOrders,
832+
request,
833+
metadata=self._metadata,
834+
timeout=timeout,
834835
),
835-
metadata=self._metadata,
836-
timeout=timeout,
837-
),
838-
)
836+
)
839837

840-
orders: list[OrderDetail] = []
841-
for order_detail in response.order_details:
842-
try:
843-
orders.append(OrderDetail.from_pb(order_detail))
844-
except InvalidOperation:
845-
_logger.error(
846-
"Failed to convert order details for order: %s",
847-
str(order_detail).replace("\n", ""),
838+
for order_detail in response.order_details:
839+
yield OrderDetail.from_pb(order_detail)
840+
841+
if response.pagination_info.next_page_token:
842+
request.pagination_params.CopyFrom(
843+
PaginationParams(
844+
page_token=response.pagination_info.next_page_token
845+
)
848846
)
847+
else:
848+
break
849849

850-
return orders
851-
except grpc.RpcError as e:
852-
_logger.exception("Error occurred while listing gridpool orders: %s", e)
853-
raise
850+
except grpc.RpcError as e:
851+
_logger.exception("Error occurred while listing gridpool orders: %s", e)
852+
raise
854853

855854
async def list_gridpool_trades(
856855
# pylint: disable=too-many-arguments, too-many-positional-arguments
@@ -861,10 +860,9 @@ async def list_gridpool_trades(
861860
market_side: MarketSide | None = None,
862861
delivery_period: DeliveryPeriod | None = None,
863862
delivery_area: DeliveryArea | None = None,
864-
max_nr_trades: int | None = None,
865-
page_token: str | None = None,
863+
page_size: int | None = None,
866864
timeout: timedelta | None = None,
867-
) -> list[Trade]:
865+
) -> AsyncIterator[Trade]:
868866
"""
869867
List trades for a specific Gridpool with optional filters.
870868
@@ -875,11 +873,10 @@ async def list_gridpool_trades(
875873
market_side: The side of the market to filter by.
876874
delivery_period: The delivery period to filter by.
877875
delivery_area: The delivery area to filter by.
878-
max_nr_trades: The maximum number of trades to return.
879-
page_token: The page token to use for pagination.
876+
page_size: The number of trades to return per page.
880877
timeout: Timeout duration, defaults to None.
881878
882-
Returns:
879+
Yields:
883880
The list of trades for the given gridpool.
884881
885882
Raises:
@@ -893,30 +890,41 @@ async def list_gridpool_trades(
893890
delivery_area=delivery_area,
894891
)
895892

896-
pagination_params = Params(
897-
page_size=max_nr_trades,
898-
page_token=page_token,
893+
request = electricity_trading_pb2.ListGridpoolTradesRequest(
894+
gridpool_id=gridpool_id,
895+
filter=gridpool_trade_filter.to_pb(),
896+
pagination_params=(
897+
Params(page_size=page_size).to_proto() if page_size else None
898+
),
899899
)
900900

901-
try:
902-
response = await cast(
903-
Awaitable[electricity_trading_pb2.ListGridpoolTradesResponse],
904-
grpc_call_with_timeout(
905-
self.stub.ListGridpoolTrades,
906-
electricity_trading_pb2.ListGridpoolTradesRequest(
907-
gridpool_id=gridpool_id,
908-
filter=gridpool_trade_filter.to_pb(),
909-
pagination_params=pagination_params.to_proto(),
901+
while True:
902+
try:
903+
response = await cast(
904+
Awaitable[electricity_trading_pb2.ListGridpoolTradesResponse],
905+
grpc_call_with_timeout(
906+
self.stub.ListGridpoolTrades,
907+
request,
908+
metadata=self._metadata,
909+
timeout=timeout,
910910
),
911-
metadata=self._metadata,
912-
timeout=timeout,
913-
),
914-
)
911+
)
915912

916-
return [Trade.from_pb(trade) for trade in response.trades]
917-
except grpc.RpcError as e:
918-
_logger.exception("Error occurred while listing gridpool trades: %s", e)
919-
raise
913+
for trade in response.trades:
914+
yield Trade.from_pb(trade)
915+
916+
if response.pagination_info.next_page_token:
917+
request.pagination_params.CopyFrom(
918+
PaginationParams(
919+
page_token=response.pagination_info.next_page_token
920+
)
921+
)
922+
else:
923+
break
924+
925+
except grpc.RpcError as e:
926+
_logger.exception("Error occurred while listing gridpool trades: %s", e)
927+
raise
920928

921929
async def list_public_trades(
922930
# pylint: disable=too-many-arguments, too-many-positional-arguments
@@ -925,24 +933,22 @@ async def list_public_trades(
925933
delivery_period: DeliveryPeriod | None = None,
926934
buy_delivery_area: DeliveryArea | None = None,
927935
sell_delivery_area: DeliveryArea | None = None,
928-
max_nr_trades: int | None = None,
929-
page_token: str | None = None,
936+
page_size: int | None = None,
930937
timeout: timedelta | None = None,
931-
) -> list[PublicTrade]:
938+
) -> AsyncIterator[PublicTrade]:
932939
"""
933-
List all executed public orders with optional filters.
940+
List all executed public orders with optional filters and pagination.
934941
935942
Args:
936943
states: List of order states to filter by.
937944
delivery_period: The delivery period to filter by.
938945
buy_delivery_area: The buy delivery area to filter by.
939946
sell_delivery_area: The sell delivery area to filter by.
940-
max_nr_trades: The maximum number of trades to return.
941-
page_token: The page token to use for pagination.
947+
page_size: The number of public trades to return per page.
942948
timeout: Timeout duration, defaults to None.
943949
944-
Returns:
945-
The list of public trades.
950+
Yields:
951+
The list of public trades for each page.
946952
947953
Raises:
948954
grpc.RpcError: If an error occurs while listing public trades.
@@ -954,29 +960,37 @@ async def list_public_trades(
954960
sell_delivery_area=sell_delivery_area,
955961
)
956962

957-
pagination_params = Params(
958-
page_size=max_nr_trades,
959-
page_token=page_token,
963+
request = electricity_trading_pb2.ListPublicTradesRequest(
964+
filter=public_trade_filter.to_pb(),
965+
pagination_params=(
966+
Params(page_size=page_size).to_proto() if page_size else None
967+
),
960968
)
961969

962-
try:
963-
response = await cast(
964-
Awaitable[electricity_trading_pb2.ListPublicTradesResponse],
965-
grpc_call_with_timeout(
966-
self.stub.ListPublicTrades,
967-
electricity_trading_pb2.ListPublicTradesRequest(
968-
filter=public_trade_filter.to_pb(),
969-
pagination_params=pagination_params.to_proto(),
970+
while True:
971+
try:
972+
response = await cast(
973+
Awaitable[electricity_trading_pb2.ListPublicTradesResponse],
974+
grpc_call_with_timeout(
975+
self.stub.ListPublicTrades,
976+
request,
977+
metadata=self._metadata,
978+
timeout=timeout,
970979
),
971-
metadata=self._metadata,
972-
timeout=timeout,
973-
),
974-
)
980+
)
975981

976-
return [
977-
PublicTrade.from_pb(public_trade)
978-
for public_trade in response.public_trades
979-
]
980-
except grpc.RpcError as e:
981-
_logger.exception("Error occurred while listing public trades: %s", e)
982-
raise
982+
for public_trade in response.public_trades:
983+
yield PublicTrade.from_pb(public_trade)
984+
985+
if response.pagination_info.next_page_token:
986+
request.pagination_params.CopyFrom(
987+
PaginationParams(
988+
page_token=response.pagination_info.next_page_token
989+
)
990+
)
991+
else:
992+
break
993+
994+
except grpc.RpcError as e:
995+
_logger.exception("Error occurred while listing public trades: %s", e)
996+
raise

0 commit comments

Comments
 (0)