diff --git a/pyproject.toml b/pyproject.toml index b31d1509..9be222ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,13 @@ module = [ "momento.internal.aio._scs_data_client", "momento.internal.aio._scs_grpc_manager", "momento.internal.aio._utilities", + "momento.internal.synchronous._utilities", "momento.responses.control.signing_key.*", + "momento.internal.aio._middleware_interceptor", + "momento.internal.synchronous._middleware_interceptor", + "momento.config.middleware.models", + "momento.config.middleware.aio.middleware_metadata", + "momento.config.middleware.synchronous.middleware_metadata", ] disallow_any_expr = false diff --git a/src/momento/config/configuration.py b/src/momento/config/configuration.py index 9e81c23c..68f7dfa1 100644 --- a/src/momento/config/configuration.py +++ b/src/momento/config/configuration.py @@ -3,14 +3,16 @@ from abc import ABC, abstractmethod from datetime import timedelta from pathlib import Path +from typing import List, Optional +import momento.config.middleware.aio from momento.retry import RetryStrategy +from .middleware import Middleware from .transport.transport_strategy import TransportStrategy class ConfigurationBase(ABC): - # TODO: Middlewares @abstractmethod def get_retry_strategy(self) -> RetryStrategy: pass @@ -35,20 +37,39 @@ def with_client_timeout(self, client_timeout: timedelta) -> Configuration: def with_root_certificates_pem(self, root_certificate_path: Path) -> Configuration: pass + @abstractmethod + def with_middlewares(self, middlewares: List[Middleware]) -> Configuration: + pass + + @abstractmethod + def add_middleware(self, middleware: Middleware) -> Configuration: + pass + + @abstractmethod + def get_middlewares(self) -> List[Middleware]: + pass + class Configuration(ConfigurationBase): """Configuration options for Momento Simple Cache Client.""" - def __init__(self, transport_strategy: TransportStrategy, retry_strategy: RetryStrategy): + def __init__( + self, + transport_strategy: TransportStrategy, + retry_strategy: RetryStrategy, + middlewares: Optional[List[Middleware]] = None, + ): """Instantiate a Configuration. Args: transport_strategy (TransportStrategy): Configuration options for networking with the Momento service. retry_strategy (RetryStrategy): the strategy to use when determining whether to retry a grpc call. + middlewares: Middleware that can intercept Momento calls. May be aio or synchronous. """ self._transport_strategy = transport_strategy self._retry_strategy = retry_strategy + self._middlewares: List[Middleware] = list(middlewares or []) def get_retry_strategy(self) -> RetryStrategy: """Access the retry strategy. @@ -67,7 +88,7 @@ def with_retry_strategy(self, retry_strategy: RetryStrategy) -> Configuration: Returns: Configuration: the new Configuration with the specified RetryStrategy. """ - return Configuration(self._transport_strategy, retry_strategy) + return Configuration(self._transport_strategy, retry_strategy, self._middlewares) def get_transport_strategy(self) -> TransportStrategy: """Access the transport strategy. @@ -86,7 +107,7 @@ def with_transport_strategy(self, transport_strategy: TransportStrategy) -> Conf Returns: Configuration: the new Configuration with the specified TransportStrategy. """ - return Configuration(transport_strategy, self._retry_strategy) + return Configuration(transport_strategy, self._retry_strategy, self._middlewares) def with_client_timeout(self, client_timeout: timedelta) -> Configuration: """Copies the Configuration and sets the new client-side timeout in the copy's TransportStrategy. @@ -97,7 +118,11 @@ def with_client_timeout(self, client_timeout: timedelta) -> Configuration: Return: Configuration: the new Configuration. """ - return Configuration(self._transport_strategy.with_client_timeout(client_timeout), self._retry_strategy) + return Configuration( + self._transport_strategy.with_client_timeout(client_timeout), + self._retry_strategy, + self._middlewares, + ) def with_root_certificates_pem(self, root_certificates_pem_path: Path) -> Configuration: """Copies the Configuration and sets the new root certificates in the copy's TransportStrategy. @@ -106,10 +131,60 @@ def with_root_certificates_pem(self, root_certificates_pem_path: Path) -> Config root_certificates_pem_path (Path): the new root certificates. Returns: - ConfigurationBase: the new Configuration. + Configuration: the new Configuration. """ grpc_configuration = self._transport_strategy.get_grpc_configuration().with_root_certificates_pem( root_certificates_pem_path ) transport_strategy = self._transport_strategy.with_grpc_configuration(grpc_configuration) return self.with_transport_strategy(transport_strategy) + + def with_middlewares(self, middlewares: List[Middleware]) -> Configuration: + """Copies the Configuration and adds the new middlewares to the end of the list. + + Args: + middlewares: the middleware list to be appended to the Configuration's existing middleware. These can be + aio or synchronous middleware. + + Returns: + Configuration: the new Configuration. + """ + new_middlewares = self._middlewares.copy() + middlewares + return Configuration(self._transport_strategy, self._retry_strategy, new_middlewares) + + def add_middleware(self, middleware: Middleware) -> Configuration: + """Copies the Configuration and adds the new middleware to the end of the list. + + Args: + middleware: the middleware to be appended to the Configuration's existing middleware. This can be aio or + synchronous middleware. + + Returns: + Configuration: the new Configuration. + """ + new_middlewares = self._middlewares.copy() + [middleware] + return Configuration(self._transport_strategy, self._retry_strategy, new_middlewares) + + def get_middlewares(self) -> List[Middleware]: + """Access the middleware list. + + Returns: + the configuration's list of middleware. + """ + return self._middlewares.copy() + + def get_aio_middlewares(self) -> List[momento.config.middleware.aio.Middleware]: + """Access the aio middleware from the middleware list. + + Returns: + the configuration's list of aio middleware. + """ + return [m for m in self._middlewares if isinstance(m, momento.config.middleware.aio.Middleware)] + + def get_sync_middlewares(self) -> List[momento.config.middleware.synchronous.Middleware]: + """Access the synchronous middleware from the middleware list. + + Returns: + the configuration's list of synchronous middleware. + """ + return [m for m in self._middlewares if isinstance(m, momento.config.middleware.synchronous.Middleware)] diff --git a/src/momento/config/middleware/__init__.py b/src/momento/config/middleware/__init__.py new file mode 100644 index 00000000..59d39fd4 --- /dev/null +++ b/src/momento/config/middleware/__init__.py @@ -0,0 +1,18 @@ +from typing import Union + +from momento.config.middleware.aio import Middleware as AsyncMiddleware +from momento.config.middleware.models import ( + MiddlewareMessage, + MiddlewareRequestHandlerContext, + MiddlewareStatus, +) +from momento.config.middleware.synchronous import Middleware as SyncMiddleware + +Middleware = Union[SyncMiddleware, AsyncMiddleware] + +__all__ = [ + "Middleware", + "MiddlewareMessage", + "MiddlewareStatus", + "MiddlewareRequestHandlerContext", +] diff --git a/src/momento/config/middleware/aio/__init__.py b/src/momento/config/middleware/aio/__init__.py new file mode 100644 index 00000000..cf517c37 --- /dev/null +++ b/src/momento/config/middleware/aio/__init__.py @@ -0,0 +1,4 @@ +from momento.config.middleware.aio.middleware import Middleware, MiddlewareRequestHandler +from momento.config.middleware.aio.middleware_metadata import MiddlewareMetadata + +__all__ = ["Middleware", "MiddlewareMetadata", "MiddlewareRequestHandler"] diff --git a/src/momento/config/middleware/aio/middleware.py b/src/momento/config/middleware/aio/middleware.py new file mode 100644 index 00000000..57c5d8b7 --- /dev/null +++ b/src/momento/config/middleware/aio/middleware.py @@ -0,0 +1,36 @@ +import abc + +from momento.config.middleware.aio.middleware_metadata import MiddlewareMetadata +from momento.config.middleware.models import MiddlewareMessage, MiddlewareRequestHandlerContext, MiddlewareStatus + + +class MiddlewareRequestHandler(abc.ABC): + @abc.abstractmethod + async def on_request_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata: + pass + + @abc.abstractmethod + async def on_request_body(self, request: MiddlewareMessage) -> MiddlewareMessage: + pass + + @abc.abstractmethod + async def on_response_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata: + pass + + @abc.abstractmethod + async def on_response_body(self, response: MiddlewareMessage) -> MiddlewareMessage: + pass + + @abc.abstractmethod + async def on_response_status(self, status: MiddlewareStatus) -> MiddlewareStatus: + pass + + +class Middleware(abc.ABC): + @abc.abstractmethod + async def on_new_request(self, context: MiddlewareRequestHandlerContext) -> MiddlewareRequestHandler: + pass + + # noinspection PyMethodMayBeStatic + def close(self) -> None: + return None diff --git a/src/momento/config/middleware/aio/middleware_metadata.py b/src/momento/config/middleware/aio/middleware_metadata.py new file mode 100644 index 00000000..6082f2f4 --- /dev/null +++ b/src/momento/config/middleware/aio/middleware_metadata.py @@ -0,0 +1,14 @@ +from typing import Optional + +from grpc.aio import Metadata + + +class MiddlewareMetadata: + """Wrapper for gRPC metadata.""" + + def __init__(self, metadata: Optional[Metadata]): + self.grpc_metadata = metadata + + def get_grpc_metadata(self) -> Optional[Metadata]: + """Get the underlying gRPC metadata.""" + return self.grpc_metadata diff --git a/src/momento/config/middleware/models.py b/src/momento/config/middleware/models.py new file mode 100644 index 00000000..f8fe558e --- /dev/null +++ b/src/momento/config/middleware/models.py @@ -0,0 +1,46 @@ +from typing import Dict + +import grpc +from google.protobuf.message import Message + +CONNECTION_ID_KEY = "connectionID" + + +class MiddlewareMessage: + """Wrapper for a gRPC protobuf message.""" + + def __init__(self, message: Message): + self.grpc_message = message + + def get_message_length(self) -> int: + """Get the length of the message in bytes.""" + return len(self.grpc_message.SerializeToString()) + + def get_constructor_name(self) -> str: + """Get the class name of the message.""" + return str(self.grpc_message.__class__.__name__) + + def get_message(self) -> Message: + """Get the underlying gRPC message.""" + return self.grpc_message + + +class MiddlewareStatus: + """Wrapper for gRPC status.""" + + def __init__(self, status: grpc.StatusCode): + self.grpc_status = status + + def get_code(self) -> grpc.StatusCode: + """Get the status code.""" + return self.grpc_status + + +class MiddlewareRequestHandlerContext: + """Context for middleware request handlers.""" + + def __init__(self, context: Dict[str, str]): + self.context = context + + def get_context(self) -> Dict[str, str]: + return self.context diff --git a/src/momento/config/middleware/synchronous/__init__.py b/src/momento/config/middleware/synchronous/__init__.py new file mode 100644 index 00000000..ace86c40 --- /dev/null +++ b/src/momento/config/middleware/synchronous/__init__.py @@ -0,0 +1,4 @@ +from momento.config.middleware.synchronous.middleware import Middleware, MiddlewareRequestHandler +from momento.config.middleware.synchronous.middleware_metadata import MiddlewareMetadata + +__all__ = ["Middleware", "MiddlewareMetadata", "MiddlewareRequestHandler"] diff --git a/src/momento/config/middleware/synchronous/middleware.py b/src/momento/config/middleware/synchronous/middleware.py new file mode 100644 index 00000000..86e9c983 --- /dev/null +++ b/src/momento/config/middleware/synchronous/middleware.py @@ -0,0 +1,36 @@ +import abc + +from momento.config.middleware.models import MiddlewareMessage, MiddlewareRequestHandlerContext, MiddlewareStatus +from momento.config.middleware.synchronous.middleware_metadata import MiddlewareMetadata + + +class MiddlewareRequestHandler(abc.ABC): + @abc.abstractmethod + def on_request_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata: + pass + + @abc.abstractmethod + def on_request_body(self, request: MiddlewareMessage) -> MiddlewareMessage: + pass + + @abc.abstractmethod + def on_response_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata: + pass + + @abc.abstractmethod + def on_response_body(self, response: MiddlewareMessage) -> MiddlewareMessage: + pass + + @abc.abstractmethod + def on_response_status(self, status: MiddlewareStatus) -> MiddlewareStatus: + pass + + +class Middleware(abc.ABC): + @abc.abstractmethod + def on_new_request(self, context: MiddlewareRequestHandlerContext) -> MiddlewareRequestHandler: + pass + + # noinspection PyMethodMayBeStatic + def close(self) -> None: + return None diff --git a/src/momento/config/middleware/synchronous/middleware_metadata.py b/src/momento/config/middleware/synchronous/middleware_metadata.py new file mode 100644 index 00000000..8e2a502f --- /dev/null +++ b/src/momento/config/middleware/synchronous/middleware_metadata.py @@ -0,0 +1,14 @@ +from typing import Optional + +from grpc._typing import MetadataType + + +class MiddlewareMetadata: + """Wrapper for gRPC metadata.""" + + def __init__(self, metadata: Optional[MetadataType]): + self.grpc_metadata = metadata + + def get_grpc_metadata(self) -> Optional[MetadataType]: + """Get the underlying gRPC metadata.""" + return self.grpc_metadata diff --git a/src/momento/internal/aio/_add_header_client_interceptor.py b/src/momento/internal/aio/_add_header_client_interceptor.py index 1bb2edfe..0430ec56 100644 --- a/src/momento/internal/aio/_add_header_client_interceptor.py +++ b/src/momento/internal/aio/_add_header_client_interceptor.py @@ -3,10 +3,8 @@ from typing import Callable import grpc -from grpc.aio import ClientCallDetails, Metadata -from momento.errors import InvalidArgumentException -from momento.internal.services import Service +from momento.internal.aio._utilities import sanitize_client_call_details class Header: @@ -81,55 +79,3 @@ async def intercept_unary_unary( AddHeaderClientInterceptor.are_only_once_headers_sent = True return await continuation(new_client_call_details, request) - - -def sanitize_client_call_details(client_call_details: grpc.aio.ClientCallDetails) -> grpc.aio.ClientCallDetails: - """Defensive function meant to handle inbound gRPC client request objects. - - Args: - client_call_details: the original inbound client grpc request we are intercepting - - Returns: a new client_call_details object with metadata properly initialized to a `grpc.aio.Metadata` object - """ - # Makes sure we can handle properly when we inject our own metadata onto request object. - # This was mainly done as temporary fix after we observed ddtrace grpc client interceptor passing - # client_call_details.metadata as a list instead of a grpc.aio.Metadata object. - # See this ticket for follow-up actions to come back in and address this longer term: - # https://github.com/momentohq/client-sdk-python/issues/149 - new_client_call_details = None - # If no metadata set on passed in client call details then we are first to set, so we should just initialize - if client_call_details.metadata is None: - new_client_call_details = ClientCallDetails( - method=client_call_details.method, - timeout=client_call_details.timeout, - metadata=Metadata(), - credentials=client_call_details.credentials, - wait_for_ready=client_call_details.wait_for_ready, - ) - - # This is block hit when ddtrace interceptor runs first and sets metadata as a list - elif isinstance(client_call_details.metadata, list): - existing_headers = client_call_details.metadata - metadata = Metadata() - # re-add all existing values to new metadata - for md_key, md_value in existing_headers: - metadata.add(md_key, md_value) - new_client_call_details = ClientCallDetails( - method=client_call_details.method, - timeout=client_call_details.timeout, - metadata=metadata, - credentials=client_call_details.credentials, - wait_for_ready=client_call_details.wait_for_ready, - ) - elif isinstance(client_call_details.metadata, grpc.aio.Metadata): - # If proper grpc `grpc.aio.Metadata()` object is passed just use original object passed and pass back - new_client_call_details = client_call_details - else: - # Else we raise exception for now since we don't know how to handle an unknown type - raise InvalidArgumentException( - "unexpected grpc client request metadata property passed to interceptor " - "type=" + str(type(client_call_details.metadata)), - Service.AUTH, - ) - - return new_client_call_details diff --git a/src/momento/internal/aio/_middleware_interceptor.py b/src/momento/internal/aio/_middleware_interceptor.py new file mode 100644 index 00000000..bbc06682 --- /dev/null +++ b/src/momento/internal/aio/_middleware_interceptor.py @@ -0,0 +1,162 @@ +import asyncio +from types import MethodType +from typing import Awaitable, Callable, List, Optional, TypeVar, Union, cast + +import grpc +from google.protobuf.message import Message +from grpc.aio import ClientCallDetails, Metadata +from grpc.aio._call import UnaryUnaryCall +from grpc.aio._typing import RequestType, ResponseType + +from momento import logs +from momento.config.middleware import ( + MiddlewareMessage, + MiddlewareRequestHandlerContext, + MiddlewareStatus, +) +from momento.config.middleware.aio import Middleware, MiddlewareMetadata, MiddlewareRequestHandler +from momento.internal.aio._utilities import create_client_call_details, sanitize_client_call_details + +T = TypeVar("T") + + +class _ProcessedResponseCall(UnaryUnaryCall): + # noinspection PyMissingConstructor + def __init__( + self, + call: UnaryUnaryCall, + status_code: grpc.StatusCode = None, + processed_response: Optional[T] = None, + initial_metadata: Optional[Metadata] = None, + error: Optional[grpc.RpcError] = None, + ) -> None: + self._call = call + self._initial_metadata = initial_metadata + self._status_code = status_code + self._error = error + + # Create a future for the processed response + self._response_future = asyncio.get_event_loop().create_future() + if error is not None: + self._response_future.set_exception(error) + elif processed_response is not None: + self._response_future.set_result(processed_response) + + async def initial_metadata(self) -> Metadata: + if self._initial_metadata is not None: + return self._initial_metadata + return await self._call.initial_metadata() + + async def trailing_metadata(self) -> Metadata: + return await self._call.trailing_metadata() + + async def code(self) -> grpc.StatusCode: + return self._status_code + + async def details(self) -> str: + return await self._call.details() # type: ignore[no-any-return] + + def cancelled(self) -> bool: + return self._call.cancelled() # type: ignore[no-any-return] + + def done(self) -> bool: + return True + + def time_remaining(self) -> Optional[float]: + return self._call.time_remaining() # type: ignore[no-any-return] + + def cancel(self) -> bool: + return False + + def add_done_callback(self, callback) -> None: # type: ignore[no-untyped-def] + callback(self) + + def __await__(self): # type: ignore[no-untyped-def] + return self._response_future.__await__() + + +class MiddlewareInterceptor(grpc.aio.UnaryUnaryClientInterceptor): + middlewares: List[Middleware] + context: MiddlewareRequestHandlerContext + + def __init__(self, middlewares: List[Middleware], context: MiddlewareRequestHandlerContext) -> None: + self._logger = logs.logger + self.middlewares = middlewares + self.context = context + + async def apply_handler_methods(self, methods: List[Callable[[T], Awaitable[T]]], original_input: T) -> T: + current_value = original_input + + for method in methods: + try: + current_value = await method(current_value) + except Exception as e: + bound_method = cast(MethodType, method) + handler_info = f"{bound_method.__self__.__class__.__name__}.{method.__name__}" + self._logger.exception(f"Error in middleware method {handler_info}: {str(e)}") + + return current_value + + async def intercept_unary_unary( + self, + continuation: Callable[[ClientCallDetails, RequestType], UnaryUnaryCall], + client_call_details: ClientCallDetails, + request: RequestType, + ) -> Union[UnaryUnaryCall, ResponseType]: + client_call_details = sanitize_client_call_details(client_call_details) + + handlers: List[MiddlewareRequestHandler] = [] + for middleware in self.middlewares: + handler = await middleware.on_new_request(self.context) + handlers.append(handler) + reversed_handlers = handlers[::-1] + + metadata = await self.apply_handler_methods( + [handler.on_request_metadata for handler in handlers], MiddlewareMetadata(client_call_details.metadata) + ) + + new_client_call_details = create_client_call_details( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata.get_grpc_metadata(), + credentials=client_call_details.credentials, + wait_for_ready=client_call_details.wait_for_ready, + ) + + if isinstance(request, Message): + middleware_message = await self.apply_handler_methods( + [handler.on_request_body for handler in handlers], MiddlewareMessage(request) + ) + request = middleware_message.get_message() + + call = await continuation(new_client_call_details, request) + try: + initial_metadata = await call.initial_metadata() + response_metadata = await self.apply_handler_methods( + [handler.on_response_metadata for handler in reversed_handlers], MiddlewareMetadata(initial_metadata) + ) + initial_metadata = response_metadata.get_grpc_metadata() + + # if the call returns an error, awaiting it will raise an RpcError, which we handle below + original_response = await call + + if isinstance(original_response, Message): + middleware_response = await self.apply_handler_methods( + [handler.on_response_body for handler in reversed_handlers], MiddlewareMessage(original_response) + ) + response = middleware_response.get_message() + else: + response = original_response + + status_code = await call.code() + middleware_status = await self.apply_handler_methods( + [handler.on_response_status for handler in reversed_handlers], MiddlewareStatus(status_code) + ) + status_code = middleware_status.grpc_status + + return _ProcessedResponseCall(call, status_code, response, initial_metadata) + except grpc.RpcError as e: + status = MiddlewareStatus(e.code()) + await self.apply_handler_methods([handler.on_response_status for handler in reversed_handlers], status) + + return _ProcessedResponseCall(call, e.code(), error=e) diff --git a/src/momento/internal/aio/_scs_grpc_manager.py b/src/momento/internal/aio/_scs_grpc_manager.py index 8a73ec61..082c36bb 100644 --- a/src/momento/internal/aio/_scs_grpc_manager.py +++ b/src/momento/internal/aio/_scs_grpc_manager.py @@ -1,7 +1,8 @@ from __future__ import annotations import asyncio -from typing import Optional +import uuid +from typing import List, Optional import grpc from momento_wire_types import cacheclient_pb2_grpc as cache_client @@ -23,14 +24,18 @@ grpc_topic_channel_options_from_grpc_config, ) from momento.internal.services import Service -from momento.retry import RetryStrategy from ... import logs +from ...config.middleware import MiddlewareRequestHandlerContext +from ...config.middleware.aio import Middleware +from ...config.middleware.models import CONNECTION_ID_KEY +from ...retry import RetryStrategy from ._add_header_client_interceptor import ( AddHeaderClientInterceptor, AddHeaderStreamingClientInterceptor, Header, ) +from ._middleware_interceptor import MiddlewareInterceptor from ._retry_interceptor import RetryInterceptor @@ -43,7 +48,10 @@ def __init__(self, configuration: Configuration, credential_provider: Credential target=credential_provider.control_endpoint, credentials=channel_credentials_from_root_certs_or_default(configuration), interceptors=_interceptors( - credential_provider.auth_token, ClientType.CACHE, configuration.get_retry_strategy() + credential_provider.auth_token, + ClientType.CACHE, + configuration.get_aio_middlewares(), + configuration.get_retry_strategy(), ), options=grpc_control_channel_options_from_grpc_config( grpc_config=configuration.get_transport_strategy().get_grpc_configuration(), @@ -53,7 +61,10 @@ def __init__(self, configuration: Configuration, credential_provider: Credential self._channel = grpc.aio.insecure_channel( target=f"{credential_provider.control_endpoint}:{credential_provider.port}", interceptors=_interceptors( - credential_provider.auth_token, ClientType.CACHE, configuration.get_retry_strategy() + credential_provider.auth_token, + ClientType.CACHE, + configuration.get_aio_middlewares(), + configuration.get_retry_strategy(), ), options=grpc_control_channel_options_from_grpc_config( grpc_config=configuration.get_transport_strategy().get_grpc_configuration(), @@ -77,7 +88,10 @@ def __init__(self, configuration: Configuration, credential_provider: Credential target=credential_provider.cache_endpoint, credentials=channel_credentials_from_root_certs_or_default(configuration), interceptors=_interceptors( - credential_provider.auth_token, ClientType.CACHE, configuration.get_retry_strategy() + credential_provider.auth_token, + ClientType.CACHE, + configuration.get_aio_middlewares(), + configuration.get_retry_strategy(), ), # Here is where you would pass override configuration to the underlying C gRPC layer. # However, I have tried several different tuning options here and did not see any @@ -101,7 +115,10 @@ def __init__(self, configuration: Configuration, credential_provider: Credential self._channel = grpc.aio.insecure_channel( target=f"{credential_provider.cache_endpoint}:{credential_provider.port}", interceptors=_interceptors( - credential_provider.auth_token, ClientType.CACHE, configuration.get_retry_strategy() + credential_provider.auth_token, + ClientType.CACHE, + configuration.get_aio_middlewares(), + configuration.get_retry_strategy(), ), options=grpc_data_channel_options_from_grpc_config( configuration.get_transport_strategy().get_grpc_configuration() @@ -115,7 +132,7 @@ async def eagerly_connect(self, timeout_seconds: float) -> None: try: await asyncio.wait_for(self.wait_for_ready(), timeout_seconds) except Exception as error: - self._channel.close() + await self._channel.close() self._logger.debug(f"Failed to connect to the server within the given timeout. {error}") raise ConnectionException( message=f"Failed to connect to Momento's server within given eager connection timeout: {error}", @@ -161,7 +178,7 @@ def __init__(self, configuration: TopicConfiguration, credential_provider: Crede self._channel = grpc.aio.secure_channel( target=credential_provider.cache_endpoint, credentials=grpc.ssl_channel_credentials(), - interceptors=_interceptors(credential_provider.auth_token, ClientType.TOPIC, None), + interceptors=_interceptors(credential_provider.auth_token, ClientType.TOPIC, [], None), options=grpc_topic_channel_options_from_grpc_config( configuration.get_transport_strategy().get_grpc_configuration() ), @@ -169,7 +186,7 @@ def __init__(self, configuration: TopicConfiguration, credential_provider: Crede else: self._channel = grpc.aio.insecure_channel( target=f"{credential_provider.cache_endpoint}:{credential_provider.port}", - interceptors=_interceptors(credential_provider.auth_token, ClientType.TOPIC, None), + interceptors=_interceptors(credential_provider.auth_token, ClientType.TOPIC, [], None), options=grpc_topic_channel_options_from_grpc_config( configuration.get_transport_strategy().get_grpc_configuration() ), @@ -220,7 +237,7 @@ def __init__(self, configuration: AuthConfiguration, credential_provider: Creden target=credential_provider.token_endpoint, credentials=grpc.ssl_channel_credentials(), interceptors=_interceptors( - credential_provider.auth_token, ClientType.TOKEN, configuration.get_retry_strategy() + credential_provider.auth_token, ClientType.TOKEN, [], configuration.get_retry_strategy() ), options=grpc_control_channel_options_from_grpc_config( grpc_config=configuration.get_transport_strategy().get_grpc_configuration(), @@ -230,7 +247,7 @@ def __init__(self, configuration: AuthConfiguration, credential_provider: Creden self._channel = grpc.aio.insecure_channel( target=f"{credential_provider.token_endpoint}:{credential_provider.port}", interceptors=_interceptors( - credential_provider.auth_token, ClientType.TOKEN, configuration.get_retry_strategy() + credential_provider.auth_token, ClientType.TOKEN, [], configuration.get_retry_strategy() ), options=grpc_control_channel_options_from_grpc_config( grpc_config=configuration.get_transport_strategy().get_grpc_configuration(), @@ -245,10 +262,15 @@ def async_stub(self) -> token_client.TokenStub: def _interceptors( - auth_token: str, client_type: ClientType, retry_strategy: Optional[RetryStrategy] = None + auth_token: str, + client_type: ClientType, + middleware: List[Middleware], + retry_strategy: Optional[RetryStrategy] = None, ) -> list[grpc.aio.ClientInterceptor]: from momento import __version__ as momento_version + context = MiddlewareRequestHandlerContext({CONNECTION_ID_KEY: str(uuid.uuid4())}) + headers = [ Header("authorization", auth_token), Header("agent", f"python:{client_type.value}:{momento_version}"), @@ -260,6 +282,7 @@ def _interceptors( [ AddHeaderClientInterceptor(headers), RetryInterceptor(retry_strategy) if retry_strategy else None, + MiddlewareInterceptor(middleware, context) if middleware else None, ], ) ) diff --git a/src/momento/internal/aio/_utilities.py b/src/momento/internal/aio/_utilities.py index 8367f39c..f38c817c 100644 --- a/src/momento/internal/aio/_utilities.py +++ b/src/momento/internal/aio/_utilities.py @@ -1,5 +1,80 @@ -from grpc.aio import Metadata +from typing import Optional + +import grpc +from grpc.aio import ClientCallDetails, Metadata + +from momento.errors import InvalidArgumentException +from momento.internal.services import Service def make_metadata(cache_name: str) -> Metadata: return Metadata(("cache", cache_name)) + + +def sanitize_client_call_details(client_call_details: grpc.aio.ClientCallDetails) -> grpc.aio.ClientCallDetails: + """Defensive function meant to handle inbound gRPC client request objects. + + Args: + client_call_details: the original inbound client grpc request we are intercepting + + Returns: a new client_call_details object with metadata properly initialized to a `grpc.aio.Metadata` object + """ + # Makes sure we can handle properly when we inject our own metadata onto request object. + # This was mainly done as temporary fix after we observed ddtrace grpc client interceptor passing + # client_call_details.metadata as a list instead of a grpc.aio.Metadata object. + # See this ticket for follow-up actions to come back in and address this longer term: + # https://github.com/momentohq/client-sdk-python/issues/149 + new_client_call_details = None + # If no metadata set on passed in client call details then we are first to set, so we should just initialize + if client_call_details.metadata is None: + new_client_call_details = create_client_call_details( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=Metadata(), + credentials=client_call_details.credentials, + wait_for_ready=client_call_details.wait_for_ready, + ) + + # This is block hit when ddtrace interceptor runs first and sets metadata as a list + elif isinstance(client_call_details.metadata, list): + existing_headers = client_call_details.metadata + metadata = Metadata() + # re-add all existing values to new metadata + for md_key, md_value in existing_headers: + metadata.add(md_key, md_value) + new_client_call_details = create_client_call_details( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata, + credentials=client_call_details.credentials, + wait_for_ready=client_call_details.wait_for_ready, + ) + elif isinstance(client_call_details.metadata, grpc.aio.Metadata): + # If proper grpc `grpc.aio.Metadata()` object is passed just use original object passed and pass back + new_client_call_details = client_call_details + else: + # Else we raise exception for now since we don't know how to handle an unknown type + raise InvalidArgumentException( + "unexpected grpc client request metadata property passed to interceptor " + "type=" + str(type(client_call_details.metadata)), + Service.AUTH, + ) + + return new_client_call_details + + +# noinspection PyArgumentList +def create_client_call_details( + method: str, + timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], +) -> ClientCallDetails: + return ClientCallDetails( + method=method, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + ) diff --git a/src/momento/internal/synchronous/_add_header_client_interceptor.py b/src/momento/internal/synchronous/_add_header_client_interceptor.py index 241da918..cafd7437 100644 --- a/src/momento/internal/synchronous/_add_header_client_interceptor.py +++ b/src/momento/internal/synchronous/_add_header_client_interceptor.py @@ -1,12 +1,10 @@ from __future__ import annotations -import collections from typing import Callable, TypeVar import grpc -from momento.errors import InvalidArgumentException -from momento.internal.services import Service +from momento.internal.synchronous._utilities import sanitize_client_call_details RequestType = TypeVar("RequestType") ResponseType = TypeVar("ResponseType") @@ -20,13 +18,6 @@ def __init__(self, name: str, value: str): self.value = value -class _ClientCallDetails( - collections.namedtuple("_ClientCallDetails", ("method", "timeout", "metadata", "credentials")), - grpc.ClientCallDetails, -): - pass - - class AddHeaderStreamingClientInterceptor(grpc.UnaryStreamClientInterceptor): are_only_once_headers_sent = False @@ -92,37 +83,3 @@ def intercept_unary_unary( AddHeaderClientInterceptor.are_only_once_headers_sent = True return continuation(new_client_call_details, request) - - -def sanitize_client_call_details(client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails: - """Defensive function meant to handle inbound gRPC client request objects. - - Args: - client_call_details: the original inbound client grpc request we are intercepting - - Returns: a new client_call_details object with metadata properly initialized to a `grpc.aio.Metadata` object - """ - # Makes sure we can handle properly when we inject our own metadata onto request object. - # This was mainly done as temporary fix after we observed ddtrace grpc client interceptor passing - # client_call_details.metadata as a list instead of a grpc.aio.Metadata object. - # See this ticket for follow-up actions to come back in and address this longer term: - # https://github.com/momentohq/client-sdk-python/issues/149 - # If no metadata set on passed in client call details then we are first to set, so we should just initialize - if client_call_details.metadata is None: - return _ClientCallDetails( - method=client_call_details.method, - timeout=client_call_details.timeout, - metadata=[], - credentials=client_call_details.credentials, - ) - - # This is block hit when ddtrace interceptor runs first and sets metadata as a list - elif isinstance(client_call_details.metadata, list): - return client_call_details - else: - # Else we raise exception for now since we don't know how to handle an unknown type - raise InvalidArgumentException( - "unexpected grpc client request metadata property passed to interceptor " - "type=" + str(type(client_call_details.metadata)), - Service.AUTH, - ) diff --git a/src/momento/internal/synchronous/_middleware_interceptor.py b/src/momento/internal/synchronous/_middleware_interceptor.py new file mode 100644 index 00000000..ef68bf3a --- /dev/null +++ b/src/momento/internal/synchronous/_middleware_interceptor.py @@ -0,0 +1,142 @@ +from types import MethodType +from typing import Callable, List, Optional, TypeVar, Union, cast + +import grpc +from google.protobuf.message import Message +from grpc import StatusCode +from grpc._interceptor import _UnaryOutcome +from grpc._typing import MetadataType + +from momento import logs +from momento.config.middleware import ( + MiddlewareMessage, + MiddlewareRequestHandlerContext, + MiddlewareStatus, +) +from momento.config.middleware.synchronous import Middleware, MiddlewareMetadata, MiddlewareRequestHandler +from momento.internal.synchronous._utilities import _ClientCallDetails, sanitize_client_call_details + +RequestType = TypeVar("RequestType") +T = TypeVar("T") + + +class _UpdatedMetadataCall(grpc.Call): + _call: grpc.Call + _initial_metadata: Optional[MetadataType] + _code: StatusCode + + def __init__(self, call: grpc.Call, initial_metadata: Optional[MetadataType], status_code: StatusCode) -> None: + self._call = call + self._initial_metadata = initial_metadata + self._code = status_code + + def initial_metadata(self) -> Optional[MetadataType]: + return self._initial_metadata + + def trailing_metadata(self) -> Optional[MetadataType]: + return self._call.trailing_metadata() + + def code(self) -> Optional[grpc.StatusCode]: + return self._code + + def details(self) -> Optional[str]: + return self._call.details() # type: ignore[no-any-return] + + def is_active(self) -> bool: + return self._call.is_active() # type: ignore[no-any-return] + + def time_remaining(self) -> Optional[float]: + return self._call.time_remaining() # type: ignore[no-any-return] + + def cancel(self) -> bool: + return self._call.cancel() # type: ignore[no-any-return] + + def add_callback(self, callback) -> bool: # type: ignore[no-untyped-def] + return self._call.add_callback(callback) # type: ignore[no-any-return] + + +class MiddlewareInterceptor(grpc.UnaryUnaryClientInterceptor): + middlewares: List[Middleware] + context: MiddlewareRequestHandlerContext + + def __init__(self, middlewares: List[Middleware], context: MiddlewareRequestHandlerContext) -> None: + self._logger = logs.logger + self.middlewares = middlewares + self.context = context + + def apply_handler_methods(self, methods: List[Callable[[T], T]], original_input: T) -> T: + current_value = original_input + + for method in methods: + try: + current_value = method(current_value) + except Exception as e: + bound_method = cast(MethodType, method) + handler_info = f"{bound_method.__self__.__class__.__name__}.{method.__name__}" + self._logger.exception(f"Error in middleware method {handler_info}: {str(e)}") + + return current_value + + def intercept_unary_unary( + self, + continuation: Callable[[grpc.ClientCallDetails, RequestType], Union[grpc.Call, grpc.Future]], + client_call_details: grpc.ClientCallDetails, + request: RequestType, + ) -> Union[grpc.Call, grpc.Future]: + client_call_details = sanitize_client_call_details(client_call_details) + + handlers: List[MiddlewareRequestHandler] = [] + for middleware in self.middlewares: + handler = middleware.on_new_request(self.context) + handlers.append(handler) + reversed_handlers = handlers[::-1] + + metadata = self.apply_handler_methods( + [handler.on_request_metadata for handler in handlers], MiddlewareMetadata(client_call_details.metadata) + ) + + new_client_call_details = _ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata.get_grpc_metadata(), + credentials=client_call_details.credentials, + ) + + if isinstance(request, Message): + middleware_message = self.apply_handler_methods( + [handler.on_request_body for handler in handlers], MiddlewareMessage(request) + ) + request = middleware_message.get_message() + + try: + call = continuation(new_client_call_details, request) + + initial_metadata = call.initial_metadata() + response_metadata = self.apply_handler_methods( + [handler.on_response_metadata for handler in reversed_handlers], MiddlewareMetadata(initial_metadata) + ) + initial_metadata = response_metadata.get_grpc_metadata() + + # if the call returns an error, call.result() will raise an RpcError, which we handle below + response_body = call.result() + if isinstance(response_body, Message): + middleware_message = self.apply_handler_methods( + [handler.on_response_body for handler in reversed_handlers], MiddlewareMessage(response_body) + ) + response_body = middleware_message.get_message() + + status_code = call.code() + middleware_status = self.apply_handler_methods( + [handler.on_response_status for handler in reversed_handlers], MiddlewareStatus(status_code) + ) + status_code = middleware_status.grpc_status + + updated_call = _UpdatedMetadataCall(call, initial_metadata, status_code) + updated_outcome = _UnaryOutcome(response_body, updated_call) + + return updated_outcome + except grpc.RpcError as e: + status = MiddlewareStatus(e.code()) + self.apply_handler_methods([handler.on_response_status for handler in reversed_handlers], status) + + raise diff --git a/src/momento/internal/synchronous/_scs_grpc_manager.py b/src/momento/internal/synchronous/_scs_grpc_manager.py index 2686c5c0..ae3f2120 100644 --- a/src/momento/internal/synchronous/_scs_grpc_manager.py +++ b/src/momento/internal/synchronous/_scs_grpc_manager.py @@ -1,7 +1,8 @@ from __future__ import annotations +import uuid from threading import Event -from typing import Optional +from typing import List, Optional import grpc from momento_wire_types import cacheclient_pb2_grpc as cache_client @@ -13,6 +14,9 @@ from momento.auth import CredentialProvider from momento.config import Configuration, TopicConfiguration from momento.config.auth_configuration import AuthConfiguration +from momento.config.middleware import MiddlewareRequestHandlerContext +from momento.config.middleware.models import CONNECTION_ID_KEY +from momento.config.middleware.synchronous import Middleware from momento.errors.exceptions import ConnectionException from momento.internal._utilities import PYTHON_RUNTIME_VERSION, ClientType from momento.internal._utilities._channel_credentials import ( @@ -29,6 +33,7 @@ AddHeaderStreamingClientInterceptor, Header, ) +from momento.internal.synchronous._middleware_interceptor import MiddlewareInterceptor from momento.internal.synchronous._retry_interceptor import RetryInterceptor from momento.retry import RetryStrategy @@ -54,7 +59,12 @@ def __init__(self, configuration: Configuration, credential_provider: Credential ) intercept_channel = grpc.intercept_channel( self._channel, - *_interceptors(credential_provider.auth_token, ClientType.CACHE, configuration.get_retry_strategy()), + *_interceptors( + credential_provider.auth_token, + ClientType.CACHE, + configuration.get_sync_middlewares(), + configuration.get_retry_strategy(), + ), ) self._stub = control_client.ScsControlStub(intercept_channel) # type: ignore[no-untyped-call] @@ -88,7 +98,12 @@ def __init__(self, configuration: Configuration, credential_provider: Credential intercept_channel = grpc.intercept_channel( self._channel, - *_interceptors(credential_provider.auth_token, ClientType.CACHE, configuration.get_retry_strategy()), + *_interceptors( + credential_provider.auth_token, + ClientType.CACHE, + configuration.get_sync_middlewares(), + configuration.get_retry_strategy(), + ), ) self._stub = cache_client.ScsStub(intercept_channel) # type: ignore[no-untyped-call] @@ -184,7 +199,7 @@ def __init__(self, configuration: TopicConfiguration, credential_provider: Crede ), ) intercept_channel = grpc.intercept_channel( - self._channel, *_interceptors(credential_provider.auth_token, ClientType.TOPIC, None) + self._channel, *_interceptors(credential_provider.auth_token, ClientType.TOPIC, [], None) ) self._stub = pubsub_client.PubsubStub(intercept_channel) # type: ignore[no-untyped-call] @@ -247,7 +262,7 @@ def __init__(self, configuration: AuthConfiguration, credential_provider: Creden ) intercept_channel = grpc.intercept_channel( self._channel, - *_interceptors(credential_provider.auth_token, ClientType.TOKEN, configuration.get_retry_strategy()), + *_interceptors(credential_provider.auth_token, ClientType.TOKEN, [], configuration.get_retry_strategy()), ) self._stub = token_client.TokenStub(intercept_channel) # type: ignore[no-untyped-call] @@ -259,10 +274,15 @@ def stub(self) -> token_client.TokenStub: def _interceptors( - auth_token: str, client_type: ClientType, retry_strategy: Optional[RetryStrategy] = None + auth_token: str, + client_type: ClientType, + middleware: List[Middleware], + retry_strategy: Optional[RetryStrategy] = None, ) -> list[grpc.UnaryUnaryClientInterceptor]: from momento import __version__ as momento_version + context = MiddlewareRequestHandlerContext({CONNECTION_ID_KEY: str(uuid.uuid4())}) + headers = [ Header("authorization", auth_token), Header("agent", f"python:{client_type.value}:{momento_version}"), @@ -270,7 +290,12 @@ def _interceptors( ] return list( filter( - None, [AddHeaderClientInterceptor(headers), RetryInterceptor(retry_strategy) if retry_strategy else None] + None, + [ + AddHeaderClientInterceptor(headers), + RetryInterceptor(retry_strategy) if retry_strategy else None, + MiddlewareInterceptor(middleware, context) if middleware else None, + ], ) ) diff --git a/src/momento/internal/synchronous/_utilities.py b/src/momento/internal/synchronous/_utilities.py index 17ed5838..14332ebf 100644 --- a/src/momento/internal/synchronous/_utilities.py +++ b/src/momento/internal/synchronous/_utilities.py @@ -1,7 +1,59 @@ from __future__ import annotations -from typing import Tuple +import collections +from typing import Optional, Tuple + +import grpc +from grpc import CallCredentials +from grpc._typing import MetadataType + +from momento.errors import InvalidArgumentException +from momento.internal.services import Service def make_metadata(cache_name: str) -> list[Tuple[str, str]]: return [("cache", cache_name)] + + +class _ClientCallDetails( + collections.namedtuple("_ClientCallDetails", ("method", "timeout", "metadata", "credentials")), + grpc.ClientCallDetails, +): + def __new__( + cls, method: str, timeout: Optional[float], metadata: MetadataType, credentials: Optional[CallCredentials] + ) -> _ClientCallDetails: + return super().__new__(cls, method, timeout, metadata, credentials) + + +def sanitize_client_call_details(client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails: + """Defensive function meant to handle inbound gRPC client request objects. + + Args: + client_call_details: the original inbound client grpc request we are intercepting + + Returns: a new client_call_details object with metadata properly initialized to a `grpc.aio.Metadata` object + """ + # Makes sure we can handle properly when we inject our own metadata onto request object. + # This was mainly done as temporary fix after we observed ddtrace grpc client interceptor passing + # client_call_details.metadata as a list instead of a grpc.aio.Metadata object. + # See this ticket for follow-up actions to come back in and address this longer term: + # https://github.com/momentohq/client-sdk-python/issues/149 + # If no metadata set on passed in client call details then we are first to set, so we should just initialize + if client_call_details.metadata is None: + return _ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=[], + credentials=client_call_details.credentials, + ) + + # This is block hit when ddtrace interceptor runs first and sets metadata as a list + elif isinstance(client_call_details.metadata, list): + return client_call_details + else: + # Else we raise exception for now since we don't know how to handle an unknown type + raise InvalidArgumentException( + "unexpected grpc client request metadata property passed to interceptor " + "type=" + str(type(client_call_details.metadata)), + Service.AUTH, + )