Skip to content

Commit 59b0c59

Browse files
camille-bouvy-frequenzcambouvy
authored andcommitted
Add a timeout in the gRPC function calls
Signed-off-by: camille-bouvy-frequenz <[email protected]>
1 parent 5da75cf commit 59b0c59

File tree

1 file changed

+120
-14
lines changed

1 file changed

+120
-14
lines changed

src/frequenz/client/electricity_trading/_client.py

Lines changed: 120 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33

44
"""Module to define the client class."""
55

6+
# pylint: disable=too-many-lines
7+
68
from __future__ import annotations
79

10+
import asyncio
811
import logging
912
from datetime import datetime, timezone
1013
from decimal import Decimal, InvalidOperation
11-
from typing import Awaitable, cast
14+
from typing import Any, Awaitable, Callable, cast
1215

1316
import grpc
1417

@@ -87,7 +90,9 @@ def validate_decimal_places(value: Decimal, decimal_places: int, name: str) -> N
8790
) from exc
8891

8992

90-
class Client(BaseApiClient[ElectricityTradingServiceStub]):
93+
class Client(
94+
BaseApiClient[ElectricityTradingServiceStub]
95+
): # pylint:disable=too-many-instance-attributes
9196
"""Electricity trading client."""
9297

9398
_instances: dict[tuple[str, str | None], "Client"] = {}
@@ -116,15 +121,28 @@ def __new__(
116121

117122
return cls._instances[key]
118123

119-
def __init__(
120-
self, server_url: str, connect: bool = True, auth_key: str | None = None
124+
def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments
125+
self,
126+
server_url: str,
127+
connect: bool = True,
128+
auth_key: str | None = None,
129+
initial_timeout: int = 10,
130+
max_timeout: int = 300,
131+
timeout_increment: int = 20,
132+
max_timeout_retries: int = 30,
133+
reset_interval: int = 600,
121134
) -> None:
122135
"""Initialize the client.
123136
124137
Args:
125138
server_url: The URL of the Electricity Trading service.
126139
connect: Whether to connect to the server immediately.
127140
auth_key: The API key for the authorization.
141+
initial_timeout: Initial timeout duration for gRPC calls (in seconds).
142+
max_timeout: Maximum timeout duration for gRPC calls (in seconds).
143+
timeout_increment: Increment value for timeout on each retry (in seconds).
144+
max_timeout_retries: Maximum number of retry attempts when a timeout is reached.
145+
reset_interval: Time (in seconds) before resetting timeout if everything behaves fine.
128146
"""
129147
if not hasattr(
130148
self, "_initialized"
@@ -159,6 +177,87 @@ def __init__(
159177

160178
self._metadata = (("key", auth_key),) if auth_key else ()
161179

180+
# Timeout configuration
181+
self._initial_timeout = initial_timeout
182+
self._timeout = initial_timeout
183+
self._timeout_increment = timeout_increment
184+
self._max_timeout = max_timeout
185+
self._max_timeout_retries = max_timeout_retries
186+
self._reset_interval = reset_interval
187+
self._last_success_time = datetime.now(timezone.utc)
188+
189+
def _increase_timeout(self) -> None:
190+
"""Increase the global timeout within limits."""
191+
if self._timeout < self._max_timeout:
192+
self._timeout = min(
193+
self._max_timeout, self._timeout + self._timeout_increment
194+
)
195+
_logger.info("Increased timeout to %ds.", self._timeout)
196+
else:
197+
_logger.warning(
198+
"Timeout is already at the maximum (%ds). No further increments.",
199+
self._max_timeout,
200+
)
201+
202+
def _reset_timeout_if_needed(self) -> None:
203+
"""Reset the timeout to the initial value if the reset interval has passed."""
204+
now = datetime.now(timezone.utc)
205+
if (now - self._last_success_time).total_seconds() >= self._reset_interval:
206+
self._timeout = self._initial_timeout
207+
_logger.info(
208+
"Timeout reset to initial value %ds after %ds seconds without timeout errors.",
209+
self._initial_timeout,
210+
self._reset_interval,
211+
)
212+
213+
async def grpc_call_with_timeout(
214+
self,
215+
call: Callable[..., Awaitable[Any]],
216+
*args: Any,
217+
**kwargs: Any,
218+
) -> Any:
219+
"""
220+
Call a gRPC function with a timeout.
221+
222+
This timeout is increased dynamically on every none-successful reconnection attempts.
223+
224+
Args:
225+
call: The gRPC method to be called.
226+
*args: Positional arguments for the gRPC call.
227+
**kwargs: Keyword arguments for the gRPC call.
228+
229+
Returns:
230+
The result of the gRPC call.
231+
232+
Raises:
233+
asyncio.TimeoutError: If all retries are exhausted.
234+
"""
235+
for attempt in range(1, self._max_timeout_retries + 1):
236+
try:
237+
self._reset_timeout_if_needed()
238+
result = await asyncio.wait_for(
239+
call(*args, **kwargs), timeout=self._timeout
240+
)
241+
# Update last success time on successful call
242+
self._last_success_time = datetime.now(timezone.utc)
243+
return result
244+
except asyncio.TimeoutError:
245+
if attempt == self._max_timeout_retries:
246+
_logger.error(
247+
"Timeout after %d retries (timeout=%ds): %s",
248+
self._max_timeout_retries,
249+
self._timeout,
250+
call,
251+
)
252+
raise
253+
_logger.warning(
254+
"Timeout on attempt %d/%d (timeout=%ds). Retrying with increased timeout.",
255+
attempt,
256+
self._max_timeout_retries,
257+
self._timeout,
258+
)
259+
self._increase_timeout()
260+
162261
@property
163262
def stub(self) -> electricity_trading_pb2_grpc.ElectricityTradingServiceAsyncStub:
164263
"""
@@ -493,10 +592,10 @@ async def create_gridpool_order(
493592
try:
494593
response = await cast(
495594
Awaitable[electricity_trading_pb2.CreateGridpoolOrderResponse],
496-
self.stub.CreateGridpoolOrder(
595+
self.grpc_call_with_timeout(
596+
self.stub.CreateGridpoolOrder,
497597
electricity_trading_pb2.CreateGridpoolOrderRequest(
498-
gridpool_id=gridpool_id,
499-
order=order.to_pb(),
598+
gridpool_id=gridpool_id, order=order.to_pb()
500599
),
501600
metadata=self._metadata,
502601
),
@@ -605,7 +704,8 @@ async def update_gridpool_order(
605704
try:
606705
response = await cast(
607706
Awaitable[electricity_trading_pb2.UpdateGridpoolOrderResponse],
608-
self.stub.UpdateGridpoolOrder(
707+
self.grpc_call_with_timeout(
708+
self.stub.UpdateGridpoolOrder,
609709
electricity_trading_pb2.UpdateGridpoolOrderRequest(
610710
gridpool_id=gridpool_id,
611711
order_id=order_id,
@@ -640,7 +740,8 @@ async def cancel_gridpool_order(
640740
try:
641741
response = await cast(
642742
Awaitable[electricity_trading_pb2.CancelGridpoolOrderResponse],
643-
self.stub.CancelGridpoolOrder(
743+
self.grpc_call_with_timeout(
744+
self.stub.CancelGridpoolOrder,
644745
electricity_trading_pb2.CancelGridpoolOrderRequest(
645746
gridpool_id=gridpool_id, order_id=order_id
646747
),
@@ -668,7 +769,8 @@ async def cancel_all_gridpool_orders(self, gridpool_id: int) -> int:
668769
try:
669770
response = await cast(
670771
Awaitable[electricity_trading_pb2.CancelAllGridpoolOrdersResponse],
671-
self.stub.CancelAllGridpoolOrders(
772+
self.grpc_call_with_timeout(
773+
self.stub.CancelAllGridpoolOrders,
672774
electricity_trading_pb2.CancelAllGridpoolOrdersRequest(
673775
gridpool_id=gridpool_id
674776
),
@@ -700,7 +802,8 @@ async def get_gridpool_order(self, gridpool_id: int, order_id: int) -> OrderDeta
700802
try:
701803
response = await cast(
702804
Awaitable[electricity_trading_pb2.GetGridpoolOrderResponse],
703-
self.stub.GetGridpoolOrder(
805+
self.grpc_call_with_timeout(
806+
self.stub.GetGridpoolOrder,
704807
electricity_trading_pb2.GetGridpoolOrderRequest(
705808
gridpool_id=gridpool_id, order_id=order_id
706809
),
@@ -760,7 +863,8 @@ async def list_gridpool_orders(
760863
try:
761864
response = await cast(
762865
Awaitable[electricity_trading_pb2.ListGridpoolOrdersResponse],
763-
self.stub.ListGridpoolOrders(
866+
self.grpc_call_with_timeout(
867+
self.stub.ListGridpoolOrders,
764868
electricity_trading_pb2.ListGridpoolOrdersRequest(
765869
gridpool_id=gridpool_id,
766870
filter=gridpool_order_filer.to_pb(),
@@ -832,7 +936,8 @@ async def list_gridpool_trades(
832936
try:
833937
response = await cast(
834938
Awaitable[electricity_trading_pb2.ListGridpoolTradesResponse],
835-
self.stub.ListGridpoolTrades(
939+
self.grpc_call_with_timeout(
940+
self.stub.ListGridpoolTrades,
836941
electricity_trading_pb2.ListGridpoolTradesRequest(
837942
gridpool_id=gridpool_id,
838943
filter=gridpool_trade_filter.to_pb(),
@@ -889,7 +994,8 @@ async def list_public_trades(
889994
try:
890995
response = await cast(
891996
Awaitable[electricity_trading_pb2.ListPublicTradesResponse],
892-
self.stub.ListPublicTrades(
997+
self.grpc_call_with_timeout(
998+
self.stub.ListPublicTrades,
893999
electricity_trading_pb2.ListPublicTradesRequest(
8941000
filter=public_trade_filter.to_pb(),
8951001
pagination_params=pagination_params.to_proto(),

0 commit comments

Comments
 (0)