88import logging
99from datetime import datetime , timezone
1010from decimal import Decimal , InvalidOperation
11- from typing import TYPE_CHECKING , Any , Awaitable , cast
11+ from typing import Awaitable , cast
1212
1313import grpc
1414
1515# pylint: disable=no-member
16- from frequenz .api .electricity_trading .v1 import electricity_trading_pb2
16+ from frequenz .api .electricity_trading .v1 import (
17+ electricity_trading_pb2 ,
18+ electricity_trading_pb2_grpc ,
19+ )
1720from frequenz .api .electricity_trading .v1 .electricity_trading_pb2_grpc import (
1821 ElectricityTradingServiceStub ,
1922)
2326from frequenz .client .base .streaming import GrpcStreamBroadcaster
2427from frequenz .client .common .pagination import Params
2528from google .protobuf import field_mask_pb2 , struct_pb2
26- from typing_extensions import override
2729
2830from ._types import (
2931 DeliveryArea ,
4547 UpdateOrder ,
4648)
4749
48- if TYPE_CHECKING :
49- from frequenz .api .electricity_trading .v1 .electricity_trading_pb2_grpc import (
50- ElectricityTradingServiceAsyncStub ,
51- )
52-
53-
5450_logger = logging .getLogger (__name__ )
5551
5652
@@ -91,7 +87,7 @@ def validate_decimal_places(value: Decimal, decimal_places: int, name: str) -> N
9187 ) from exc
9288
9389
94- class Client (BaseApiClient ):
90+ class Client (BaseApiClient [ ElectricityTradingServiceStub ] ):
9591 """Electricity trading client."""
9692
9793 _instances : dict [tuple [str , str | None ], "Client" ] = {}
@@ -133,10 +129,11 @@ def __init__(
133129 if not hasattr (
134130 self , "_initialized"
135131 ): # Prevent re-initialization of existing instances
136- super ().__init__ (server_url , connect = connect )
137- self ._stub : ElectricityTradingServiceAsyncStub | None = None
138- if connect :
139- self ._create_stub ()
132+ super ().__init__ (
133+ server_url ,
134+ connect = connect ,
135+ create_stub = ElectricityTradingServiceStub ,
136+ )
140137 self ._initialized = True
141138
142139 self ._gridpool_orders_streams : dict [
@@ -162,28 +159,8 @@ def __init__(
162159
163160 self ._metadata = (("key" , auth_key ),) if auth_key else ()
164161
165- def _create_stub (self ) -> None :
166- """Create a new gRPC stub for the Electricity Trading service."""
167- stub : Any = ElectricityTradingServiceStub (self .channel )
168- self ._stub = stub
169-
170- @override
171- def connect (self , server_url : str | None = None ) -> None :
172- """Connect to the server, possibly using a new URL.
173-
174- If the client is already connected and the URL is the same as the previous URL,
175- this method does nothing. If you want to force a reconnection, you can call
176- [disconnect()][frequenz.client.base.client.BaseApiClient.disconnect] first.
177-
178- Args:
179- server_url: The URL of the server to connect to. If not provided, the
180- previously used URL is used.
181- """
182- super ().connect (server_url )
183- self ._create_stub ()
184-
185162 @property
186- def stub (self ) -> ElectricityTradingServiceAsyncStub :
163+ def stub (self ) -> electricity_trading_pb2_grpc . ElectricityTradingServiceAsyncStub :
187164 """
188165 Get the gRPC stub for the Electricity Trading service.
189166
@@ -195,7 +172,11 @@ def stub(self) -> ElectricityTradingServiceAsyncStub:
195172 """
196173 if self ._stub is None :
197174 raise ClientNotConnected (server_url = self .server_url , operation = "stub" )
198- return self ._stub
175+ # This type: ignore is needed because we need to cast the sync stub to
176+ # the async stub, but we can't use cast because the async stub doesn't
177+ # actually exists to the eyes of the interpreter, it only exists for the
178+ # type-checker, so it can only be used for type hints.
179+ return self ._stub # type: ignore
199180
200181 async def stream_gridpool_orders (
201182 # pylint: disable=too-many-arguments, too-many-positional-arguments
0 commit comments