Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 2 additions & 25 deletions src/momento/config/auth_configuration.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,13 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from datetime import timedelta

from momento.config.transport.transport_strategy import TransportStrategy
from momento.retry.retry_strategy import RetryStrategy


class AuthConfigurationBase(ABC):
@abstractmethod
def get_retry_strategy(self) -> RetryStrategy:
pass

@abstractmethod
def with_retry_strategy(self, retry_strategy: RetryStrategy) -> AuthConfiguration:
pass

@abstractmethod
def get_transport_strategy(self) -> TransportStrategy:
pass

@abstractmethod
def with_transport_strategy(self, transport_strategy: TransportStrategy) -> AuthConfiguration:
pass

@abstractmethod
def with_client_timeout(self, client_timeout: timedelta) -> AuthConfiguration:
pass


class AuthConfiguration(AuthConfigurationBase):
"""AuthConfiguration options for Momento Simple Cache Client."""
class AuthConfiguration:
"""AuthConfiguration options for Momento Auth Client."""

def __init__(self, transport_strategy: TransportStrategy, retry_strategy: RetryStrategy):
"""Instantiate a AuthConfiguration.
Expand Down
61 changes: 10 additions & 51 deletions src/momento/config/configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from datetime import timedelta
from pathlib import Path
from typing import List, Optional
Expand All @@ -12,46 +11,8 @@
from .transport.transport_strategy import TransportStrategy


class ConfigurationBase(ABC):
@abstractmethod
def get_retry_strategy(self) -> RetryStrategy:
pass

@abstractmethod
def with_retry_strategy(self, retry_strategy: RetryStrategy) -> Configuration:
pass

@abstractmethod
def get_transport_strategy(self) -> TransportStrategy:
pass

@abstractmethod
def with_transport_strategy(self, transport_strategy: TransportStrategy) -> Configuration:
pass

@abstractmethod
def with_client_timeout(self, client_timeout: timedelta) -> Configuration:
pass

@abstractmethod
def with_root_certificates_pem(self, root_certificate_path: Path) -> Configuration:
pass

@abstractmethod
def with_middlewares(self, middlewares: List[Middleware]) -> Configuration:
pass

@abstractmethod
def add_middleware(self, middleware: Middleware) -> Configuration:
pass

@abstractmethod
def get_middlewares(self) -> List[Middleware]:
pass


class Configuration(ConfigurationBase):
"""Configuration options for Momento Simple Cache Client."""
class Configuration:
"""Configuration options for Momento Cache Client."""

def __init__(
self,
Expand Down Expand Up @@ -140,24 +101,22 @@ def with_root_certificates_pem(self, root_certificates_pem_path: Path) -> Config
return self.with_transport_strategy(transport_strategy)

def with_middlewares(self, middlewares: List[Middleware]) -> Configuration:
"""Copies the Configuration and adds the new middlewares to the end of the list.
"""Copies the Configuration and replaces the middleware with the given middleware list.

Args:
middlewares: the middleware list to be appended to the Configuration's existing middleware. These can be
aio or synchronous middleware.
middlewares: the new middleware list. It can contain async or synchronous middleware.

Returns:
Configuration: the new Configuration.
"""
new_middlewares = self._middlewares.copy() + middlewares
return Configuration(self._transport_strategy, self._retry_strategy, new_middlewares)
return Configuration(self._transport_strategy, self._retry_strategy, middlewares)

def add_middleware(self, middleware: Middleware) -> Configuration:
"""Copies the Configuration and adds the new middleware to the end of the list.

Args:
middleware: the middleware to be appended to the Configuration's existing middleware. This can be aio or
synchronous middleware.
middleware: the middleware to be appended to the Configuration's existing middleware. This can be an async
or synchronous middleware.

Returns:
Configuration: the new Configuration.
Expand All @@ -173,11 +132,11 @@ def get_middlewares(self) -> List[Middleware]:
"""
return self._middlewares.copy()

def get_aio_middlewares(self) -> List[momento.config.middleware.aio.Middleware]:
"""Access the aio middleware from the middleware list.
def get_async_middlewares(self) -> List[momento.config.middleware.aio.Middleware]:
"""Access the async middleware from the middleware list.

