diff --git a/openfeature/client.py b/openfeature/client.py index 450d646e..3d039538 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -2,6 +2,7 @@ import typing from collections.abc import Awaitable, Sequence from dataclasses import dataclass +from itertools import chain from openfeature import _event_support from openfeature.evaluation_context import EvaluationContext, get_evaluation_context @@ -420,10 +421,10 @@ def _establish_hooks_and_provider( flag_evaluation_options: typing.Optional[FlagEvaluationOptions], ) -> tuple[ FeatureProvider, - HookContext, HookHints, - list[Hook], - list[Hook], + list[tuple[Hook, HookContext]], + list[tuple[Hook, HookContext]], + EvaluationContext, ]: if evaluation_context is None: evaluation_context = EvaluationContext() @@ -444,25 +445,43 @@ def _establish_hooks_and_provider( .merge(evaluation_context) ) - hook_context = HookContext( - flag_key=flag_key, - flag_type=flag_type, - default_value=default_value, - evaluation_context=merged_eval_context, - client_metadata=self.get_metadata(), - provider_metadata=provider.get_metadata(), - ) + client_metadata = self.get_metadata() + provider_metadata = provider.get_metadata() + # Hooks need to be handled in different orders at different stages # in the flag evaluation # before: API, Client, Invocation, Provider - merged_hooks = ( - get_hooks() + self.hooks + evaluation_hooks + provider.get_provider_hooks() - ) + merged_hooks_and_context = [ + ( + hook, + HookContext( + flag_key=flag_key, + flag_type=flag_type, + default_value=default_value, + evaluation_context=merged_eval_context, + client_metadata=client_metadata, + provider_metadata=provider_metadata, + hook_data={}, + ), + ) + for hook in chain( + get_hooks(), + self.hooks, + evaluation_hooks, + provider.get_provider_hooks(), + ) + ] # after, error, finally: Provider, Invocation, Client, API - reversed_merged_hooks = merged_hooks[:] - reversed_merged_hooks.reverse() - - return provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks + reversed_merged_hooks_and_context = merged_hooks_and_context[:] + reversed_merged_hooks_and_context.reverse() + + return ( + provider, + hook_hints, + merged_hooks_and_context, + reversed_merged_hooks_and_context, + merged_eval_context, + ) def _assert_provider_status( self, @@ -477,24 +496,21 @@ def _assert_provider_status( def _run_before_hooks_and_update_context( self, flag_type: FlagType, - hook_context: HookContext, - merged_hooks: list[Hook], + merged_hooks_and_context: list[tuple[Hook, HookContext]], hook_hints: HookHints, - evaluation_context: typing.Optional[EvaluationContext], + evaluation_context: EvaluationContext, ) -> EvaluationContext: # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md # Any resulting evaluation context from a before hook will overwrite # duplicate fields defined globally, on the client, or in the invocation. # Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context before_hooks_context = before_hooks( - flag_type, hook_context, merged_hooks, hook_hints + flag_type, merged_hooks_and_context, hook_hints ) # The hook_context.evaluation_context already contains the merged context from # _establish_hooks_and_provider, so we just need to merge with the before hooks result - merged_context = hook_context.evaluation_context.merge(before_hooks_context) - - return merged_context + return evaluation_context.merge(before_hooks_context) @typing.overload async def evaluate_flag_details_async( @@ -575,23 +591,26 @@ async def evaluate_flag_details_async( :return: a typing.Awaitable[FlagEvaluationDetails] object with the fully evaluated flag from a provider """ - provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = ( - self._establish_hooks_and_provider( - flag_type, - flag_key, - default_value, - evaluation_context, - flag_evaluation_options, - ) + ( + provider, + hook_hints, + merged_hooks_and_context, + reversed_merged_hooks_and_context, + merged_eval_context, + ) = self._establish_hooks_and_provider( + flag_type, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, ) try: if provider_err := self._assert_provider_status(): error_hooks( flag_type, - hook_context, provider_err, - reversed_merged_hooks, + reversed_merged_hooks_and_context, hook_hints, ) flag_evaluation = FlagEvaluationDetails( @@ -605,10 +624,9 @@ async def evaluate_flag_details_async( merged_context = self._run_before_hooks_and_update_context( flag_type, - hook_context, - merged_hooks, + merged_hooks_and_context, hook_hints, - evaluation_context, + merged_eval_context, ) flag_evaluation = await self._create_provider_evaluation_async( @@ -620,22 +638,21 @@ async def evaluate_flag_details_async( ) if err := flag_evaluation.get_exception(): error_hooks( - flag_type, hook_context, err, reversed_merged_hooks, hook_hints + flag_type, err, reversed_merged_hooks_and_context, hook_hints ) return flag_evaluation after_hooks( flag_type, - hook_context, flag_evaluation, - reversed_merged_hooks, + reversed_merged_hooks_and_context, hook_hints, ) return flag_evaluation except OpenFeatureError as err: - error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + error_hooks(flag_type, err, reversed_merged_hooks_and_context, hook_hints) flag_evaluation = FlagEvaluationDetails( flag_key=flag_key, value=default_value, @@ -651,7 +668,7 @@ async def evaluate_flag_details_async( "Unable to correctly evaluate flag with key: '%s'", flag_key ) - error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + error_hooks(flag_type, err, reversed_merged_hooks_and_context, hook_hints) error_message = getattr(err, "error_message", str(err)) flag_evaluation = FlagEvaluationDetails( @@ -666,9 +683,8 @@ async def evaluate_flag_details_async( finally: after_all_hooks( flag_type, - hook_context, flag_evaluation, - reversed_merged_hooks, + reversed_merged_hooks_and_context, hook_hints, ) @@ -751,23 +767,26 @@ def evaluate_flag_details( :return: a FlagEvaluationDetails object with the fully evaluated flag from a provider """ - provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = ( - self._establish_hooks_and_provider( - flag_type, - flag_key, - default_value, - evaluation_context, - flag_evaluation_options, - ) + ( + provider, + hook_hints, + merged_hooks_and_context, + reversed_merged_hooks_and_context, + merged_eval_context, + ) = self._establish_hooks_and_provider( + flag_type, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, ) try: if provider_err := self._assert_provider_status(): error_hooks( flag_type, - hook_context, provider_err, - reversed_merged_hooks, + reversed_merged_hooks_and_context, hook_hints, ) flag_evaluation = FlagEvaluationDetails( @@ -781,10 +800,9 @@ def evaluate_flag_details( merged_context = self._run_before_hooks_and_update_context( flag_type, - hook_context, - merged_hooks, + merged_hooks_and_context, hook_hints, - evaluation_context, + merged_eval_context, ) flag_evaluation = self._create_provider_evaluation( @@ -796,23 +814,22 @@ def evaluate_flag_details( ) if err := flag_evaluation.get_exception(): error_hooks( - flag_type, hook_context, err, reversed_merged_hooks, hook_hints + flag_type, err, reversed_merged_hooks_and_context, hook_hints ) flag_evaluation.value = default_value return flag_evaluation after_hooks( flag_type, - hook_context, flag_evaluation, - reversed_merged_hooks, + reversed_merged_hooks_and_context, hook_hints, ) return flag_evaluation except OpenFeatureError as err: - error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + error_hooks(flag_type, err, reversed_merged_hooks_and_context, hook_hints) flag_evaluation = FlagEvaluationDetails( flag_key=flag_key, @@ -829,7 +846,7 @@ def evaluate_flag_details( "Unable to correctly evaluate flag with key: '%s'", flag_key ) - error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + error_hooks(flag_type, err, reversed_merged_hooks_and_context, hook_hints) error_message = getattr(err, "error_message", str(err)) flag_evaluation = FlagEvaluationDetails( @@ -844,9 +861,8 @@ def evaluate_flag_details( finally: after_all_hooks( flag_type, - hook_context, flag_evaluation, - reversed_merged_hooks, + reversed_merged_hooks_and_context, hook_hints, ) diff --git a/openfeature/hook/__init__.py b/openfeature/hook/__init__.py index 9e1e11ac..e21a263f 100644 --- a/openfeature/hook/__init__.py +++ b/openfeature/hook/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations import typing -from collections.abc import Sequence +from collections.abc import MutableMapping, Sequence from datetime import datetime from enum import Enum from typing import TYPE_CHECKING @@ -16,6 +16,7 @@ __all__ = [ "Hook", "HookContext", + "HookData", "HookHints", "HookType", "add_hooks", @@ -26,6 +27,10 @@ _hooks: list[Hook] = [] +# https://openfeature.dev/specification/sections/hooks/#requirement-461 +HookData = MutableMapping[str, typing.Any] + + class HookType(Enum): BEFORE = "before" AFTER = "after" @@ -34,7 +39,7 @@ class HookType(Enum): class HookContext: - def __init__( + def __init__( # noqa: PLR0913 self, flag_key: str, flag_type: FlagType, @@ -42,6 +47,7 @@ def __init__( evaluation_context: EvaluationContext, client_metadata: typing.Optional[ClientMetadata] = None, provider_metadata: typing.Optional[Metadata] = None, + hook_data: typing.Optional[HookData] = None, ): self.flag_key = flag_key self.flag_type = flag_type @@ -49,6 +55,7 @@ def __init__( self.evaluation_context = evaluation_context self.client_metadata = client_metadata self.provider_metadata = provider_metadata + self.hook_data = hook_data or {} def __setattr__(self, key: str, value: typing.Any) -> None: if hasattr(self, key) and key in ( diff --git a/openfeature/hook/_hook_support.py b/openfeature/hook/_hook_support.py index 37b7e5b4..0a3565cd 100644 --- a/openfeature/hook/_hook_support.py +++ b/openfeature/hook/_hook_support.py @@ -11,52 +11,60 @@ def error_hooks( flag_type: FlagType, - hook_context: HookContext, exception: Exception, - hooks: list[Hook], + hooks_and_context: list[tuple[Hook, HookContext]], hints: typing.Optional[HookHints] = None, ) -> None: - kwargs = {"hook_context": hook_context, "exception": exception, "hints": hints} + kwargs = {"exception": exception, "hints": hints} _execute_hooks( - flag_type=flag_type, hooks=hooks, hook_method=HookType.ERROR, **kwargs + flag_type=flag_type, + hooks_and_context=hooks_and_context, + hook_method=HookType.ERROR, + **kwargs, ) def after_all_hooks( flag_type: FlagType, - hook_context: HookContext, details: FlagEvaluationDetails[typing.Any], - hooks: list[Hook], + hooks_and_context: list[tuple[Hook, HookContext]], hints: typing.Optional[HookHints] = None, ) -> None: - kwargs = {"hook_context": hook_context, "details": details, "hints": hints} + kwargs = {"details": details, "hints": hints} _execute_hooks( - flag_type=flag_type, hooks=hooks, hook_method=HookType.FINALLY_AFTER, **kwargs + flag_type=flag_type, + hooks_and_context=hooks_and_context, + hook_method=HookType.FINALLY_AFTER, + **kwargs, ) def after_hooks( flag_type: FlagType, - hook_context: HookContext, details: FlagEvaluationDetails[typing.Any], - hooks: list[Hook], + hooks_and_context: list[tuple[Hook, HookContext]], hints: typing.Optional[HookHints] = None, ) -> None: - kwargs = {"hook_context": hook_context, "details": details, "hints": hints} + kwargs = {"details": details, "hints": hints} _execute_hooks_unchecked( - flag_type=flag_type, hooks=hooks, hook_method=HookType.AFTER, **kwargs + flag_type=flag_type, + hooks_and_context=hooks_and_context, + hook_method=HookType.AFTER, + **kwargs, ) def before_hooks( flag_type: FlagType, - hook_context: HookContext, - hooks: list[Hook], + hooks_and_context: list[tuple[Hook, HookContext]], hints: typing.Optional[HookHints] = None, ) -> EvaluationContext: - kwargs = {"hook_context": hook_context, "hints": hints} + kwargs = {"hints": hints} executed_hooks = _execute_hooks_unchecked( - flag_type=flag_type, hooks=hooks, hook_method=HookType.BEFORE, **kwargs + flag_type=flag_type, + hooks_and_context=hooks_and_context, + hook_method=HookType.BEFORE, + **kwargs, ) filtered_hooks = [result for result in executed_hooks if result is not None] @@ -68,30 +76,30 @@ def before_hooks( def _execute_hooks( flag_type: FlagType, - hooks: list[Hook], + hooks_and_context: list[tuple[Hook, HookContext]], hook_method: HookType, **kwargs: typing.Any, -) -> list: +) -> list[typing.Optional[EvaluationContext]]: """ Run multiple hooks of any hook type. All of these hooks will be run through an exception check. :param flag_type: particular type of flag - :param hooks: a list of hooks + :param hooks_and_context: a list of hooks and their context :param hook_method: the type of hook that is being run :param kwargs: arguments that need to be provided to the hook method :return: a list of results from the applied hook methods """ return [ - _execute_hook_checked(hook, hook_method, **kwargs) - for hook in hooks + _execute_hook_checked(hook, hook_method, hook_context=hook_context, **kwargs) + for (hook, hook_context) in hooks_and_context if hook.supports_flag_value_type(flag_type) ] def _execute_hooks_unchecked( flag_type: FlagType, - hooks: list[Hook], + hooks_and_context: list[tuple[Hook, HookContext]], hook_method: HookType, **kwargs: typing.Any, ) -> list[typing.Optional[EvaluationContext]]: @@ -101,14 +109,14 @@ def _execute_hooks_unchecked( client. :param flag_type: particular type of flag - :param hooks: a list of hooks + :param hooks_and_context: a list of hooks and their context :param hook_method: the type of hook that is being run :param kwargs: arguments that need to be provided to the hook method :return: a list of results from the applied hook methods """ return [ - getattr(hook, hook_method.value)(**kwargs) - for hook in hooks + getattr(hook, hook_method.value)(hook_context=hook_context, **kwargs) + for (hook, hook_context) in hooks_and_context if hook.supports_flag_value_type(flag_type) ] diff --git a/tests/hook/test_hook_data.py b/tests/hook/test_hook_data.py new file mode 100644 index 00000000..33d12591 --- /dev/null +++ b/tests/hook/test_hook_data.py @@ -0,0 +1,61 @@ +import typing + +from openfeature.api import set_provider +from openfeature.client import OpenFeatureClient +from openfeature.evaluation_context import EvaluationContext +from openfeature.flag_evaluation import FlagEvaluationDetails, FlagValueType +from openfeature.hook import Hook, HookContext, HookHints +from openfeature.provider.no_op_provider import NoOpProvider + + +class Example: + def __init__(self): + self.value = "example" + + +class HookWithData(Hook): + def __init__(self, data: dict[str, typing.Any]): + self.data_before = data + self.data_after = None + + def before( + self, hook_context: HookContext, hints: HookHints + ) -> typing.Optional[EvaluationContext]: + hook_context.hook_data = hook_context.hook_data | self.data_before + return None + + def after( + self, + hook_context: HookContext, + details: FlagEvaluationDetails[FlagValueType], + hints: HookHints, + ) -> None: + self.data_after = hook_context.hook_data + + +def test_hook_data_is_not_shared_between_hooks(): + """Requirement + + 4.3.2 - "Hook data" MUST must be created before the first "stage" invoked in a hook for a specific evaluation + and propagated between each "stage" of the hook. The hook data is not shared between different hooks. + 4.6.1 - "Hook data" MUST be a structure supporting the definition of arbitrary properties, with keys of type string, + and values of any type. + """ + + # given + provider = NoOpProvider() + set_provider(provider) + + client = OpenFeatureClient(domain=None, version=None) + + hook_1 = HookWithData({"key": "value"}) + hook_2 = HookWithData({"key": Example()}) + client.add_hooks([hook_1, hook_2]) + + # when + client.get_boolean_value(flag_key="test-flag", default_value=False) + + # then + assert hook_1.data_after["key"] == "value" + assert isinstance(hook_2.data_after["key"], Example) + assert hook_2.data_after["key"].value == "example" diff --git a/tests/hook/test_hook_support.py b/tests/hook/test_hook_support.py index 19a86ec0..a72b0178 100644 --- a/tests/hook/test_hook_support.py +++ b/tests/hook/test_hook_support.py @@ -19,7 +19,7 @@ def test_hook_context_has_required_and_optional_fields(): """Requirement - 4.1.1 - Hook context MUST provide: the "flag key", "flag value type", "evaluation context", and the "default value". + 4.1.1 - Hook context MUST provide: the "flag key", "flag value type", "evaluation context", "default value" and "hook data". 4.1.2 - The "hook context" SHOULD provide: access to the "client metadata" and the "provider metadata" fields. """ @@ -33,12 +33,14 @@ def test_hook_context_has_required_and_optional_fields(): assert hasattr(hook_context, "evaluation_context") assert hasattr(hook_context, "client_metadata") assert hasattr(hook_context, "provider_metadata") + assert hasattr(hook_context, "hook_data") def test_hook_context_has_immutable_and_mutable_fields(): """Requirement 4.1.3 - The "flag key", "flag type", and "default value" properties MUST be immutable. + 4.1.5 - The "hook data" property MUST be mutable. 4.1.4.1 - The evaluation context MUST be mutable only within the before hook. 4.2.2.2 - The client "metadata" field in the "hook context" MUST be immutable. 4.2.2.3 - The provider "metadata" field in the "hook context" MUST be immutable. @@ -62,6 +64,7 @@ def test_hook_context_has_immutable_and_mutable_fields(): hook_context.provider_metadata = Metadata("name") hook_context.evaluation_context = EvaluationContext("targeting_key") + hook_context.hook_data["key"] = "value" # Then assert hook_context.flag_key == "flag_key" @@ -70,6 +73,7 @@ def test_hook_context_has_immutable_and_mutable_fields(): assert hook_context.evaluation_context.targeting_key == "targeting_key" assert hook_context.client_metadata.name == "name" assert hook_context.provider_metadata is None + assert hook_context.hook_data == {"key": "value"} def test_error_hooks_run_error_method(mock_hook): @@ -77,7 +81,7 @@ def test_error_hooks_run_error_method(mock_hook): hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") hook_hints = MappingProxyType({}) # When - error_hooks(FlagType.BOOLEAN, hook_context, Exception, [mock_hook], hook_hints) + error_hooks(FlagType.BOOLEAN, Exception, [(mock_hook, hook_context)], hook_hints) # Then mock_hook.supports_flag_value_type.assert_called_once() mock_hook.error.assert_called_once() @@ -91,7 +95,7 @@ def test_before_hooks_run_before_method(mock_hook): hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") hook_hints = MappingProxyType({}) # When - before_hooks(FlagType.BOOLEAN, hook_context, [mock_hook], hook_hints) + before_hooks(FlagType.BOOLEAN, [(mock_hook, hook_context)], hook_hints) # Then mock_hook.supports_flag_value_type.assert_called_once() mock_hook.before.assert_called_once() @@ -109,7 +113,10 @@ def test_before_hooks_merges_evaluation_contexts(): hook_3.before.return_value = None # When - context = before_hooks(FlagType.BOOLEAN, hook_context, [hook_1, hook_2, hook_3]) + context = before_hooks( + FlagType.BOOLEAN, + [(hook_1, hook_context), (hook_2, hook_context), (hook_3, hook_context)], + ) # Then assert context == EvaluationContext("bar", {"key_1": "val_1", "key_2": "val_2"}) @@ -124,7 +131,10 @@ def test_after_hooks_run_after_method(mock_hook): hook_hints = MappingProxyType({}) # When after_hooks( - FlagType.BOOLEAN, hook_context, flag_evaluation_details, [mock_hook], hook_hints + FlagType.BOOLEAN, + flag_evaluation_details, + [(mock_hook, hook_context)], + hook_hints, ) # Then mock_hook.supports_flag_value_type.assert_called_once() @@ -143,7 +153,10 @@ def test_finally_after_hooks_run_finally_after_method(mock_hook): hook_hints = MappingProxyType({}) # When after_all_hooks( - FlagType.BOOLEAN, hook_context, flag_evaluation_details, [mock_hook], hook_hints + FlagType.BOOLEAN, + flag_evaluation_details, + [(mock_hook, hook_context)], + hook_hints, ) # Then mock_hook.supports_flag_value_type.assert_called_once() diff --git a/uv.lock b/uv.lock index 5107ba6d..0a4c1682 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9" [[package]] @@ -202,7 +202,7 @@ wheels = [ [[package]] name = "openfeature-sdk" -version = "0.8.1" +version = "0.8.2" source = { editable = "." } [package.dev-dependencies]