66import abc
77import inspect
88from collections .abc import Awaitable , Callable
9+ from types import EllipsisType
910from typing import Any , Generic , Self , TypeVar , overload
1011
1112from grpc .aio import (
1213 AioRpcError ,
1314 Channel ,
15+ ClientInterceptor ,
1416)
1517
18+ from .authentication import (
19+ AuthenticationInterceptorUnaryStream ,
20+ AuthenticationInterceptorUnaryUnary ,
21+ )
1622from .channel import ChannelOptions , parse_grpc_uri
1723from .exception import ApiClientError , ClientNotConnected
24+ from .signing import (
25+ SigningInterceptorUnaryStream ,
26+ SigningInterceptorUnaryUnary ,
27+ )
1828
1929StubT = TypeVar ("StubT" )
2030"""The type of the gRPC stub."""
@@ -153,13 +163,15 @@ async def main():
153163 instances.
154164 """
155165
156- def __init__ (
166+ def __init__ ( # pylint: disable=too-many-arguments
157167 self ,
158168 server_url : str ,
159169 create_stub : Callable [[Channel ], StubT ],
160170 * ,
161171 connect : bool = True ,
162172 channel_defaults : ChannelOptions = ChannelOptions (),
173+ auth_key : str | None = None ,
174+ sign_secret : str | None = None ,
163175 ) -> None :
164176 """Create an instance and connect to the server.
165177
@@ -172,14 +184,21 @@ def __init__(
172184 called.
173185 channel_defaults: The default options for the gRPC channel to create using
174186 the server URL.
187+ auth_key: The API key to use when connecting to the service.
188+ sign_secret: The secret to use when creating message HMAC.
189+
175190 """
176191 self ._server_url : str = server_url
177192 self ._create_stub : Callable [[Channel ], StubT ] = create_stub
178193 self ._channel_defaults : ChannelOptions = channel_defaults
194+ self ._auth_key = auth_key
195+ self ._sign_secret = sign_secret
179196 self ._channel : Channel | None = None
180197 self ._stub : StubT | None = None
181198 if connect :
182- self .connect (server_url )
199+ self .connect (
200+ server_url = self ._server_url , auth_key = auth_key , sign_secret = sign_secret
201+ )
183202
184203 @property
185204 def server_url (self ) -> str :
@@ -212,7 +231,13 @@ def is_connected(self) -> bool:
212231 """Whether the client is connected to the server."""
213232 return self ._channel is not None
214233
215- def connect (self , server_url : str | None = None ) -> None :
234+ def connect (
235+ self ,
236+ server_url : str | None = None ,
237+ * ,
238+ auth_key : str | None | EllipsisType = ...,
239+ sign_secret : str | None | EllipsisType = ...,
240+ ) -> None :
216241 """Connect to the server, possibly using a new URL.
217242
218243 If the client is already connected and the URL is the same as the previous URL,
@@ -222,12 +247,41 @@ def connect(self, server_url: str | None = None) -> None:
222247 Args:
223248 server_url: The URL of the server to connect to. If not provided, the
224249 previously used URL is used.
250+ auth_key: The API key to use when connecting to the service. If an Ellipsis
251+ is provided, the previously used auth_key is used.
252+ sign_secret: The secret to use when creating message HMAC. If an Ellipsis is
253+ provided,
225254 """
255+ reconnect = False
226256 if server_url is not None and server_url != self ._server_url : # URL changed
227257 self ._server_url = server_url
228- elif self .is_connected :
258+ reconnect = True
259+ if auth_key is not ... and auth_key != self ._auth_key :
260+ self ._auth_key = auth_key
261+ reconnect = True
262+ if sign_secret is not ... and sign_secret != self ._sign_secret :
263+ self ._sign_secret = sign_secret
264+ reconnect = True
265+ if self .is_connected and not reconnect : # Desired connection already exists
229266 return
230- self ._channel = parse_grpc_uri (self ._server_url , self ._channel_defaults )
267+
268+ interceptors : list [ClientInterceptor ] = []
269+ if self ._auth_key is not None :
270+ interceptors += [
271+ AuthenticationInterceptorUnaryUnary (self ._auth_key ), # type: ignore [list-item]
272+ AuthenticationInterceptorUnaryStream (self ._auth_key ), # type: ignore [list-item]
273+ ]
274+ if self ._sign_secret is not None :
275+ interceptors += [
276+ SigningInterceptorUnaryUnary (self ._sign_secret ), # type: ignore [list-item]
277+ SigningInterceptorUnaryStream (self ._sign_secret ), # type: ignore [list-item]
278+ ]
279+
280+ self ._channel = parse_grpc_uri (
281+ self ._server_url ,
282+ interceptors ,
283+ defaults = self ._channel_defaults ,
284+ )
231285 self ._stub = self ._create_stub (self ._channel )
232286
233287 async def disconnect (self ) -> None :
0 commit comments