Returns:
the configuration's list of aio middleware.
the configuration's list of async middleware.
"""
return [m for m in self._middlewares if isinstance(m, momento.config.middleware.aio.Middleware)]

Expand Down
9 changes: 3 additions & 6 deletions src/momento/config/middleware/aio/middleware_metadata.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from dataclasses import dataclass
from typing import Optional

from grpc.aio import Metadata


@dataclass
class MiddlewareMetadata:
"""Wrapper for gRPC metadata."""

def __init__(self, metadata: Optional[Metadata]):
self.grpc_metadata = metadata

def get_grpc_metadata(self) -> Optional[Metadata]:
"""Get the underlying gRPC metadata."""
return self.grpc_metadata
grpc_metadata: Optional[Metadata]
34 changes: 13 additions & 21 deletions src/momento/config/middleware/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Dict

import grpc
Expand All @@ -6,41 +7,32 @@
CONNECTION_ID_KEY = "connectionID"


@dataclass
class MiddlewareMessage:
"""Wrapper for a gRPC protobuf message."""

def __init__(self, message: Message):
self.grpc_message = message
grpc_message: Message

def get_message_length(self) -> int:
"""Get the length of the message in bytes."""
@property
def message_length(self) -> int:
"""Length of the message in bytes."""
return len(self.grpc_message.SerializeToString())

def get_constructor_name(self) -> str:
"""Get the class name of the message."""
@property
def constructor_name(self) -> str:
"""The class name of the message."""
return str(self.grpc_message.__class__.__name__)

def get_message(self) -> Message:
"""Get the underlying gRPC message."""
return self.grpc_message


@dataclass
class MiddlewareStatus:
"""Wrapper for gRPC status."""

def __init__(self, status: grpc.StatusCode):
self.grpc_status = status

def get_code(self) -> grpc.StatusCode:
"""Get the status code."""
return self.grpc_status
grpc_status: grpc.StatusCode


@dataclass
class MiddlewareRequestHandlerContext:
"""Context for middleware request handlers."""

def __init__(self, context: Dict[str, str]):
self.context = context

def get_context(self) -> Dict[str, str]:
return self.context
context: Dict[str, str]
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from dataclasses import dataclass
from typing import Optional

from grpc._typing import MetadataType


@dataclass
class MiddlewareMetadata:
"""Wrapper for gRPC metadata."""

def __init__(self, metadata: Optional[MetadataType]):
self.grpc_metadata = metadata

def get_grpc_metadata(self) -> Optional[MetadataType]:
"""Get the underlying gRPC metadata."""
return self.grpc_metadata
grpc_metadata: Optional[MetadataType]
25 changes: 1 addition & 24 deletions src/momento/config/topic_configuration.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,11 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from datetime import timedelta

from momento.config.transport.topic_transport_strategy import TopicTransportStrategy


class TopicConfigurationBase(ABC):
@abstractmethod
def get_max_subscriptions(self) -> int:
pass

@abstractmethod
def with_max_subscriptions(self, max_subscriptions: int) -> TopicConfiguration:
pass

@abstractmethod
def get_transport_strategy(self) -> TopicTransportStrategy:
pass

@abstractmethod
def with_transport_strategy(self, transport_strategy: TopicTransportStrategy) -> TopicConfiguration:
pass

@abstractmethod
def with_client_timeout(self, client_timeout: timedelta) -> TopicConfiguration:
pass


class TopicConfiguration(TopicConfigurationBase):
class TopicConfiguration:
"""Configuration options for Momento topic client."""

def __init__(self, transport_strategy: TopicTransportStrategy, max_subscriptions: int = 0):
Expand Down
8 changes: 4 additions & 4 deletions src/momento/internal/aio/_middleware_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def intercept_unary_unary(
new_client_call_details = create_client_call_details(
method=client_call_details.method,
timeout=client_call_details.timeout,
metadata=metadata.get_grpc_metadata(),
metadata=metadata.grpc_metadata,
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready,
)
Expand All @@ -127,15 +127,15 @@ async def intercept_unary_unary(
middleware_message = await self.apply_handler_methods(
[handler.on_request_body for handler in handlers], MiddlewareMessage(request)
)
request = middleware_message.get_message()
request = middleware_message.grpc_message

call = await continuation(new_client_call_details, request)
try:
initial_metadata = await call.initial_metadata()
response_metadata = await self.apply_handler_methods(
[handler.on_response_metadata for handler in reversed_handlers], MiddlewareMetadata(initial_metadata)
)
initial_metadata = response_metadata.get_grpc_metadata()
initial_metadata = response_metadata.grpc_metadata

# if the call returns an error, awaiting it will raise an RpcError, which we handle below
original_response = await call
Expand All @@ -144,7 +144,7 @@ async def intercept_unary_unary(
middleware_response = await self.apply_handler_methods(
[handler.on_response_body for handler in reversed_handlers], MiddlewareMessage(original_response)
)
response = middleware_response.get_message()
response = middleware_response.grpc_message
else:
response = original_response

Expand Down
8 changes: 4 additions & 4 deletions src/momento/internal/aio/_scs_grpc_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, configuration: Configuration, credential_provider: Credential
interceptors=_interceptors(
credential_provider.auth_token,
ClientType.CACHE,
configuration.get_aio_middlewares(),
configuration.get_async_middlewares(),
configuration.get_retry_strategy(),
),
options=grpc_control_channel_options_from_grpc_config(
Expand All @@ -63,7 +63,7 @@ def __init__(self, configuration: Configuration, credential_provider: Credential
interceptors=_interceptors(
credential_provider.auth_token,
ClientType.CACHE,
configuration.get_aio_middlewares(),
configuration.get_async_middlewares(),
configuration.get_retry_strategy(),
),
options=grpc_control_channel_options_from_grpc_config(
Expand All @@ -90,7 +90,7 @@ def __init__(self, configuration: Configuration, credential_provider: Credential
interceptors=_interceptors(
credential_provider.auth_token,
ClientType.CACHE,
configuration.get_aio_middlewares(),
configuration.get_async_middlewares(),
configuration.get_retry_strategy(),
),
# Here is where you would pass override configuration to the underlying C gRPC layer.
Expand All @@ -117,7 +117,7 @@ def __init__(self, configuration: Configuration, credential_provider: Credential
interceptors=_interceptors(
credential_provider.auth_token,
ClientType.CACHE,
configuration.get_aio_middlewares(),
configuration.get_async_middlewares(),
configuration.get_retry_strategy(),
),
options=grpc_data_channel_options_from_grpc_config(
Expand Down
8 changes: 4 additions & 4 deletions src/momento/internal/synchronous/_middleware_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ def intercept_unary_unary(
new_client_call_details = _ClientCallDetails(
method=client_call_details.method,
timeout=client_call_details.timeout,
metadata=metadata.get_grpc_metadata(),
metadata=metadata.grpc_metadata,
credentials=client_call_details.credentials,
)

if isinstance(request, Message):
middleware_message = self.apply_handler_methods(
[handler.on_request_body for handler in handlers], MiddlewareMessage(request)
)
request = middleware_message.get_message()
request = middleware_message.grpc_message

try:
call = continuation(new_client_call_details, request)
Expand All @@ -115,15 +115,15 @@ def intercept_unary_unary(
response_metadata = self.apply_handler_methods(
[handler.on_response_metadata for handler in reversed_handlers], MiddlewareMetadata(initial_metadata)
)
initial_metadata = response_metadata.get_grpc_metadata()
initial_metadata = response_metadata.grpc_metadata

# if the call returns an error, call.result() will raise an RpcError, which we handle below
response_body = call.result()
if isinstance(response_body, Message):
middleware_message = self.apply_handler_methods(
[handler.on_response_body for handler in reversed_handlers], MiddlewareMessage(response_body)
)
response_body = middleware_message.get_message()
response_body = middleware_message.grpc_message

status_code = call.code()
middleware_status = self.apply_handler_methods(
Expand Down
Loading