diff --git a/openfeature/api.py b/openfeature/api.py index 4460b695..19736733 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -1,7 +1,7 @@ import typing from openfeature import _event_support -from openfeature.client import OpenFeatureClient +from openfeature.client import AsyncOpenFeatureClient, OpenFeatureClient from openfeature.evaluation_context import EvaluationContext from openfeature.event import ( EventHandler, @@ -49,6 +49,12 @@ def get_client( return OpenFeatureClient(domain=domain, version=version) +def get_async_client( + domain: typing.Optional[str] = None, version: typing.Optional[str] = None +) -> AsyncOpenFeatureClient: + return AsyncOpenFeatureClient(domain=domain, version=version) + + def set_provider( provider: FeatureProvider, domain: typing.Optional[str] = None ) -> None: diff --git a/openfeature/client.py b/openfeature/client.py index 1edfca63..f415d38f 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -33,6 +33,7 @@ __all__ = [ "ClientMetadata", "OpenFeatureClient", + "AsyncOpenFeatureClient", ] logger = logging.getLogger("openfeature") @@ -464,3 +465,353 @@ def _typecheck_flag_value(value: typing.Any, flag_type: FlagType) -> None: raise GeneralError(error_message="Unknown flag type") if not isinstance(value, _type): raise TypeMismatchError(f"Expected type {_type} but got {type(value)}") + + +class AsyncOpenFeatureClient(OpenFeatureClient): + async def get_boolean_value( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> bool: + details = await self.get_boolean_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + + async def get_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[bool]: + return await self.evaluate_flag_details( + FlagType.BOOLEAN, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_string_value( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> str: + details = await self.get_string_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + + async def get_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[str]: + return await self.evaluate_flag_details( + FlagType.STRING, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_integer_value( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> int: + details = await self.get_integer_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + + async def get_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[int]: + return await self.evaluate_flag_details( + FlagType.INTEGER, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_float_value( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> float: + details = await self.get_float_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + + async def get_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[float]: + return await self.evaluate_flag_details( + FlagType.FLOAT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_object_value( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> typing.Union[dict, list]: + details = await self.get_object_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + + async def get_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Union[dict, list]]: + return await self.evaluate_flag_details( + FlagType.OBJECT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def evaluate_flag_details( # noqa: PLR0915 + self, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Evaluate the flag requested by the user from the clients provider. + + :param flag_type: the type of the flag being returned + :param flag_key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :param flag_evaluation_options: Additional flag evaluation information + :return: a FlagEvaluationDetails object with the fully evaluated flag from a + provider + """ + if evaluation_context is None: + evaluation_context = EvaluationContext() + + if flag_evaluation_options is None: + flag_evaluation_options = FlagEvaluationOptions() + + provider = self.provider # call this once to maintain a consistent reference + evaluation_hooks = flag_evaluation_options.hooks + hook_hints = flag_evaluation_options.hook_hints + + hook_context = HookContext( + flag_key=flag_key, + flag_type=flag_type, + default_value=default_value, + evaluation_context=evaluation_context, + client_metadata=self.get_metadata(), + provider_metadata=provider.get_metadata(), + ) + # Hooks need to be handled in different orders at different stages + # in the flag evaluation + # before: API, Client, Invocation, Provider + merged_hooks = ( + api.get_hooks() + + self.hooks + + evaluation_hooks + + provider.get_provider_hooks() + ) + # after, error, finally: Provider, Invocation, Client, API + reversed_merged_hooks = merged_hooks[:] + reversed_merged_hooks.reverse() + + status = self.get_provider_status() + if status == ProviderStatus.NOT_READY: + error_hooks( + flag_type, + hook_context, + ProviderNotReadyError(), + reversed_merged_hooks, + hook_hints, + ) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.PROVIDER_NOT_READY, + ) + if status == ProviderStatus.FATAL: + error_hooks( + flag_type, + hook_context, + ProviderFatalError(), + reversed_merged_hooks, + hook_hints, + ) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.PROVIDER_FATAL, + ) + + try: + # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md + # Any resulting evaluation context from a before hook will overwrite + # duplicate fields defined globally, on the client, or in the invocation. + # Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context + invocation_context = before_hooks( + flag_type, hook_context, merged_hooks, hook_hints + ) + + invocation_context = invocation_context.merge(ctx2=evaluation_context) + # Requirement 3.2.2 merge: API.context->client.context->invocation.context + merged_context = ( + api.get_evaluation_context() + .merge(self.context) + .merge(invocation_context) + ) + + flag_evaluation = await self._create_provider_evaluation( + provider, + flag_type, + flag_key, + default_value, + merged_context, + ) + + after_hooks( + flag_type, + hook_context, + flag_evaluation, + reversed_merged_hooks, + hook_hints, + ) + + return flag_evaluation + + except OpenFeatureError as err: + error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=err.error_code, + error_message=err.error_message, + ) + # Catch any type of exception here since the user can provide any exception + # in the error hooks + except Exception as err: # pragma: no cover + logger.exception( + "Unable to correctly evaluate flag with key: '%s'", flag_key + ) + + error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + + error_message = getattr(err, "error_message", str(err)) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=error_message, + ) + + finally: + after_all_hooks(flag_type, hook_context, reversed_merged_hooks, hook_hints) + + async def _create_provider_evaluation( + self, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Asynchronous encapsulated method to create a FlagEvaluationDetail from a specific provider. + + :param flag_type: the type of the flag being returned + :param key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :return: a FlagEvaluationDetails object with the fully evaluated flag from a + provider + """ + args = ( + flag_key, + default_value, + evaluation_context, + ) + + get_details_callables: typing.Mapping[FlagType, GetDetailCallable] = { + FlagType.BOOLEAN: provider.resolve_boolean_details, + FlagType.INTEGER: provider.resolve_integer_details, + FlagType.FLOAT: provider.resolve_float_details, + FlagType.OBJECT: provider.resolve_object_details, + FlagType.STRING: provider.resolve_string_details, + } + + get_details_callable = get_details_callables.get(flag_type) + if not get_details_callable: + raise GeneralError(error_message="Unknown flag type") + + resolution = await get_details_callable(*args) # type: ignore[misc] + resolution.raise_for_error() + + # we need to check the get_args to be compatible with union types. + _typecheck_flag_value(resolution.value, flag_type) + + return FlagEvaluationDetails( + flag_key=flag_key, + value=resolution.value, + variant=resolution.variant, + flag_metadata=resolution.flag_metadata or {}, + reason=resolution.reason, + error_code=resolution.error_code, + error_message=resolution.error_message, + ) diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index b390f928..c2870797 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -164,3 +164,50 @@ def emit_provider_stale(self, details: ProviderEventDetails) -> None: def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: if hasattr(self, "_on_emit"): self._on_emit(self, event, details) + + +class AsyncAbstractProvider(AbstractProvider): + @abstractmethod + async def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + raise NotImplementedError("Method not implemented") + + @abstractmethod + async def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + raise NotImplementedError("Method not implemented") + + @abstractmethod + async def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + raise NotImplementedError("Method not implemented") + + @abstractmethod + async def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + raise NotImplementedError("Method not implemented") + + @abstractmethod + async def resolve_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + raise NotImplementedError("Method not implemented") diff --git a/openfeature/provider/in_memory_provider.py b/openfeature/provider/in_memory_provider.py index 322f4ed6..39bd2a27 100644 --- a/openfeature/provider/in_memory_provider.py +++ b/openfeature/provider/in_memory_provider.py @@ -117,3 +117,56 @@ def _resolve( if flag is None: raise FlagNotFoundError(f"Flag '{flag_key}' not found") return flag.resolve(evaluation_context) + + +class AsyncInMemoryProvider(InMemoryProvider): + _flags: FlagStorage + + def __init__(self, flags: FlagStorage) -> None: + self._flags = flags.copy() + + def get_metadata(self) -> Metadata: + return InMemoryMetadata() + + def get_provider_hooks(self) -> typing.List[Hook]: + return [] + + async def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return self._resolve(flag_key, evaluation_context) + + async def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return self._resolve(flag_key, evaluation_context) + + async def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return self._resolve(flag_key, evaluation_context) + + async def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return self._resolve(flag_key, evaluation_context) + + async def resolve_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + return self._resolve(flag_key, evaluation_context) diff --git a/openfeature/provider/no_op_provider.py b/openfeature/provider/no_op_provider.py index 070945c9..a2bb5003 100644 --- a/openfeature/provider/no_op_provider.py +++ b/openfeature/provider/no_op_provider.py @@ -3,7 +3,7 @@ from openfeature.evaluation_context import EvaluationContext from openfeature.flag_evaluation import FlagResolutionDetails, Reason from openfeature.hook import Hook -from openfeature.provider import AbstractProvider, Metadata +from openfeature.provider import AbstractProvider, AsyncAbstractProvider, Metadata from openfeature.provider.no_op_metadata import NoOpMetadata PASSED_IN_DEFAULT = "Passed in default" @@ -75,3 +75,68 @@ def resolve_object_details( reason=Reason.DEFAULT, variant=PASSED_IN_DEFAULT, ) + + +class AsyncNoOpProvider(AsyncAbstractProvider): + def get_metadata(self) -> Metadata: + return NoOpMetadata() + + async def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) + + async def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) + + async def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) + + async def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) + + async def resolve_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) diff --git a/pyproject.toml b/pyproject.toml index 394947a7..0e74f28f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ local_partial_types = true # will become the new default from version 2 pretty = true strict = true disallow_any_generics = false +disable_error_code = ["override"] [tool.ruff] exclude = [ diff --git a/tests/conftest.py b/tests/conftest.py index 1f0a7982..d4755514 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import pytest from openfeature import api -from openfeature.provider.no_op_provider import NoOpProvider +from openfeature.provider.no_op_provider import AsyncNoOpProvider, NoOpProvider @pytest.fixture(autouse=True) @@ -13,7 +13,22 @@ def clear_providers(): api.clear_providers() +@pytest.fixture(autouse=True) +def clear_hooks_fixture(): + """ + For tests that use add_hooks(), we need to clear the hooks to avoid issues + in other tests. + """ + api.clear_hooks() + + @pytest.fixture() def no_op_provider_client(): api.set_provider(NoOpProvider()) return api.get_client() + + +@pytest.fixture() +def no_op_provider_client_async(): + api.set_provider(AsyncNoOpProvider()) + return api.get_async_client("my-async-client") diff --git a/tests/provider/test_no_op_provider.py b/tests/provider/test_no_op_provider.py index 3876091a..26ef4448 100644 --- a/tests/provider/test_no_op_provider.py +++ b/tests/provider/test_no_op_provider.py @@ -1,5 +1,11 @@ +import typing from numbers import Number +import pytest + +from openfeature.evaluation_context import EvaluationContext +from openfeature.flag_evaluation import FlagResolutionDetails +from openfeature.provider import AsyncAbstractProvider from openfeature.provider.no_op_provider import NoOpProvider @@ -80,3 +86,66 @@ def test_should_resolve_object_flag_from_no_op(): assert flag is not None assert flag.value == return_value assert isinstance(flag.value, dict) + + +class ConcreteAsyncProvider(AsyncAbstractProvider): + def get_metadata(self): + return super().get_metadata() + + async def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return await super().resolve_boolean_details(flag_key, default_value) + + async def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return await super().resolve_string_details(flag_key, default_value) + + async def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return await super().resolve_integer_details(flag_key, default_value) + + async def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return await super().resolve_float_details(flag_key, default_value) + + async def resolve_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + return await super().resolve_object_details(flag_key, default_value) + + +@pytest.mark.parametrize( + "get_method, default", + ( + ("resolve_boolean_details", True), + ("resolve_string_details", "default"), + ("resolve_integer_details", 42), + ("resolve_float_details", 3.14), + ("resolve_object_details", {"key": "value"}), + ), +) +@pytest.mark.asyncio +async def test_abstract_provider_throws_not_implemented(get_method, default): + with pytest.raises(NotImplementedError) as exception: + provider = ConcreteAsyncProvider() + await getattr(provider, get_method)("test_flag", default) + assert str(exception.value) == "Method not implemented" diff --git a/tests/test_async_client.py b/tests/test_async_client.py new file mode 100644 index 00000000..25b16e86 --- /dev/null +++ b/tests/test_async_client.py @@ -0,0 +1,192 @@ +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import MagicMock + +import pytest + +from openfeature.api import add_hooks, get_async_client, set_provider +from openfeature.client import AsyncOpenFeatureClient +from openfeature.event import ProviderEvent, ProviderEventDetails +from openfeature.exception import ErrorCode, OpenFeatureError +from openfeature.flag_evaluation import Reason +from openfeature.hook import Hook +from openfeature.provider import ProviderStatus +from openfeature.provider.in_memory_provider import AsyncInMemoryProvider, InMemoryFlag +from openfeature.provider.no_op_provider import NoOpProvider + + +@pytest.mark.parametrize( + "default, variants, get_method, expected_value", + ( + ("true", {"true": True, "false": False}, "get_boolean", True), + ("String", {"String": "Variant"}, "get_string", "Variant"), + ("Number", {"Number": 100}, "get_integer", 100), + ("Float", {"Float": 10.23}, "get_float", 10.23), + ( + "Object", + {"Object": {"some": "object"}}, + "get_object", + {"some": "object"}, + ), + ), +) +@pytest.mark.asyncio +async def test_flag_resolution_to_evaluation_details_async( + default, variants, get_method, expected_value, clear_hooks_fixture +): + # Given + api_hook = MagicMock(spec=Hook) + add_hooks([api_hook]) + provider = AsyncInMemoryProvider( + { + "Key": InMemoryFlag( + default, + variants, + flag_metadata={"foo": "bar"}, + ) + } + ) + set_provider(provider, "my-async-client") + client = AsyncOpenFeatureClient("my-async-client", None) + client.add_hooks([api_hook]) + # When + details = await getattr(client, f"{get_method}_details")( + flag_key="Key", default_value=None + ) + value = await getattr(client, f"{get_method}_value")( + flag_key="Key", default_value=None + ) + # Then + assert details is not None + assert details.flag_metadata == {"foo": "bar"} + assert details.value == expected_value + assert details.value == value + + +@pytest.mark.asyncio +async def test_should_return_client_metadata_with_domain_async( + no_op_provider_client_async, +): + # Given + # When + metadata = no_op_provider_client_async.get_metadata() + # Then + assert metadata is not None + assert metadata.domain == "my-async-client" + + +def test_add_remove_event_handler_async(): + # Given + provider = NoOpProvider() + set_provider(provider) + spy = MagicMock() + + client = get_async_client() + client.add_handler( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed + ) + client.remove_handler( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed + ) + + provider_details = ProviderEventDetails(message="message") + + # When + provider.emit_provider_configuration_changed(provider_details) + + # Then + spy.provider_configuration_changed.assert_not_called() + + +def test_client_handlers_thread_safety_async(): + provider = NoOpProvider() + set_provider(provider) + + def add_handlers_task(): + def handler(*args, **kwargs): + time.sleep(0.005) + + for _ in range(10): + time.sleep(0.01) + client = get_async_client(str(uuid.uuid4())) + client.add_handler(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, handler) + + def emit_events_task(): + for _ in range(10): + time.sleep(0.01) + provider.emit_provider_configuration_changed(ProviderEventDetails()) + + with ThreadPoolExecutor(max_workers=2) as executor: + f1 = executor.submit(add_handlers_task) + f2 = executor.submit(emit_events_task) + f1.result() + f2.result() + + +@pytest.mark.parametrize( + "provider_status, error_code", + ( + (ProviderStatus.NOT_READY, ErrorCode.PROVIDER_NOT_READY), + (ProviderStatus.FATAL, ErrorCode.PROVIDER_FATAL), + ), +) +@pytest.mark.asyncio +async def test_should_shortcircuit_if_provider_is_not_ready( + no_op_provider_client_async, monkeypatch, provider_status, error_code +): + # Given + monkeypatch.setattr( + no_op_provider_client_async, + "get_provider_status", + lambda: provider_status, + ) + spy_hook = MagicMock(spec=Hook) + spy_hook.before.return_value = None + no_op_provider_client_async.add_hooks([spy_hook]) + # When + flag_details = await no_op_provider_client_async.get_boolean_details( + flag_key="Key", default_value=True + ) + # Then + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == error_code + spy_hook.error.assert_called_once() + + +@pytest.mark.parametrize( + "expected_type, get_method, default_value", + ( + (bool, "get_boolean_details", True), + (str, "get_string_details", "default"), + (int, "get_integer_details", 100), + (float, "get_float_details", 10.23), + (dict, "get_object_details", {"some": "object"}), + ), +) +@pytest.mark.asyncio +async def test_handle_an_open_feature_exception_thrown_by_a_provider_async( + expected_type, + get_method, + default_value, + no_op_provider_client_async, +): + # Given + exception_hook = MagicMock(spec=Hook) + exception_hook.after.side_effect = OpenFeatureError( + ErrorCode.GENERAL, "error_message" + ) + no_op_provider_client_async.add_hooks([exception_hook]) + + # When + flag_details = await getattr(no_op_provider_client_async, get_method)( + flag_key="Key", default_value=default_value + ) + # Then + assert flag_details is not None + assert flag_details.value + assert isinstance(flag_details.value, expected_type) + assert flag_details.reason == Reason.ERROR + assert flag_details.error_message == "error_message"