99import secrets
1010import time
1111from base64 import urlsafe_b64encode
12- from typing import Any , Callable
12+ from typing import Any , AsyncIterable , Callable
1313
1414from grpc .aio import (
1515 ClientCallDetails ,
16+ UnaryStreamCall ,
17+ UnaryStreamClientInterceptor ,
1618 UnaryUnaryCall ,
1719 UnaryUnaryClientInterceptor ,
1820)
1921
2022_logger = logging .getLogger (__name__ )
2123
2224
25+ def _add_hmac (
26+ secret : bytes , client_call_details : ClientCallDetails , ts : int , nonce : bytes
27+ ) -> None :
28+ """Add the HMAC authentication to the metadata fields of the call details.
29+
30+ The extra headers are directly added to the client_call details.
31+
32+ Args:
33+ secret: The symmetric secret shared with the service.
34+ client_call_details: The call details.
35+ ts: The timestamp to use for the HMAC.
36+ nonce: The nonce to use for the HMAC.
37+ """
38+ if client_call_details .metadata is None :
39+ _logger .error (
40+ "No metadata found, cannot extract an api key. Therefore, cannot sign the request."
41+ )
42+ return
43+
44+ key : Any = client_call_details .metadata .get ("key" )
45+ if key is None :
46+ _logger .error ("No key found in metadata, cannot sign the request." )
47+ return
48+
49+ # Make into a base10 integer string and then encode to bytes
50+ # We can not use a raw bytes timestamp as the underlying network library
51+ # really hates zero bytes in the metadata values
52+ ts_bytes = str (ts ).encode ()
53+ nonce_bytes = urlsafe_b64encode (nonce )
54+
55+ hmac_obj = hmac .new (secret , digestmod = "sha256" )
56+ hmac_obj .update (key .encode ())
57+ hmac_obj .update (ts_bytes )
58+ hmac_obj .update (nonce_bytes )
59+
60+ # Once again, gRPC is mistyped
61+ hmac_obj .update (client_call_details .method .split (b"/" )[- 1 ]) # type: ignore[arg-type]
62+
63+ client_call_details .metadata ["ts" ] = ts_bytes
64+ client_call_details .metadata ["nonce" ] = nonce_bytes
65+ # By definition the signature is base64 encoded _without_ the padding, so we strip that
66+ client_call_details .metadata ["sig" ] = urlsafe_b64encode (hmac_obj .digest ()).strip (
67+ b"="
68+ )
69+
70+
2371@dataclasses .dataclass (frozen = True )
2472class SigningOptions :
2573 """Options for message signing of messages."""
@@ -28,8 +76,8 @@ class SigningOptions:
2876 """The secret to sign the message with."""
2977
3078
31- # There is an issue in gRPC that causes the type to be unspecifieable correctly here.
32- class SigningInterceptor (UnaryUnaryClientInterceptor ): # type: ignore[type-arg]
79+ # There is an issue in gRPC which means the type can not be specified correctly here.
80+ class SigningInterceptorUnaryUnary (UnaryUnaryClientInterceptor ): # type: ignore[type-arg]
3381 """An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""
3482
3583 def __init__ (self , options : SigningOptions ):
@@ -60,42 +108,51 @@ async def intercept_unary_unary(
60108 Returns:
61109 The response object (this implementation does not modify the response).
62110 """
63- self .add_hmac (
111+ _add_hmac (
112+ self ._secret ,
64113 client_call_details ,
65- int (time .time ()). to_bytes ( 8 , "big" ) ,
114+ int (time .time ()),
66115 secrets .token_bytes (16 ),
67116 )
68117 return await continuation (client_call_details , request )
69118
70- def add_hmac (
71- self , client_call_details : ClientCallDetails , ts : bytes , nonce : bytes
72- ) -> None :
73- """Add the HMAC authentication to the metadata fields of the call details.
74119
75- The extra headers are directly added to the client_call details.
120+ # There is an issue in gRPC which means the type can not be specified correctly here.
121+ class SigningInterceptorUnaryStream (UnaryStreamClientInterceptor ): # type: ignore[type-arg]
122+ """An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call."""
123+
124+ def __init__ (self , options : SigningOptions ):
125+ """Create an instance of the interceptor.
76126
77127 Args:
128+ options: The options for signing the message.
129+ """
130+ self ._secret = options .secret .encode ()
131+
132+ async def intercept_unary_stream (
133+ self ,
134+ continuation : Callable [
135+ [ClientCallDetails , Any ], UnaryStreamCall [object , object ]
136+ ],
137+ client_call_details : ClientCallDetails ,
138+ request : Any ,
139+ ) -> AsyncIterable [object ] | UnaryStreamCall [object , object ]:
140+ """Intercept the call to add HMAC authentication to the metadata fields.
141+
142+ This is a known method from the base class that is overridden.
143+
144+ Args:
145+ continuation: The next interceptor in the chain.
78146 client_call_details: The call details.
79- ts: The timestamp to use for the HMAC.
80- nonce: The nonce to use for the HMAC.
147+ request: The request object.
148+
149+ Returns:
150+ The response object (this implementation does not modify the response).
81151 """
82- if client_call_details .metadata is None :
83- _logger .error (
84- "No metadata found, cannot extract an api key. Therefore, cannot sign the request."
85- )
86- return
87-
88- key : Any = client_call_details .metadata .get ("key" )
89- if key is None :
90- _logger .error ("No key found in metadata, cannot sign the request." )
91- return
92- hmac_obj = hmac .new (self ._secret , digestmod = "sha256" )
93- hmac_obj .update (key .encode ())
94- hmac_obj .update (ts )
95- hmac_obj .update (nonce )
96-
97- hmac_obj .update (client_call_details .method .encode ())
98-
99- client_call_details .metadata ["ts" ] = ts
100- client_call_details .metadata ["nonce" ] = nonce
101- client_call_details .metadata ["sig" ] = urlsafe_b64encode (hmac_obj .digest ())
152+ _add_hmac (
153+ self ._secret ,
154+ client_call_details ,
155+ int (time .time ()),
156+ secrets .token_bytes (16 ),
157+ )
158+ return await continuation (client_call_details , request ) # type: ignore
0 commit comments