|
5 | 5 |
|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
| 8 | +import asyncio |
8 | 9 | import logging |
9 | 10 | from datetime import datetime, timezone |
10 | 11 | from decimal import Decimal, InvalidOperation |
11 | | -from typing import Awaitable, cast |
| 12 | +from typing import Any, Awaitable, Callable, cast |
12 | 13 |
|
13 | 14 | import grpc |
14 | 15 |
|
@@ -87,6 +88,31 @@ def validate_decimal_places(value: Decimal, decimal_places: int, name: str) -> N |
87 | 88 | ) from exc |
88 | 89 |
|
89 | 90 |
|
| 91 | +async def grpc_call_with_timeout( |
| 92 | + call: Callable[..., Awaitable[Any]], *args: Any, timeout: float = 300, **kwargs: Any |
| 93 | +) -> Any: |
| 94 | + """ |
| 95 | + Call a gRPC function with a timeout (in seconds). |
| 96 | +
|
| 97 | + Args: |
| 98 | + call: The gRPC method to be called. |
| 99 | + *args: Positional arguments for the gRPC call. |
| 100 | + timeout: Timeout duration in seconds. Defaults to 300. |
| 101 | + **kwargs: Keyword arguments for the gRPC call. |
| 102 | +
|
| 103 | + Returns: |
| 104 | + The result of the gRPC call. |
| 105 | +
|
| 106 | + Raises: |
| 107 | + asyncio.TimeoutError: If the call exceeds the timeout. |
| 108 | + """ |
| 109 | + try: |
| 110 | + return await asyncio.wait_for(call(*args, **kwargs), timeout=timeout) |
| 111 | + except asyncio.TimeoutError: |
| 112 | + _logger.error("Timeout while calling %s", call) |
| 113 | + raise |
| 114 | + |
| 115 | + |
90 | 116 | class Client(BaseApiClient[ElectricityTradingServiceStub]): |
91 | 117 | """Electricity trading client.""" |
92 | 118 |
|
@@ -493,10 +519,10 @@ async def create_gridpool_order( |
493 | 519 | try: |
494 | 520 | response = await cast( |
495 | 521 | Awaitable[electricity_trading_pb2.CreateGridpoolOrderResponse], |
496 | | - self.stub.CreateGridpoolOrder( |
| 522 | + grpc_call_with_timeout( |
| 523 | + self.stub.CreateGridpoolOrder, |
497 | 524 | electricity_trading_pb2.CreateGridpoolOrderRequest( |
498 | | - gridpool_id=gridpool_id, |
499 | | - order=order.to_pb(), |
| 525 | + gridpool_id=gridpool_id, order=order.to_pb() |
500 | 526 | ), |
501 | 527 | metadata=self._metadata, |
502 | 528 | ), |
@@ -605,7 +631,8 @@ async def update_gridpool_order( |
605 | 631 | try: |
606 | 632 | response = await cast( |
607 | 633 | Awaitable[electricity_trading_pb2.UpdateGridpoolOrderResponse], |
608 | | - self.stub.UpdateGridpoolOrder( |
| 634 | + grpc_call_with_timeout( |
| 635 | + self.stub.UpdateGridpoolOrder, |
609 | 636 | electricity_trading_pb2.UpdateGridpoolOrderRequest( |
610 | 637 | gridpool_id=gridpool_id, |
611 | 638 | order_id=order_id, |
@@ -640,7 +667,8 @@ async def cancel_gridpool_order( |
640 | 667 | try: |
641 | 668 | response = await cast( |
642 | 669 | Awaitable[electricity_trading_pb2.CancelGridpoolOrderResponse], |
643 | | - self.stub.CancelGridpoolOrder( |
| 670 | + grpc_call_with_timeout( |
| 671 | + self.stub.CancelGridpoolOrder, |
644 | 672 | electricity_trading_pb2.CancelGridpoolOrderRequest( |
645 | 673 | gridpool_id=gridpool_id, order_id=order_id |
646 | 674 | ), |
@@ -668,7 +696,8 @@ async def cancel_all_gridpool_orders(self, gridpool_id: int) -> int: |
668 | 696 | try: |
669 | 697 | response = await cast( |
670 | 698 | Awaitable[electricity_trading_pb2.CancelAllGridpoolOrdersResponse], |
671 | | - self.stub.CancelAllGridpoolOrders( |
| 699 | + grpc_call_with_timeout( |
| 700 | + self.stub.CancelAllGridpoolOrders, |
672 | 701 | electricity_trading_pb2.CancelAllGridpoolOrdersRequest( |
673 | 702 | gridpool_id=gridpool_id |
674 | 703 | ), |
@@ -700,7 +729,8 @@ async def get_gridpool_order(self, gridpool_id: int, order_id: int) -> OrderDeta |
700 | 729 | try: |
701 | 730 | response = await cast( |
702 | 731 | Awaitable[electricity_trading_pb2.GetGridpoolOrderResponse], |
703 | | - self.stub.GetGridpoolOrder( |
| 732 | + grpc_call_with_timeout( |
| 733 | + self.stub.GetGridpoolOrder, |
704 | 734 | electricity_trading_pb2.GetGridpoolOrderRequest( |
705 | 735 | gridpool_id=gridpool_id, order_id=order_id |
706 | 736 | ), |
@@ -760,7 +790,8 @@ async def list_gridpool_orders( |
760 | 790 | try: |
761 | 791 | response = await cast( |
762 | 792 | Awaitable[electricity_trading_pb2.ListGridpoolOrdersResponse], |
763 | | - self.stub.ListGridpoolOrders( |
| 793 | + grpc_call_with_timeout( |
| 794 | + self.stub.ListGridpoolOrders, |
764 | 795 | electricity_trading_pb2.ListGridpoolOrdersRequest( |
765 | 796 | gridpool_id=gridpool_id, |
766 | 797 | filter=gridpool_order_filer.to_pb(), |
@@ -832,7 +863,8 @@ async def list_gridpool_trades( |
832 | 863 | try: |
833 | 864 | response = await cast( |
834 | 865 | Awaitable[electricity_trading_pb2.ListGridpoolTradesResponse], |
835 | | - self.stub.ListGridpoolTrades( |
| 866 | + grpc_call_with_timeout( |
| 867 | + self.stub.ListGridpoolTrades, |
836 | 868 | electricity_trading_pb2.ListGridpoolTradesRequest( |
837 | 869 | gridpool_id=gridpool_id, |
838 | 870 | filter=gridpool_trade_filter.to_pb(), |
@@ -889,7 +921,8 @@ async def list_public_trades( |
889 | 921 | try: |
890 | 922 | response = await cast( |
891 | 923 | Awaitable[electricity_trading_pb2.ListPublicTradesResponse], |
892 | | - self.stub.ListPublicTrades( |
| 924 | + grpc_call_with_timeout( |
| 925 | + self.stub.ListPublicTrades, |
893 | 926 | electricity_trading_pb2.ListPublicTradesRequest( |
894 | 927 | filter=public_trade_filter.to_pb(), |
895 | 928 | pagination_params=pagination_params.to_proto(), |
|
0 commit comments