33
44"""Module to define the client class."""
55
6+ from __future__ import annotations
7+
68import logging
79from datetime import datetime , timezone
810from decimal import Decimal , InvalidOperation
9- from typing import Awaitable , cast
11+ from typing import TYPE_CHECKING , Any , Awaitable , cast
1012
1113import grpc
1214
1719)
1820from frequenz .channels import Receiver
1921from frequenz .client .base .client import BaseApiClient
22+ from frequenz .client .base .exception import ClientNotConnected
2023from frequenz .client .base .streaming import GrpcStreamBroadcaster
2124from frequenz .client .common .pagination import Params
2225from google .protobuf import field_mask_pb2 , struct_pb2
26+ from typing_extensions import override
2327
2428from ._types import (
2529 DeliveryArea ,
4145 UpdateOrder ,
4246)
4347
48+ if TYPE_CHECKING :
49+ from frequenz .api .electricity_trading .v1 .electricity_trading_pb2_grpc import (
50+ ElectricityTradingServiceAsyncStub ,
51+ )
52+
53+
4454_logger = logging .getLogger (__name__ )
4555
4656
@@ -81,7 +91,7 @@ def validate_decimal_places(value: Decimal, decimal_places: int, name: str) -> N
8191 ) from exc
8292
8393
84- class Client (BaseApiClient [ ElectricityTradingServiceStub ] ):
94+ class Client (BaseApiClient ):
8595 """Electricity trading client."""
8696
8797 _instances : dict [tuple [str , str | None ], "Client" ] = {}
@@ -123,7 +133,10 @@ def __init__(
123133 if not hasattr (
124134 self , "_initialized"
125135 ): # Prevent re-initialization of existing instances
126- super ().__init__ (server_url , ElectricityTradingServiceStub , connect = connect )
136+ super ().__init__ (server_url , connect = connect )
137+ self ._stub : ElectricityTradingServiceAsyncStub | None = None
138+ if connect :
139+ self ._create_stub ()
127140 self ._initialized = True
128141
129142 self ._gridpool_orders_streams : dict [
@@ -149,6 +162,41 @@ def __init__(
149162
150163 self ._metadata = (("key" , auth_key ),) if auth_key else ()
151164
165+ def _create_stub (self ) -> None :
166+ """Create a new gRPC stub for the Electricity Trading service."""
167+ stub : Any = ElectricityTradingServiceStub (self .channel )
168+ self ._stub = stub
169+
170+ @override
171+ def connect (self , server_url : str | None = None ) -> None :
172+ """Connect to the server, possibly using a new URL.
173+
174+ If the client is already connected and the URL is the same as the previous URL,
175+ this method does nothing. If you want to force a reconnection, you can call
176+ [disconnect()][frequenz.client.base.client.BaseApiClient.disconnect] first.
177+
178+ Args:
179+ server_url: The URL of the server to connect to. If not provided, the
180+ previously used URL is used.
181+ """
182+ super ().connect (server_url )
183+ self ._create_stub ()
184+
185+ @property
186+ def stub (self ) -> ElectricityTradingServiceAsyncStub :
187+ """
188+ Get the gRPC stub for the Electricity Trading service.
189+
190+ Returns:
191+ The gRPC stub.
192+
193+ Raises:
194+ ClientNotConnected: If the client is not connected to the server.
195+ """
196+ if self ._stub is None :
197+ raise ClientNotConnected (server_url = self .server_url , operation = "stub" )
198+ return self ._stub
199+
152200 async def stream_gridpool_orders (
153201 # pylint: disable=too-many-arguments, too-many-positional-arguments
154202 self ,
@@ -192,7 +240,7 @@ async def stream_gridpool_orders(
192240 try :
193241 self ._gridpool_orders_streams [stream_key ] = GrpcStreamBroadcaster (
194242 f"electricity-trading-{ stream_key } " ,
195- lambda : self .stub .ReceiveGridpoolOrdersStream ( # type: ignore
243+ lambda : self .stub .ReceiveGridpoolOrdersStream (
196244 electricity_trading_pb2 .ReceiveGridpoolOrdersStreamRequest (
197245 gridpool_id = gridpool_id ,
198246 filter = gridpool_order_filter .to_pb (),
@@ -251,7 +299,7 @@ async def stream_gridpool_trades(
251299 try :
252300 self ._gridpool_trades_streams [stream_key ] = GrpcStreamBroadcaster (
253301 f"electricity-trading-{ stream_key } " ,
254- lambda : self .stub .ReceiveGridpoolTradesStream ( # type: ignore
302+ lambda : self .stub .ReceiveGridpoolTradesStream (
255303 electricity_trading_pb2 .ReceiveGridpoolTradesStreamRequest (
256304 gridpool_id = gridpool_id ,
257305 filter = gridpool_trade_filter .to_pb (),
@@ -303,7 +351,7 @@ async def stream_public_trades(
303351 self ._public_trades_streams [public_trade_filter ] = (
304352 GrpcStreamBroadcaster (
305353 f"electricity-trading-{ public_trade_filter } " ,
306- lambda : self .stub .ReceivePublicTradesStream ( # type: ignore
354+ lambda : self .stub .ReceivePublicTradesStream (
307355 electricity_trading_pb2 .ReceivePublicTradesStreamRequest (
308356 filter = public_trade_filter .to_pb (),
309357 ),
0 commit comments