|
3 | 3 |
|
4 | 4 | """Module to define the client class.""" |
5 | 5 |
|
| 6 | +# pylint: disable=too-many-lines |
| 7 | + |
6 | 8 | from __future__ import annotations |
7 | 9 |
|
| 10 | +import asyncio |
8 | 11 | import logging |
9 | 12 | from datetime import datetime, timezone |
10 | 13 | from decimal import Decimal, InvalidOperation |
11 | | -from typing import Awaitable, cast |
| 14 | +from typing import Any, Awaitable, Callable, cast |
12 | 15 |
|
13 | 16 | import grpc |
14 | 17 |
|
@@ -87,7 +90,9 @@ def validate_decimal_places(value: Decimal, decimal_places: int, name: str) -> N |
87 | 90 | ) from exc |
88 | 91 |
|
89 | 92 |
|
90 | | -class Client(BaseApiClient[ElectricityTradingServiceStub]): |
| 93 | +class Client( |
| 94 | + BaseApiClient[ElectricityTradingServiceStub] |
| 95 | +): # pylint:disable=too-many-instance-attributes |
91 | 96 | """Electricity trading client.""" |
92 | 97 |
|
93 | 98 | _instances: dict[tuple[str, str | None], "Client"] = {} |
@@ -116,15 +121,28 @@ def __new__( |
116 | 121 |
|
117 | 122 | return cls._instances[key] |
118 | 123 |
|
119 | | - def __init__( |
120 | | - self, server_url: str, connect: bool = True, auth_key: str | None = None |
| 124 | + def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments |
| 125 | + self, |
| 126 | + server_url: str, |
| 127 | + connect: bool = True, |
| 128 | + auth_key: str | None = None, |
| 129 | + initial_timeout: int = 10, |
| 130 | + max_timeout: int = 300, |
| 131 | + timeout_increment: int = 20, |
| 132 | + max_timeout_retries: int = 30, |
| 133 | + reset_interval: int = 600, |
121 | 134 | ) -> None: |
122 | 135 | """Initialize the client. |
123 | 136 |
|
124 | 137 | Args: |
125 | 138 | server_url: The URL of the Electricity Trading service. |
126 | 139 | connect: Whether to connect to the server immediately. |
127 | 140 | auth_key: The API key for the authorization. |
| 141 | + initial_timeout: Initial timeout duration for gRPC calls (in seconds). |
| 142 | + max_timeout: Maximum timeout duration for gRPC calls (in seconds). |
| 143 | + timeout_increment: Increment value for timeout on each retry (in seconds). |
| 144 | + max_timeout_retries: Maximum number of retry attempts when a timeout is reached. |
| 145 | + reset_interval: Time (in seconds) before resetting timeout if everything behaves fine. |
128 | 146 | """ |
129 | 147 | if not hasattr( |
130 | 148 | self, "_initialized" |
@@ -159,6 +177,87 @@ def __init__( |
159 | 177 |
|
160 | 178 | self._metadata = (("key", auth_key),) if auth_key else () |
161 | 179 |
|
| 180 | + # Timeout configuration |
| 181 | + self._initial_timeout = initial_timeout |
| 182 | + self._timeout = initial_timeout |
| 183 | + self._timeout_increment = timeout_increment |
| 184 | + self._max_timeout = max_timeout |
| 185 | + self._max_timeout_retries = max_timeout_retries |
| 186 | + self._reset_interval = reset_interval |
| 187 | + self._last_success_time = datetime.now(timezone.utc) |
| 188 | + |
| 189 | + def _increase_timeout(self) -> None: |
| 190 | + """Increase the global timeout within limits.""" |
| 191 | + if self._timeout < self._max_timeout: |
| 192 | + self._timeout = min( |
| 193 | + self._max_timeout, self._timeout + self._timeout_increment |
| 194 | + ) |
| 195 | + _logger.info("Increased timeout to %ds.", self._timeout) |
| 196 | + else: |
| 197 | + _logger.warning( |
| 198 | + "Timeout is already at the maximum (%ds). No further increments.", |
| 199 | + self._max_timeout, |
| 200 | + ) |
| 201 | + |
| 202 | + def _reset_timeout_if_needed(self) -> None: |
| 203 | + """Reset the timeout to the initial value if the reset interval has passed.""" |
| 204 | + now = datetime.now(timezone.utc) |
| 205 | + if (now - self._last_success_time).total_seconds() >= self._reset_interval: |
| 206 | + self._timeout = self._initial_timeout |
| 207 | + _logger.info( |
| 208 | + "Timeout reset to initial value %ds after %ds seconds without timeout errors.", |
| 209 | + self._initial_timeout, |
| 210 | + self._reset_interval, |
| 211 | + ) |
| 212 | + |
| 213 | + async def grpc_call_with_timeout( |
| 214 | + self, |
| 215 | + call: Callable[..., Awaitable[Any]], |
| 216 | + *args: Any, |
| 217 | + **kwargs: Any, |
| 218 | + ) -> Any: |
| 219 | + """ |
| 220 | + Call a gRPC function with a timeout. |
| 221 | +
|
| 222 | + This timeout is increased dynamically on every none-successful reconnection attempts. |
| 223 | +
|
| 224 | + Args: |
| 225 | + call: The gRPC method to be called. |
| 226 | + *args: Positional arguments for the gRPC call. |
| 227 | + **kwargs: Keyword arguments for the gRPC call. |
| 228 | +
|
| 229 | + Returns: |
| 230 | + The result of the gRPC call. |
| 231 | +
|
| 232 | + Raises: |
| 233 | + asyncio.TimeoutError: If all retries are exhausted. |
| 234 | + """ |
| 235 | + for attempt in range(1, self._max_timeout_retries + 1): |
| 236 | + try: |
| 237 | + self._reset_timeout_if_needed() |
| 238 | + result = await asyncio.wait_for( |
| 239 | + call(*args, **kwargs), timeout=self._timeout |
| 240 | + ) |
| 241 | + # Update last success time on successful call |
| 242 | + self._last_success_time = datetime.now(timezone.utc) |
| 243 | + return result |
| 244 | + except asyncio.TimeoutError: |
| 245 | + if attempt == self._max_timeout_retries: |
| 246 | + _logger.error( |
| 247 | + "Timeout after %d retries (timeout=%ds): %s", |
| 248 | + self._max_timeout_retries, |
| 249 | + self._timeout, |
| 250 | + call, |
| 251 | + ) |
| 252 | + raise |
| 253 | + _logger.warning( |
| 254 | + "Timeout on attempt %d/%d (timeout=%ds). Retrying with increased timeout.", |
| 255 | + attempt, |
| 256 | + self._max_timeout_retries, |
| 257 | + self._timeout, |
| 258 | + ) |
| 259 | + self._increase_timeout() |
| 260 | + |
162 | 261 | @property |
163 | 262 | def stub(self) -> electricity_trading_pb2_grpc.ElectricityTradingServiceAsyncStub: |
164 | 263 | """ |
@@ -493,10 +592,10 @@ async def create_gridpool_order( |
493 | 592 | try: |
494 | 593 | response = await cast( |
495 | 594 | Awaitable[electricity_trading_pb2.CreateGridpoolOrderResponse], |
496 | | - self.stub.CreateGridpoolOrder( |
| 595 | + self.grpc_call_with_timeout( |
| 596 | + self.stub.CreateGridpoolOrder, |
497 | 597 | electricity_trading_pb2.CreateGridpoolOrderRequest( |
498 | | - gridpool_id=gridpool_id, |
499 | | - order=order.to_pb(), |
| 598 | + gridpool_id=gridpool_id, order=order.to_pb() |
500 | 599 | ), |
501 | 600 | metadata=self._metadata, |
502 | 601 | ), |
@@ -605,7 +704,8 @@ async def update_gridpool_order( |
605 | 704 | try: |
606 | 705 | response = await cast( |
607 | 706 | Awaitable[electricity_trading_pb2.UpdateGridpoolOrderResponse], |
608 | | - self.stub.UpdateGridpoolOrder( |
| 707 | + self.grpc_call_with_timeout( |
| 708 | + self.stub.UpdateGridpoolOrder, |
609 | 709 | electricity_trading_pb2.UpdateGridpoolOrderRequest( |
610 | 710 | gridpool_id=gridpool_id, |
611 | 711 | order_id=order_id, |
@@ -640,7 +740,8 @@ async def cancel_gridpool_order( |
640 | 740 | try: |
641 | 741 | response = await cast( |
642 | 742 | Awaitable[electricity_trading_pb2.CancelGridpoolOrderResponse], |
643 | | - self.stub.CancelGridpoolOrder( |
| 743 | + self.grpc_call_with_timeout( |
| 744 | + self.stub.CancelGridpoolOrder, |
644 | 745 | electricity_trading_pb2.CancelGridpoolOrderRequest( |
645 | 746 | gridpool_id=gridpool_id, order_id=order_id |
646 | 747 | ), |
@@ -668,7 +769,8 @@ async def cancel_all_gridpool_orders(self, gridpool_id: int) -> int: |
668 | 769 | try: |
669 | 770 | response = await cast( |
670 | 771 | Awaitable[electricity_trading_pb2.CancelAllGridpoolOrdersResponse], |
671 | | - self.stub.CancelAllGridpoolOrders( |
| 772 | + self.grpc_call_with_timeout( |
| 773 | + self.stub.CancelAllGridpoolOrders, |
672 | 774 | electricity_trading_pb2.CancelAllGridpoolOrdersRequest( |
673 | 775 | gridpool_id=gridpool_id |
674 | 776 | ), |
@@ -700,7 +802,8 @@ async def get_gridpool_order(self, gridpool_id: int, order_id: int) -> OrderDeta |
700 | 802 | try: |
701 | 803 | response = await cast( |
702 | 804 | Awaitable[electricity_trading_pb2.GetGridpoolOrderResponse], |
703 | | - self.stub.GetGridpoolOrder( |
| 805 | + self.grpc_call_with_timeout( |
| 806 | + self.stub.GetGridpoolOrder, |
704 | 807 | electricity_trading_pb2.GetGridpoolOrderRequest( |
705 | 808 | gridpool_id=gridpool_id, order_id=order_id |
706 | 809 | ), |
@@ -760,7 +863,8 @@ async def list_gridpool_orders( |
760 | 863 | try: |
761 | 864 | response = await cast( |
762 | 865 | Awaitable[electricity_trading_pb2.ListGridpoolOrdersResponse], |
763 | | - self.stub.ListGridpoolOrders( |
| 866 | + self.grpc_call_with_timeout( |
| 867 | + self.stub.ListGridpoolOrders, |
764 | 868 | electricity_trading_pb2.ListGridpoolOrdersRequest( |
765 | 869 | gridpool_id=gridpool_id, |
766 | 870 | filter=gridpool_order_filer.to_pb(), |
@@ -832,7 +936,8 @@ async def list_gridpool_trades( |
832 | 936 | try: |
833 | 937 | response = await cast( |
834 | 938 | Awaitable[electricity_trading_pb2.ListGridpoolTradesResponse], |
835 | | - self.stub.ListGridpoolTrades( |
| 939 | + self.grpc_call_with_timeout( |
| 940 | + self.stub.ListGridpoolTrades, |
836 | 941 | electricity_trading_pb2.ListGridpoolTradesRequest( |
837 | 942 | gridpool_id=gridpool_id, |
838 | 943 | filter=gridpool_trade_filter.to_pb(), |
@@ -889,7 +994,8 @@ async def list_public_trades( |
889 | 994 | try: |
890 | 995 | response = await cast( |
891 | 996 | Awaitable[electricity_trading_pb2.ListPublicTradesResponse], |
892 | | - self.stub.ListPublicTrades( |
| 997 | + self.grpc_call_with_timeout( |
| 998 | + self.stub.ListPublicTrades, |
893 | 999 | electricity_trading_pb2.ListPublicTradesRequest( |
894 | 1000 | filter=public_trade_filter.to_pb(), |
895 | 1001 | pagination_params=pagination_params.to_proto(), |
|
0 commit comments