55
66from __future__ import annotations
77
8+ import asyncio
89import logging
9- from datetime import datetime , timezone
10+ from datetime import datetime , timedelta , timezone
1011from decimal import Decimal , InvalidOperation
11- from typing import Awaitable , cast
12+ from typing import Any , Awaitable , Callable , cast
1213
1314import grpc
1415
@@ -91,6 +92,38 @@ def validate_decimal_places(value: Decimal, decimal_places: int, name: str) -> N
9192 ) from exc
9293
9394
95+ async def grpc_call_with_timeout (
96+ call : Callable [..., Awaitable [Any ]],
97+ * args : Any ,
98+ timeout : timedelta | None = None ,
99+ ** kwargs : Any ,
100+ ) -> Any :
101+ """
102+ Call a gRPC function with a timeout (in seconds).
103+
104+ Args:
105+ call: The gRPC method to be called.
106+ *args: Positional arguments for the gRPC call.
107+ timeout: Timeout duration, defaults to None.
108+ **kwargs: Keyword arguments for the gRPC call.
109+
110+ Returns:
111+ The result of the gRPC call.
112+
113+ Raises:
114+ asyncio.TimeoutError: If the call exceeds the timeout.
115+ """
116+ if timeout is None :
117+ return await call (* args , ** kwargs )
118+ try :
119+ return await asyncio .wait_for (
120+ call (* args , ** kwargs ), timeout = timeout .total_seconds ()
121+ )
122+ except asyncio .TimeoutError :
123+ _logger .exception ("Timeout while calling %s" , call )
124+ raise
125+
126+
94127class Client (BaseApiClient [ElectricityTradingServiceStub ]):
95128 """Electricity trading client."""
96129
@@ -447,6 +480,7 @@ async def create_gridpool_order(
447480 valid_until : datetime | None = None ,
448481 payload : dict [str , struct_pb2 .Value ] | None = None ,
449482 tag : str | None = None ,
483+ timeout : timedelta | None = None ,
450484 ) -> OrderDetail :
451485 """
452486 Create a gridpool order.
@@ -466,6 +500,7 @@ async def create_gridpool_order(
466500 valid_until: Valid until of the order.
467501 payload: Payload of the order.
468502 tag: Tag of the order.
503+ timeout: Timeout duration, defaults to None.
469504
470505 Returns:
471506 The created order.
@@ -503,12 +538,13 @@ async def create_gridpool_order(
503538 try :
504539 response = await cast (
505540 Awaitable [electricity_trading_pb2 .CreateGridpoolOrderResponse ],
506- self .stub .CreateGridpoolOrder (
541+ grpc_call_with_timeout (
542+ self .stub .CreateGridpoolOrder ,
507543 electricity_trading_pb2 .CreateGridpoolOrderRequest (
508- gridpool_id = gridpool_id ,
509- order = order .to_pb (),
544+ gridpool_id = gridpool_id , order = order .to_pb ()
510545 ),
511546 metadata = self ._metadata ,
547+ timeout = timeout ,
512548 ),
513549 )
514550 except grpc .RpcError as e :
@@ -531,6 +567,7 @@ async def update_gridpool_order(
531567 valid_until : datetime | None | _Sentinel = NO_VALUE ,
532568 payload : dict [str , struct_pb2 .Value ] | None | _Sentinel = NO_VALUE ,
533569 tag : str | None | _Sentinel = NO_VALUE ,
570+ timeout : timedelta | None = None ,
534571 ) -> OrderDetail :
535572 """
536573 Update an existing order for a given Gridpool.
@@ -553,6 +590,7 @@ async def update_gridpool_order(
553590 payload: Updated user-defined payload individual to a specific order. This can be any
554591 data that the user wants to associate with the order.
555592 tag: Updated user-defined tag to group related orders.
593+ timeout: Timeout duration, defaults to None.
556594
557595 Returns:
558596 The updated order.
@@ -615,14 +653,16 @@ async def update_gridpool_order(
615653 try :
616654 response = await cast (
617655 Awaitable [electricity_trading_pb2 .UpdateGridpoolOrderResponse ],
618- self .stub .UpdateGridpoolOrder (
656+ grpc_call_with_timeout (
657+ self .stub .UpdateGridpoolOrder ,
619658 electricity_trading_pb2 .UpdateGridpoolOrderRequest (
620659 gridpool_id = gridpool_id ,
621660 order_id = order_id ,
622661 update_order_fields = update_order_fields .to_pb (),
623662 update_mask = update_mask ,
624663 ),
625664 metadata = self ._metadata ,
665+ timeout = timeout ,
626666 ),
627667 )
628668 return OrderDetail .from_pb (response .order_detail )
@@ -632,14 +672,15 @@ async def update_gridpool_order(
632672 raise
633673
634674 async def cancel_gridpool_order (
635- self , gridpool_id : int , order_id : int
675+ self , gridpool_id : int , order_id : int , timeout : timedelta | None = None
636676 ) -> OrderDetail :
637677 """
638678 Cancel a single order for a given Gridpool.
639679
640680 Args:
641681 gridpool_id: The Gridpool to cancel the order for.
642682 order_id: The order to cancel.
683+ timeout: Timeout duration, defaults to None.
643684
644685 Returns:
645686 The cancelled order.
@@ -650,24 +691,29 @@ async def cancel_gridpool_order(
650691 try :
651692 response = await cast (
652693 Awaitable [electricity_trading_pb2 .CancelGridpoolOrderResponse ],
653- self .stub .CancelGridpoolOrder (
694+ grpc_call_with_timeout (
695+ self .stub .CancelGridpoolOrder ,
654696 electricity_trading_pb2 .CancelGridpoolOrderRequest (
655697 gridpool_id = gridpool_id , order_id = order_id
656698 ),
657699 metadata = self ._metadata ,
700+ timeout = timeout ,
658701 ),
659702 )
660703 return OrderDetail .from_pb (response .order_detail )
661704 except grpc .RpcError as e :
662705 _logger .exception ("Error occurred while cancelling gridpool order: %s" , e )
663706 raise
664707
665- async def cancel_all_gridpool_orders (self , gridpool_id : int ) -> int :
708+ async def cancel_all_gridpool_orders (
709+ self , gridpool_id : int , timeout : timedelta | None = None
710+ ) -> int :
666711 """
667712 Cancel all orders for a specific Gridpool.
668713
669714 Args:
670715 gridpool_id: The Gridpool to cancel the orders for.
716+ timeout: Timeout duration, defaults to None.
671717
672718 Returns:
673719 The ID of the Gridpool for which the orders were cancelled.
@@ -678,11 +724,13 @@ async def cancel_all_gridpool_orders(self, gridpool_id: int) -> int:
678724 try :
679725 response = await cast (
680726 Awaitable [electricity_trading_pb2 .CancelAllGridpoolOrdersResponse ],
681- self .stub .CancelAllGridpoolOrders (
727+ grpc_call_with_timeout (
728+ self .stub .CancelAllGridpoolOrders ,
682729 electricity_trading_pb2 .CancelAllGridpoolOrdersRequest (
683730 gridpool_id = gridpool_id
684731 ),
685732 metadata = self ._metadata ,
733+ timeout = timeout ,
686734 ),
687735 )
688736
@@ -693,13 +741,16 @@ async def cancel_all_gridpool_orders(self, gridpool_id: int) -> int:
693741 )
694742 raise
695743
696- async def get_gridpool_order (self , gridpool_id : int , order_id : int ) -> OrderDetail :
744+ async def get_gridpool_order (
745+ self , gridpool_id : int , order_id : int , timeout : timedelta | None = None
746+ ) -> OrderDetail :
697747 """
698748 Get a single order from a given gridpool.
699749
700750 Args:
701751 gridpool_id: The Gridpool to retrieve the order for.
702752 order_id: The order to retrieve.
753+ timeout: Timeout duration, defaults to None.
703754
704755 Returns:
705756 The order.
@@ -710,11 +761,13 @@ async def get_gridpool_order(self, gridpool_id: int, order_id: int) -> OrderDeta
710761 try :
711762 response = await cast (
712763 Awaitable [electricity_trading_pb2 .GetGridpoolOrderResponse ],
713- self .stub .GetGridpoolOrder (
764+ grpc_call_with_timeout (
765+ self .stub .GetGridpoolOrder ,
714766 electricity_trading_pb2 .GetGridpoolOrderRequest (
715767 gridpool_id = gridpool_id , order_id = order_id
716768 ),
717769 metadata = self ._metadata ,
770+ timeout = timeout ,
718771 ),
719772 )
720773
@@ -724,7 +777,7 @@ async def get_gridpool_order(self, gridpool_id: int, order_id: int) -> OrderDeta
724777 raise
725778
726779 async def list_gridpool_orders (
727- # pylint: disable=too-many-arguments, too-many-positional-arguments
780+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals
728781 self ,
729782 gridpool_id : int ,
730783 order_states : list [OrderState ] | None = None ,
@@ -734,6 +787,7 @@ async def list_gridpool_orders(
734787 tag : str | None = None ,
735788 max_nr_orders : int | None = None ,
736789 page_token : str | None = None ,
790+ timeout : timedelta | None = None ,
737791 ) -> list [OrderDetail ]:
738792 """
739793 List orders for a specific Gridpool with optional filters.
@@ -747,6 +801,7 @@ async def list_gridpool_orders(
747801 tag: The tag to filter by.
748802 max_nr_orders: The maximum number of orders to return.
749803 page_token: The page token to use for pagination.
804+ timeout: Timeout duration, defaults to None.
750805
751806 Returns:
752807 The list of orders for that gridpool.
@@ -770,13 +825,15 @@ async def list_gridpool_orders(
770825 try :
771826 response = await cast (
772827 Awaitable [electricity_trading_pb2 .ListGridpoolOrdersResponse ],
773- self .stub .ListGridpoolOrders (
828+ grpc_call_with_timeout (
829+ self .stub .ListGridpoolOrders ,
774830 electricity_trading_pb2 .ListGridpoolOrdersRequest (
775831 gridpool_id = gridpool_id ,
776832 filter = gridpool_order_filer .to_pb (),
777833 pagination_params = pagination_params .to_proto (),
778834 ),
779835 metadata = self ._metadata ,
836+ timeout = timeout ,
780837 ),
781838 )
782839
@@ -806,6 +863,7 @@ async def list_gridpool_trades(
806863 delivery_area : DeliveryArea | None = None ,
807864 max_nr_trades : int | None = None ,
808865 page_token : str | None = None ,
866+ timeout : timedelta | None = None ,
809867 ) -> list [Trade ]:
810868 """
811869 List trades for a specific Gridpool with optional filters.
@@ -819,6 +877,7 @@ async def list_gridpool_trades(
819877 delivery_area: The delivery area to filter by.
820878 max_nr_trades: The maximum number of trades to return.
821879 page_token: The page token to use for pagination.
880+ timeout: Timeout duration, defaults to None.
822881
823882 Returns:
824883 The list of trades for the given gridpool.
@@ -842,13 +901,15 @@ async def list_gridpool_trades(
842901 try :
843902 response = await cast (
844903 Awaitable [electricity_trading_pb2 .ListGridpoolTradesResponse ],
845- self .stub .ListGridpoolTrades (
904+ grpc_call_with_timeout (
905+ self .stub .ListGridpoolTrades ,
846906 electricity_trading_pb2 .ListGridpoolTradesRequest (
847907 gridpool_id = gridpool_id ,
848908 filter = gridpool_trade_filter .to_pb (),
849909 pagination_params = pagination_params .to_proto (),
850910 ),
851911 metadata = self ._metadata ,
912+ timeout = timeout ,
852913 ),
853914 )
854915
@@ -866,6 +927,7 @@ async def list_public_trades(
866927 sell_delivery_area : DeliveryArea | None = None ,
867928 max_nr_trades : int | None = None ,
868929 page_token : str | None = None ,
930+ timeout : timedelta | None = None ,
869931 ) -> list [PublicTrade ]:
870932 """
871933 List all executed public orders with optional filters.
@@ -877,6 +939,7 @@ async def list_public_trades(
877939 sell_delivery_area: The sell delivery area to filter by.
878940 max_nr_trades: The maximum number of trades to return.
879941 page_token: The page token to use for pagination.
942+ timeout: Timeout duration, defaults to None.
880943
881944 Returns:
882945 The list of public trades.
@@ -899,12 +962,14 @@ async def list_public_trades(
899962 try :
900963 response = await cast (
901964 Awaitable [electricity_trading_pb2 .ListPublicTradesResponse ],
902- self .stub .ListPublicTrades (
965+ grpc_call_with_timeout (
966+ self .stub .ListPublicTrades ,
903967 electricity_trading_pb2 .ListPublicTradesRequest (
904968 filter = public_trade_filter .to_pb (),
905969 pagination_params = pagination_params .to_proto (),
906970 ),
907971 metadata = self ._metadata ,
972+ timeout = timeout ,
908973 ),
909974 )
910975
0 commit comments