Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* Add idiomatic string representations for `Power` and `Price` classes.
* Add support for timeouts in the gRPC function calls
* Export Client constants for external use
* Fetch subsequent paginated data for `list_*` methods

## Bug Fixes

Expand Down
44 changes: 30 additions & 14 deletions integration_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,12 @@ async def test_list_gridpool_orders(set_up: dict[str, Any]) -> None:
created_orders_id = [(await create_test_order(set_up)).order_id for _ in range(10)]

# List the orders and check they are present
orders = await set_up["client"].list_gridpool_orders(
gridpool_id=GRIDPOOL_ID, delivery_period=set_up["delivery_period"]
) # filter by delivery period to avoid fetching too many orders

# filter by delivery period to avoid fetching too many orders
orders = [
order async for order in set_up["client"].list_gridpool_orders(
gridpool_id=GRIDPOOL_ID,
delivery_period=set_up["delivery_period"])
]
listed_orders_id = [order.order_id for order in orders]
for order_id in created_orders_id:
assert order_id in listed_orders_id, f"Order ID {order_id} not found"
Expand Down Expand Up @@ -327,7 +329,11 @@ async def test_cancel_all_orders(set_up: dict[str, Any]) -> None:
# Cancel all orders and check that did indeed get cancelled
await set_up["client"].cancel_all_gridpool_orders(GRIDPOOL_ID)

orders = await set_up["client"].list_gridpool_orders(gridpool_id=GRIDPOOL_ID)
orders = [
order async for order in set_up["client"].list_gridpool_orders(
gridpool_id=GRIDPOOL_ID,
)
]

