Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* Extra validation check to ensure the quantity and price are within the allowed bounds.
* Add more edge cases to the integration tests.
* Add idiomatic string representations for `Power` and `Price` classes.
* Add support for timeouts in the gRPC function calls

## Bug Fixes

Expand Down
97 changes: 81 additions & 16 deletions src/frequenz/client/electricity_trading/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

from __future__ import annotations

import asyncio
import logging
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from decimal import Decimal, InvalidOperation
from typing import Awaitable, cast
from typing import Any, Awaitable, Callable, cast

import grpc

Expand Down Expand Up @@ -91,6 +92,38 @@ def validate_decimal_places(value: Decimal, decimal_places: int, name: str) -> N
) from exc


async def grpc_call_with_timeout(
call: Callable[..., Awaitable[Any]],
*args: Any,
timeout: timedelta | None = None,
**kwargs: Any,
) -> Any:
"""
Call a gRPC function with a timeout (in seconds).

Args:
call: The gRPC method to be called.
*args: Positional arguments for the gRPC call.
timeout: Timeout duration, defaults to None.
**kwargs: Keyword arguments for the gRPC call.

Returns:
The result of the gRPC call.

Raises:
asyncio.TimeoutError: If the call exceeds the timeout.
"""
if timeout is None:
return await call(*args, **kwargs)
try:
return await asyncio.wait_for(
call(*args, **kwargs), timeout=timeout.total_seconds()
)
except asyncio.TimeoutError:
_logger.exception("Timeout while calling %s", call)
raise


class Client(BaseApiClient[ElectricityTradingServiceStub]):
"""Electricity trading client."""

