diff --git a/openfeature/client.py b/openfeature/client.py index d73a3800..29e3a32a 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -1,5 +1,6 @@ import logging import typing +from collections.abc import Awaitable from dataclasses import dataclass from openfeature import _event_support, api @@ -37,46 +38,6 @@ logger = logging.getLogger("openfeature") -GetDetailCallable = typing.Union[ - typing.Callable[ - [str, bool, typing.Optional[EvaluationContext]], FlagResolutionDetails[bool] - ], - typing.Callable[ - [str, int, typing.Optional[EvaluationContext]], FlagResolutionDetails[int] - ], - typing.Callable[ - [str, float, typing.Optional[EvaluationContext]], FlagResolutionDetails[float] - ], - typing.Callable[ - [str, str, typing.Optional[EvaluationContext]], FlagResolutionDetails[str] - ], - typing.Callable[ - [str, typing.Union[dict, list], typing.Optional[EvaluationContext]], - FlagResolutionDetails[typing.Union[dict, list]], - ], -] -GetDetailCallableAsync = typing.Union[ - typing.Callable[ - [str, bool, typing.Optional[EvaluationContext]], - typing.Awaitable[FlagResolutionDetails[bool]], - ], - typing.Callable[ - [str, int, typing.Optional[EvaluationContext]], - typing.Awaitable[FlagResolutionDetails[int]], - ], - typing.Callable[ - [str, float, typing.Optional[EvaluationContext]], - typing.Awaitable[FlagResolutionDetails[float]], - ], - typing.Callable[ - [str, str, typing.Optional[EvaluationContext]], - typing.Awaitable[FlagResolutionDetails[str]], - ], - typing.Callable[ - [str, typing.Union[dict, list], typing.Optional[EvaluationContext]], - typing.Awaitable[FlagResolutionDetails[typing.Union[dict, list]]], - ], -] TypeMap = dict[ FlagType, typing.Union[ @@ -88,6 +49,26 @@ ], ] +T = typing.TypeVar("T", bool, int, float, str, typing.Union[dict, list]) + + +class ResolveDetailsCallable(typing.Protocol[T]): + def __call__( + self, + flag_key: str, + default_value: T, + evaluation_context: typing.Optional[EvaluationContext], + ) -> FlagResolutionDetails[T]: ... + + +class ResolveDetailsCallableAsync(typing.Protocol[T]): + def __call__( + self, + flag_key: str, + default_value: T, + evaluation_context: typing.Optional[EvaluationContext], + ) -> Awaitable[FlagResolutionDetails[T]]: ... + @dataclass class ClientMetadata: @@ -702,7 +683,7 @@ async def _create_provider_evaluation_async( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagEvaluationDetails[typing.Any]: get_details_callables_async: typing.Mapping[ - FlagType, GetDetailCallableAsync + FlagType, ResolveDetailsCallableAsync ] = { FlagType.BOOLEAN: provider.resolve_boolean_details_async, FlagType.INTEGER: provider.resolve_integer_details_async, @@ -714,7 +695,7 @@ async def _create_provider_evaluation_async( if not get_details_callable: raise GeneralError(error_message="Unknown flag type") - resolution = await get_details_callable( # type: ignore[call-arg] + resolution = await get_details_callable( flag_key=flag_key, default_value=default_value, evaluation_context=evaluation_context, @@ -752,7 +733,7 @@ def _create_provider_evaluation( :return: a FlagEvaluationDetails object with the fully evaluated flag from a provider """ - get_details_callables: typing.Mapping[FlagType, GetDetailCallable] = { + get_details_callables: typing.Mapping[FlagType, ResolveDetailsCallable] = { FlagType.BOOLEAN: provider.resolve_boolean_details, FlagType.INTEGER: provider.resolve_integer_details, FlagType.FLOAT: provider.resolve_float_details, @@ -764,7 +745,7 @@ def _create_provider_evaluation( if not get_details_callable: raise GeneralError(error_message="Unknown flag type") - resolution = get_details_callable( # type: ignore[call-arg] + resolution = get_details_callable( flag_key=flag_key, default_value=default_value, evaluation_context=evaluation_context,