for order in orders:
assert (
Expand All @@ -339,22 +345,32 @@ async def test_cancel_all_orders(set_up: dict[str, Any]) -> None:
async def test_list_gridpool_trades(set_up: dict[str, Any]) -> None:
"""Test listing gridpool trades."""
buy_order, sell_order = await create_test_trade(set_up)
trades = await set_up["client"].list_gridpool_trades(
GRIDPOOL_ID,
delivery_period=buy_order.order.delivery_period,
)
trades = [
trade async for trade in set_up["client"].list_gridpool_trades(
GRIDPOOL_ID,
delivery_period=buy_order.order.delivery_period,
)
]
assert len(trades) >= 1


@pytest.mark.asyncio
async def test_list_public_trades(set_up: dict[str, Any]) -> None:
"""Test listing public trades."""
public_trades = await set_up["client"].list_public_trades(
delivery_period=set_up["delivery_period"],
max_nr_trades=10,
)
assert len(public_trades) >= 0
delivery_period = DeliveryPeriod(
start=datetime.fromisoformat("2024-06-10T10:00:00+00:00"),
duration=timedelta(minutes=15)
)

public_trades = []
counter = 0
async for trade in set_up["client"].list_public_trades(delivery_period=delivery_period):
public_trades.append(trade)
counter += 1
if counter == 10:
break

assert len(public_trades) == 10, "Failed to retrieve 10 public trades"

@pytest.mark.asyncio
async def test_stream_gridpool_orders(set_up: dict[str, Any]) -> None:
Expand Down
202 changes: 108 additions & 94 deletions src/frequenz/client/electricity_trading/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import logging
from datetime import datetime, timedelta, timezone
from decimal import Decimal, InvalidOperation
from typing import Any, Awaitable, Callable, cast
from typing import Any, AsyncIterator, Awaitable, Callable, cast

import grpc
from frequenz.api.common.v1.pagination.pagination_params_pb2 import PaginationParams

# pylint: disable=no-member
from frequenz.api.electricity_trading.v1 import (
Expand Down Expand Up @@ -785,10 +786,9 @@ async def list_gridpool_orders(
delivery_period: DeliveryPeriod | None = None,
delivery_area: DeliveryArea | None = None,
tag: str | None = None,
max_nr_orders: int | None = None,
page_token: str | None = None,
page_size: int | None = None,
timeout: timedelta | None = None,
) -> list[OrderDetail]:
) -> AsyncIterator[OrderDetail]:
"""
List orders for a specific Gridpool with optional filters.

Expand All @@ -799,58 +799,57 @@ async def list_gridpool_orders(
delivery_period: The delivery period to filter by.
delivery_area: The delivery area to filter by.
tag: The tag to filter by.
max_nr_orders: The maximum number of orders to return.
page_token: The page token to use for pagination.
page_size: The number of orders to return per page.
timeout: Timeout duration, defaults to None.

Returns:
The list of orders for that gridpool.
Yields:
The list of orders for the given gridpool.

Raises:
grpc.RpcError: If an error occurs while listing the orders.
"""
gridpool_order_filer = GridpoolOrderFilter(
gridpool_order_filter = GridpoolOrderFilter(
order_states=order_states,
side=side,
delivery_period=delivery_period,
delivery_area=delivery_area,
tag=tag,
)

pagination_params = Params(
page_size=max_nr_orders,
page_token=page_token,
request = electricity_trading_pb2.ListGridpoolOrdersRequest(
gridpool_id=gridpool_id,
filter=gridpool_order_filter.to_pb(),
pagination_params=(
Params(page_size=page_size).to_proto() if page_size else None
),
)

try:
response = await cast(
Awaitable[electricity_trading_pb2.ListGridpoolOrdersResponse],
grpc_call_with_timeout(
self.stub.ListGridpoolOrders,
electricity_trading_pb2.ListGridpoolOrdersRequest(
gridpool_id=gridpool_id,
filter=gridpool_order_filer.to_pb(),
pagination_params=pagination_params.to_proto(),
while True:
try:
response = await cast(
Awaitable[electricity_trading_pb2.ListGridpoolOrdersResponse],
grpc_call_with_timeout(
self.stub.ListGridpoolOrders,
request,
metadata=self._metadata,
timeout=timeout,
),
metadata=self._metadata,
timeout=timeout,
),
)
)

orders: list[OrderDetail] = []
for order_detail in response.order_details:
try:
orders.append(OrderDetail.from_pb(order_detail))
except InvalidOperation:
_logger.error(
"Failed to convert order details for order: %s",
str(order_detail).replace("\n", ""),
for order_detail in response.order_details:
yield OrderDetail.from_pb(order_detail)

if response.pagination_info.next_page_token:
request.pagination_params.CopyFrom(
PaginationParams(
page_token=response.pagination_info.next_page_token
)
)
else:
break

return orders
except grpc.RpcError as e:
_logger.exception("Error occurred while listing gridpool orders: %s", e)
raise
except grpc.RpcError as e:
_logger.exception("Error occurred while listing gridpool orders: %s", e)
raise

async def list_gridpool_trades(
# pylint: disable=too-many-arguments, too-many-positional-arguments
Expand All @@ -861,10 +860,9 @@ async def list_gridpool_trades(
market_side: MarketSide | None = None,
delivery_period: DeliveryPeriod | None = None,
delivery_area: DeliveryArea | None = None,
max_nr_trades: int | None = None,
page_token: str | None = None,
page_size: int | None = None,
timeout: timedelta | None = None,
) -> list[Trade]:
) -> AsyncIterator[Trade]:
"""
List trades for a specific Gridpool with optional filters.

Expand All @@ -875,11 +873,10 @@ async def list_gridpool_trades(
market_side: The side of the market to filter by.
delivery_period: The delivery period to filter by.
delivery_area: The delivery area to filter by.
max_nr_trades: The maximum number of trades to return.
page_token: The page token to use for pagination.
page_size: The number of trades to return per page.
timeout: Timeout duration, defaults to None.

Returns:
Yields:
The list of trades for the given gridpool.

Raises:
Expand All @@ -893,30 +890,41 @@ async def list_gridpool_trades(
delivery_area=delivery_area,
)

pagination_params = Params(
page_size=max_nr_trades,
page_token=page_token,
request = electricity_trading_pb2.ListGridpoolTradesRequest(
gridpool_id=gridpool_id,
filter=gridpool_trade_filter.to_pb(),
pagination_params=(
Params(page_size=page_size).to_proto() if page_size else None
),
)

try:
response = await cast(
Awaitable[electricity_trading_pb2.ListGridpoolTradesResponse],
grpc_call_with_timeout(
self.stub.ListGridpoolTrades,
electricity_trading_pb2.ListGridpoolTradesRequest(
gridpool_id=gridpool_id,
filter=gridpool_trade_filter.to_pb(),
pagination_params=pagination_params.to_proto(),
while True:
try:
response = await cast(
Awaitable[electricity_trading_pb2.ListGridpoolTradesResponse],
grpc_call_with_timeout(
self.stub.ListGridpoolTrades,
request,
metadata=self._metadata,
timeout=timeout,
),
metadata=self._metadata,
timeout=timeout,
),
)
)

return [Trade.from_pb(trade) for trade in response.trades]
except grpc.RpcError as e:
_logger.exception("Error occurred while listing gridpool trades: %s", e)
raise
for trade in response.trades:
yield Trade.from_pb(trade)

if response.pagination_info.next_page_token:
request.pagination_params.CopyFrom(
PaginationParams(
page_token=response.pagination_info.next_page_token
)
)
else:
break

except grpc.RpcError as e:
_logger.exception("Error occurred while listing gridpool trades: %s", e)
raise

async def list_public_trades(
# pylint: disable=too-many-arguments, too-many-positional-arguments
Expand All @@ -925,24 +933,22 @@ async def list_public_trades(
delivery_period: DeliveryPeriod | None = None,
buy_delivery_area: DeliveryArea | None = None,
sell_delivery_area: DeliveryArea | None = None,
max_nr_trades: int | None = None,
page_token: str | None = None,
page_size: int | None = None,
timeout: timedelta | None = None,
) -> list[PublicTrade]:
) -> AsyncIterator[PublicTrade]:
"""
List all executed public orders with optional filters.
List all executed public orders with optional filters and pagination.

Args:
states: List of order states to filter by.
delivery_period: The delivery period to filter by.
buy_delivery_area: The buy delivery area to filter by.
sell_delivery_area: The sell delivery area to filter by.
max_nr_trades: The maximum number of trades to return.
page_token: The page token to use for pagination.
page_size: The number of public trades to return per page.
timeout: Timeout duration, defaults to None.

Returns:
The list of public trades.
Yields:
The list of public trades for each page.

Raises:
grpc.RpcError: If an error occurs while listing public trades.
Expand All @@ -954,29 +960,37 @@ async def list_public_trades(
sell_delivery_area=sell_delivery_area,
)

pagination_params = Params(
page_size=max_nr_trades,
page_token=page_token,
request = electricity_trading_pb2.ListPublicTradesRequest(
filter=public_trade_filter.to_pb(),
pagination_params=(
Params(page_size=page_size).to_proto() if page_size else None
),
)

try:
response = await cast(
Awaitable[electricity_trading_pb2.ListPublicTradesResponse],
grpc_call_with_timeout(
self.stub.ListPublicTrades,
electricity_trading_pb2.ListPublicTradesRequest(
filter=public_trade_filter.to_pb(),
pagination_params=pagination_params.to_proto(),
while True:
try:
response = await cast(
Awaitable[electricity_trading_pb2.ListPublicTradesResponse],
grpc_call_with_timeout(
self.stub.ListPublicTrades,
request,
metadata=self._metadata,
timeout=timeout,
),
metadata=self._metadata,
timeout=timeout,
),
)
)

return [
PublicTrade.from_pb(public_trade)
for public_trade in response.public_trades
]
except grpc.RpcError as e:
_logger.exception("Error occurred while listing public trades: %s", e)
raise
for public_trade in response.public_trades:
yield PublicTrade.from_pb(public_trade)

if response.pagination_info.next_page_token:
request.pagination_params.CopyFrom(
PaginationParams(
page_token=response.pagination_info.next_page_token
)
)
else:
break

except grpc.RpcError as e:
_logger.exception("Error occurred while listing public trades: %s", e)
raise
Loading
Loading