Expand Down Expand Up @@ -447,6 +480,7 @@ async def create_gridpool_order(
valid_until: datetime | None = None,
payload: dict[str, struct_pb2.Value] | None = None,
tag: str | None = None,
timeout: timedelta | None = None,
) -> OrderDetail:
"""
Create a gridpool order.
Expand All @@ -466,6 +500,7 @@ async def create_gridpool_order(
valid_until: Valid until of the order.
payload: Payload of the order.
tag: Tag of the order.
timeout: Timeout duration, defaults to None.

Returns:
The created order.
Expand Down Expand Up @@ -503,12 +538,13 @@ async def create_gridpool_order(
try:
response = await cast(
Awaitable[electricity_trading_pb2.CreateGridpoolOrderResponse],
self.stub.CreateGridpoolOrder(
grpc_call_with_timeout(
self.stub.CreateGridpoolOrder,
electricity_trading_pb2.CreateGridpoolOrderRequest(
gridpool_id=gridpool_id,
order=order.to_pb(),
gridpool_id=gridpool_id, order=order.to_pb()
),
metadata=self._metadata,
timeout=timeout,
),
)
except grpc.RpcError as e:
Expand All @@ -531,6 +567,7 @@ async def update_gridpool_order(
valid_until: datetime | None | _Sentinel = NO_VALUE,
payload: dict[str, struct_pb2.Value] | None | _Sentinel = NO_VALUE,
tag: str | None | _Sentinel = NO_VALUE,
timeout: timedelta | None = None,
) -> OrderDetail:
"""
Update an existing order for a given Gridpool.
Expand All @@ -553,6 +590,7 @@ async def update_gridpool_order(
payload: Updated user-defined payload individual to a specific order. This can be any
data that the user wants to associate with the order.
tag: Updated user-defined tag to group related orders.
timeout: Timeout duration, defaults to None.

Returns:
The updated order.
Expand Down Expand Up @@ -615,14 +653,16 @@ async def update_gridpool_order(
try:
response = await cast(
Awaitable[electricity_trading_pb2.UpdateGridpoolOrderResponse],
self.stub.UpdateGridpoolOrder(
grpc_call_with_timeout(
self.stub.UpdateGridpoolOrder,
electricity_trading_pb2.UpdateGridpoolOrderRequest(
gridpool_id=gridpool_id,
order_id=order_id,
update_order_fields=update_order_fields.to_pb(),
update_mask=update_mask,
),
metadata=self._metadata,
timeout=timeout,
),
)
return OrderDetail.from_pb(response.order_detail)
Expand All @@ -632,14 +672,15 @@ async def update_gridpool_order(
raise

async def cancel_gridpool_order(
self, gridpool_id: int, order_id: int
self, gridpool_id: int, order_id: int, timeout: timedelta | None = None
) -> OrderDetail:
"""
Cancel a single order for a given Gridpool.

Args:
gridpool_id: The Gridpool to cancel the order for.
order_id: The order to cancel.
timeout: Timeout duration, defaults to None.

Returns:
The cancelled order.
Expand All @@ -650,24 +691,29 @@ async def cancel_gridpool_order(
try:
response = await cast(
Awaitable[electricity_trading_pb2.CancelGridpoolOrderResponse],
self.stub.CancelGridpoolOrder(
grpc_call_with_timeout(
self.stub.CancelGridpoolOrder,
electricity_trading_pb2.CancelGridpoolOrderRequest(
gridpool_id=gridpool_id, order_id=order_id
),
metadata=self._metadata,
timeout=timeout,
),
)
return OrderDetail.from_pb(response.order_detail)
except grpc.RpcError as e:
_logger.exception("Error occurred while cancelling gridpool order: %s", e)
raise

async def cancel_all_gridpool_orders(self, gridpool_id: int) -> int:
async def cancel_all_gridpool_orders(
self, gridpool_id: int, timeout: timedelta | None = None
) -> int:
"""
Cancel all orders for a specific Gridpool.

Args:
gridpool_id: The Gridpool to cancel the orders for.
timeout: Timeout duration, defaults to None.

Returns:
The ID of the Gridpool for which the orders were cancelled.
Expand All @@ -678,11 +724,13 @@ async def cancel_all_gridpool_orders(self, gridpool_id: int) -> int:
try:
response = await cast(
Awaitable[electricity_trading_pb2.CancelAllGridpoolOrdersResponse],
self.stub.CancelAllGridpoolOrders(
grpc_call_with_timeout(
self.stub.CancelAllGridpoolOrders,
electricity_trading_pb2.CancelAllGridpoolOrdersRequest(
gridpool_id=gridpool_id
),
metadata=self._metadata,
timeout=timeout,
),
)

Expand All @@ -693,13 +741,16 @@ async def cancel_all_gridpool_orders(self, gridpool_id: int) -> int:
)
raise

async def get_gridpool_order(self, gridpool_id: int, order_id: int) -> OrderDetail:
async def get_gridpool_order(
self, gridpool_id: int, order_id: int, timeout: timedelta | None = None
) -> OrderDetail:
"""
Get a single order from a given gridpool.

Args:
gridpool_id: The Gridpool to retrieve the order for.
order_id: The order to retrieve.
timeout: Timeout duration, defaults to None.

Returns:
The order.
Expand All @@ -710,11 +761,13 @@ async def get_gridpool_order(self, gridpool_id: int, order_id: int) -> OrderDeta
try:
response = await cast(
Awaitable[electricity_trading_pb2.GetGridpoolOrderResponse],
self.stub.GetGridpoolOrder(
grpc_call_with_timeout(
self.stub.GetGridpoolOrder,
electricity_trading_pb2.GetGridpoolOrderRequest(
gridpool_id=gridpool_id, order_id=order_id
),
metadata=self._metadata,
timeout=timeout,
),
)

Expand All @@ -724,7 +777,7 @@ async def get_gridpool_order(self, gridpool_id: int, order_id: int) -> OrderDeta
raise

async def list_gridpool_orders(
# pylint: disable=too-many-arguments, too-many-positional-arguments
# pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals
self,
gridpool_id: int,
order_states: list[OrderState] | None = None,
Expand All @@ -734,6 +787,7 @@ async def list_gridpool_orders(
tag: str | None = None,
max_nr_orders: int | None = None,
page_token: str | None = None,
timeout: timedelta | None = None,
) -> list[OrderDetail]:
"""
List orders for a specific Gridpool with optional filters.
Expand All @@ -747,6 +801,7 @@ async def list_gridpool_orders(
tag: The tag to filter by.
max_nr_orders: The maximum number of orders to return.
page_token: The page token to use for pagination.
timeout: Timeout duration, defaults to None.

Returns:
The list of orders for that gridpool.
Expand All @@ -770,13 +825,15 @@ async def list_gridpool_orders(
try:
response = await cast(
Awaitable[electricity_trading_pb2.ListGridpoolOrdersResponse],
self.stub.ListGridpoolOrders(
grpc_call_with_timeout(
self.stub.ListGridpoolOrders,
electricity_trading_pb2.ListGridpoolOrdersRequest(
gridpool_id=gridpool_id,
filter=gridpool_order_filer.to_pb(),
pagination_params=pagination_params.to_proto(),
),
metadata=self._metadata,
timeout=timeout,
),
)

Expand Down Expand Up @@ -806,6 +863,7 @@ async def list_gridpool_trades(
delivery_area: DeliveryArea | None = None,
max_nr_trades: int | None = None,
page_token: str | None = None,
timeout: timedelta | None = None,
) -> list[Trade]:
"""
List trades for a specific Gridpool with optional filters.
Expand All @@ -819,6 +877,7 @@ async def list_gridpool_trades(
delivery_area: The delivery area to filter by.
max_nr_trades: The maximum number of trades to return.
page_token: The page token to use for pagination.
timeout: Timeout duration, defaults to None.

Returns:
The list of trades for the given gridpool.
Expand All @@ -842,13 +901,15 @@ async def list_gridpool_trades(
try:
response = await cast(
Awaitable[electricity_trading_pb2.ListGridpoolTradesResponse],
self.stub.ListGridpoolTrades(
grpc_call_with_timeout(
self.stub.ListGridpoolTrades,
electricity_trading_pb2.ListGridpoolTradesRequest(
gridpool_id=gridpool_id,
filter=gridpool_trade_filter.to_pb(),
pagination_params=pagination_params.to_proto(),
),
metadata=self._metadata,
timeout=timeout,
),
)

Expand All @@ -866,6 +927,7 @@ async def list_public_trades(
sell_delivery_area: DeliveryArea | None = None,
max_nr_trades: int | None = None,
page_token: str | None = None,
timeout: timedelta | None = None,
) -> list[PublicTrade]:
"""
List all executed public orders with optional filters.
Expand All @@ -877,6 +939,7 @@ async def list_public_trades(
sell_delivery_area: The sell delivery area to filter by.
max_nr_trades: The maximum number of trades to return.
page_token: The page token to use for pagination.
timeout: Timeout duration, defaults to None.

Returns:
The list of public trades.
Expand All @@ -899,12 +962,14 @@ async def list_public_trades(
try:
response = await cast(
Awaitable[electricity_trading_pb2.ListPublicTradesResponse],
self.stub.ListPublicTrades(
grpc_call_with_timeout(
self.stub.ListPublicTrades,
electricity_trading_pb2.ListPublicTradesRequest(
filter=public_trade_filter.to_pb(),
pagination_params=pagination_params.to_proto(),
),
metadata=self._metadata,
timeout=timeout,
),
)

Expand Down
Loading