diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py index ac6134a7..59b86e2a 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py @@ -3,6 +3,8 @@ import typing from enum import Enum +import grpc + class ResolverType(Enum): RPC = "rpc" @@ -45,9 +47,11 @@ class CacheType(Enum): ENV_VAR_RETRY_BACKOFF_MAX_MS = "FLAGD_RETRY_BACKOFF_MAX_MS" ENV_VAR_RETRY_GRACE_PERIOD_SECONDS = "FLAGD_RETRY_GRACE_PERIOD" ENV_VAR_SELECTOR = "FLAGD_SOURCE_SELECTOR" +ENV_VAR_PROVIDER_ID = "FLAGD_SOURCE_PROVIDER_ID" ENV_VAR_STREAM_DEADLINE_MS = "FLAGD_STREAM_DEADLINE_MS" ENV_VAR_TLS = "FLAGD_TLS" ENV_VAR_TLS_CERT = "FLAGD_SERVER_CERT_PATH" +ENV_VAR_DEFAULT_AUTHORITY = "FLAGD_DEFAULT_AUTHORITY" T = typing.TypeVar("T") @@ -81,6 +85,7 @@ def __init__( # noqa: PLR0913 port: typing.Optional[int] = None, tls: typing.Optional[bool] = None, selector: typing.Optional[str] = None, + provider_id: typing.Optional[str] = None, resolver: typing.Optional[ResolverType] = None, offline_flag_source_path: typing.Optional[str] = None, offline_poll_interval_ms: typing.Optional[int] = None, @@ -93,6 +98,8 @@ def __init__( # noqa: PLR0913 cache: typing.Optional[CacheType] = None, max_cache_size: typing.Optional[int] = None, cert_path: typing.Optional[str] = None, + default_authority: typing.Optional[str] = None, + channel_credentials: typing.Optional[grpc.ChannelCredentials] = None, ): self.host = env_or_default(ENV_VAR_HOST, DEFAULT_HOST) if host is None else host @@ -227,3 +234,17 @@ def __init__( # noqa: PLR0913 self.selector = ( env_or_default(ENV_VAR_SELECTOR, None) if selector is None else selector ) + + self.provider_id = ( + env_or_default(ENV_VAR_PROVIDER_ID, None) + if provider_id is None + else provider_id + ) + + self.default_authority = ( + env_or_default(ENV_VAR_DEFAULT_AUTHORITY, None) + if default_authority is None + else default_authority + ) + + self.channel_credentials = channel_credentials diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py index ae3cf323..59bad1c3 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py @@ -24,6 +24,8 @@ import typing import warnings +import grpc + from openfeature.evaluation_context import EvaluationContext from openfeature.event import ProviderEventDetails from openfeature.flag_evaluation import FlagResolutionDetails @@ -50,6 +52,7 @@ def __init__( # noqa: PLR0913 timeout: typing.Optional[int] = None, retry_backoff_ms: typing.Optional[int] = None, selector: typing.Optional[str] = None, + provider_id: typing.Optional[str] = None, resolver_type: typing.Optional[ResolverType] = None, offline_flag_source_path: typing.Optional[str] = None, stream_deadline_ms: typing.Optional[int] = None, @@ -59,6 +62,8 @@ def __init__( # noqa: PLR0913 retry_backoff_max_ms: typing.Optional[int] = None, retry_grace_period: typing.Optional[int] = None, cert_path: typing.Optional[str] = None, + default_authority: typing.Optional[str] = None, + channel_credentials: typing.Optional[grpc.ChannelCredentials] = None, ): """ Create an instance of the FlagdProvider @@ -91,6 +96,7 @@ def __init__( # noqa: PLR0913 retry_backoff_max_ms=retry_backoff_max_ms, retry_grace_period=retry_grace_period, selector=selector, + provider_id=provider_id, resolver=resolver_type, offline_flag_source_path=offline_flag_source_path, stream_deadline_ms=stream_deadline_ms, @@ -98,6 +104,8 @@ def __init__( # noqa: PLR0913 cache=cache, max_cache_size=max_cache_size, cert_path=cert_path, + default_authority=default_authority, + channel_credentials=channel_credentials, ) self.enriched_context: dict = {} diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py index f5aeba22..7b611dd4 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py @@ -42,6 +42,7 @@ def __init__( self.streamline_deadline_seconds = config.stream_deadline_ms * 0.001 self.deadline = config.deadline_ms * 0.001 self.selector = config.selector + self.provider_id = config.provider_id self.emit_provider_ready = emit_provider_ready self.emit_provider_error = emit_provider_error self.emit_provider_stale = emit_provider_stale @@ -55,13 +56,23 @@ def __init__( def _generate_channel(self, config: Config) -> grpc.Channel: target = f"{config.host}:{config.port}" # Create the channel with the service config - options = [ + options: list[tuple[str, typing.Any]] = [ ("grpc.keepalive_time_ms", config.keep_alive_time), ("grpc.initial_reconnect_backoff_ms", config.retry_backoff_ms), ("grpc.max_reconnect_backoff_ms", config.retry_backoff_max_ms), ("grpc.min_reconnect_backoff_ms", config.stream_deadline_ms), ] - if config.tls: + if config.default_authority is not None: + options.append(("grpc.default_authority", config.default_authority)) + + if config.channel_credentials is not None: + channel_args = { + "options": options, + "credentials": config.channel_credentials, + } + channel = grpc.secure_channel(target, **channel_args) + + elif config.tls: channel_args = { "options": options, "credentials": grpc.ssl_channel_credentials(), @@ -157,7 +168,11 @@ def listen(self) -> None: if self.streamline_deadline_seconds > 0 else {} ) - request_args = {"selector": self.selector} if self.selector is not None else {} + request_args = {} + if self.selector is not None: + request_args["selector"] = self.selector + if self.provider_id is not None: + request_args["provider_id"] = self.provider_id while self.active: try: