|  | 
|  | 1 | +# License: MIT | 
|  | 2 | +# Copyright © 2025 Frequenz Energy-as-a-Service GmbH | 
|  | 3 | + | 
|  | 4 | +"""An Interceptor that adds HMAC signature of the metadata fields to a gRPC call.""" | 
|  | 5 | + | 
|  | 6 | +import dataclasses | 
|  | 7 | +import hmac | 
|  | 8 | +import logging | 
|  | 9 | +import secrets | 
|  | 10 | +import time | 
|  | 11 | +from base64 import urlsafe_b64encode | 
|  | 12 | +from typing import Any, Callable | 
|  | 13 | + | 
|  | 14 | +from grpc.aio import ( | 
|  | 15 | +    ClientCallDetails, | 
|  | 16 | +    UnaryUnaryCall, | 
|  | 17 | +    UnaryUnaryClientInterceptor, | 
|  | 18 | +) | 
|  | 19 | + | 
|  | 20 | +_logger = logging.getLogger(__name__) | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +@dataclasses.dataclass(frozen=True) | 
|  | 24 | +class SignOptions: | 
|  | 25 | +    """Options for message signing of messages.""" | 
|  | 26 | + | 
|  | 27 | +    secret: str | 
|  | 28 | +    """The secret to sign the message with.""" | 
|  | 29 | + | 
|  | 30 | + | 
|  | 31 | +class SignInterceptor(UnaryUnaryClientInterceptor):  # type: ignore[type-arg] | 
|  | 32 | +    """An Interceptor that adds HMAC authentication of the metadata fields to a gRPC call.""" | 
|  | 33 | + | 
|  | 34 | +    def __init__(self, *, sign_options: SignOptions): | 
|  | 35 | +        """Create an instance of the interceptor. | 
|  | 36 | +
 | 
|  | 37 | +        Args: | 
|  | 38 | +            sign_options: The options for signing the message. | 
|  | 39 | +        """ | 
|  | 40 | +        self._secret = sign_options.secret.encode() | 
|  | 41 | + | 
|  | 42 | +    async def intercept_unary_unary( | 
|  | 43 | +        self, | 
|  | 44 | +        continuation: Callable[ | 
|  | 45 | +            [ClientCallDetails, object], UnaryUnaryCall[object, object] | 
|  | 46 | +        ], | 
|  | 47 | +        client_call_details: ClientCallDetails, | 
|  | 48 | +        request: object, | 
|  | 49 | +    ) -> object: | 
|  | 50 | +        """Intercept the call to add HMAC authentication to the metadata fields. | 
|  | 51 | +
 | 
|  | 52 | +        This is a known method from the base class that is overridden. | 
|  | 53 | +
 | 
|  | 54 | +        Args: | 
|  | 55 | +            continuation: The next interceptor in the chain. | 
|  | 56 | +            client_call_details: The call details. | 
|  | 57 | +            request: The request object. | 
|  | 58 | +
 | 
|  | 59 | +        Returns: | 
|  | 60 | +            The response object (this implementation does not modify the response). | 
|  | 61 | +        """ | 
|  | 62 | +        self.add_hmac( | 
|  | 63 | +            client_call_details, | 
|  | 64 | +            int(time.time()).to_bytes(8, "big"), | 
|  | 65 | +            secrets.token_bytes(16), | 
|  | 66 | +        ) | 
|  | 67 | +        return await continuation(client_call_details, request) | 
|  | 68 | + | 
|  | 69 | +    def add_hmac( | 
|  | 70 | +        self, client_call_details: ClientCallDetails, ts: bytes, nonce: bytes | 
|  | 71 | +    ) -> None: | 
|  | 72 | +        """Add the HMAC authentication to the metadata fields of the call details. | 
|  | 73 | +
 | 
|  | 74 | +        The extra headers are directly added to the client_call details. | 
|  | 75 | +
 | 
|  | 76 | +        Args: | 
|  | 77 | +            client_call_details: The call details. | 
|  | 78 | +            ts: The timestamp to use for the HMAC. | 
|  | 79 | +            nonce: The nonce to use for the HMAC. | 
|  | 80 | +        """ | 
|  | 81 | +        if client_call_details.metadata is None: | 
|  | 82 | +            _logger.error( | 
|  | 83 | +                "No metadata found, cannot extract an api key. Therefore, cannot sign the request." | 
|  | 84 | +            ) | 
|  | 85 | +            return | 
|  | 86 | + | 
|  | 87 | +        key: Any = client_call_details.metadata.get("x-key") | 
|  | 88 | +        if key is None: | 
|  | 89 | +            _logger.error("No key found in metadata, cannot sign the request.") | 
|  | 90 | +            return | 
|  | 91 | +        hmac_obj = hmac.new(self._secret, digestmod="sha256") | 
|  | 92 | +        hmac_obj.update(key.encode()) | 
|  | 93 | +        hmac_obj.update(ts) | 
|  | 94 | +        hmac_obj.update(nonce) | 
|  | 95 | + | 
|  | 96 | +        hmac_obj.update(client_call_details.method.encode()) | 
|  | 97 | + | 
|  | 98 | +        client_call_details.metadata["x-ts"] = ts | 
|  | 99 | +        client_call_details.metadata["x-nonce"] = nonce | 
|  | 100 | +        client_call_details.metadata["x-hmac"] = urlsafe_b64encode(hmac_obj.digest()) | 
0 commit comments