Skip to content

Commit 6f7ac50

Browse files
committed
add hook data
Signed-off-by: gruebel <[email protected]>
1 parent 32fdec1 commit 6f7ac50

File tree

6 files changed

+216
-98
lines changed

6 files changed

+216
-98
lines changed

openfeature/client.py

Lines changed: 92 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import typing
33
from collections.abc import Awaitable, Sequence
44
from dataclasses import dataclass
5+
from functools import reduce
6+
from itertools import chain
57

68
from openfeature import _event_support
79
from openfeature.evaluation_context import EvaluationContext, get_evaluation_context
@@ -420,10 +422,10 @@ def _establish_hooks_and_provider(
420422
flag_evaluation_options: typing.Optional[FlagEvaluationOptions],
421423
) -> tuple[
422424
FeatureProvider,
423-
HookContext,
424425
HookHints,
425-
list[Hook],
426-
list[Hook],
426+
list[tuple[Hook, HookContext]],
427+
list[tuple[Hook, HookContext]],
428+
EvaluationContext,
427429
]:
428430
if evaluation_context is None:
429431
evaluation_context = EvaluationContext()
@@ -444,25 +446,43 @@ def _establish_hooks_and_provider(
444446
.merge(evaluation_context)
445447
)
446448

447-
hook_context = HookContext(
448-
flag_key=flag_key,
449-
flag_type=flag_type,
450-
default_value=default_value,
451-
evaluation_context=merged_eval_context,
452-
client_metadata=self.get_metadata(),
453-
provider_metadata=provider.get_metadata(),
454-
)
449+
client_metadata = self.get_metadata()
450+
provider_metadata = provider.get_metadata()
451+
455452
# Hooks need to be handled in different orders at different stages
456453
# in the flag evaluation
457454
# before: API, Client, Invocation, Provider
458-
merged_hooks = (
459-
get_hooks() + self.hooks + evaluation_hooks + provider.get_provider_hooks()
460-
)
455+
merged_hooks_and_context = [
456+
(
457+
hook,
458+
HookContext(
459+
flag_key=flag_key,
460+
flag_type=flag_type,
461+
default_value=default_value,
462+
evaluation_context=merged_eval_context,
463+
client_metadata=client_metadata,
464+
provider_metadata=provider_metadata,
465+
hook_data={},
466+
),
467+
)
468+
for hook in chain(
469+
get_hooks(),
470+
self.hooks,
471+
evaluation_hooks,
472+
provider.get_provider_hooks(),
473+
)
474+
]
461475
# after, error, finally: Provider, Invocation, Client, API
462-
reversed_merged_hooks = merged_hooks[:]
463-
reversed_merged_hooks.reverse()
464-
465-
return provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks
476+
reversed_merged_hooks_and_context = merged_hooks_and_context[:]
477+
reversed_merged_hooks_and_context.reverse()
478+
479+
return (
480+
provider,
481+
hook_hints,
482+
merged_hooks_and_context,
483+
reversed_merged_hooks_and_context,
484+
merged_eval_context,
485+
)
466486

467487
def _assert_provider_status(
468488
self,
@@ -477,22 +497,31 @@ def _assert_provider_status(
477497
def _run_before_hooks_and_update_context(
478498
self,
479499
flag_type: FlagType,
480-
hook_context: HookContext,
481-
merged_hooks: list[Hook],
500+
merged_hooks_and_context: list[tuple[Hook, HookContext]],
482501
hook_hints: HookHints,
483-
evaluation_context: typing.Optional[EvaluationContext],
502+
evaluation_context: EvaluationContext,
484503
) -> EvaluationContext:
485504
# https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md
486505
# Any resulting evaluation context from a before hook will overwrite
487506
# duplicate fields defined globally, on the client, or in the invocation.
488507
# Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context
489508
before_hooks_context = before_hooks(
490-
flag_type, hook_context, merged_hooks, hook_hints
509+
flag_type, merged_hooks_and_context, hook_hints
491510
)
492511

512+
if not merged_hooks_and_context:
513+
return evaluation_context.merge(before_hooks_context)
514+
493515
# The hook_context.evaluation_context already contains the merged context from
494516
# _establish_hooks_and_provider, so we just need to merge with the before hooks result
495-
merged_context = hook_context.evaluation_context.merge(before_hooks_context)
517+
merged_context = reduce(
518+
lambda a, b: a.merge(b),
519+
[
520+
hook_context.evaluation_context
521+
for (_, hook_context) in merged_hooks_and_context
522+
],
523+
)
524+
merged_context = merged_context.merge(before_hooks_context)
496525

497526
return merged_context
498527

@@ -575,23 +604,26 @@ async def evaluate_flag_details_async(
575604
:return: a typing.Awaitable[FlagEvaluationDetails] object with the fully evaluated flag from a
576605
provider
577606
"""
578-
provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = (
579-
self._establish_hooks_and_provider(
580-
flag_type,
581-
flag_key,
582-
default_value,
583-
evaluation_context,
584-
flag_evaluation_options,
585-
)
607+
(
608+
provider,
609+
hook_hints,
610+
merged_hooks_and_context,
611+
reversed_merged_hooks_and_context,
612+
merged_eval_context,
613+
) = self._establish_hooks_and_provider(
614+
flag_type,
615+
flag_key,
616+
default_value,
617+
evaluation_context,
618+
flag_evaluation_options,
586619
)
587620

588621
try:
589622
if provider_err := self._assert_provider_status():
590623
error_hooks(
591624
flag_type,
592-
hook_context,
593625
provider_err,
594-
reversed_merged_hooks,
626+
reversed_merged_hooks_and_context,
595627
hook_hints,
596628
)
597629
flag_evaluation = FlagEvaluationDetails(
@@ -605,10 +637,9 @@ async def evaluate_flag_details_async(
605637

606638
merged_context = self._run_before_hooks_and_update_context(
607639
flag_type,
608-
hook_context,
609-
merged_hooks,
640+
merged_hooks_and_context,
610641
hook_hints,
611-
evaluation_context,
642+
merged_eval_context,
612643
)
613644

614645
flag_evaluation = await self._create_provider_evaluation_async(
@@ -620,22 +651,21 @@ async def evaluate_flag_details_async(
620651
)
621652
if err := flag_evaluation.get_exception():
622653
error_hooks(
623-
flag_type, hook_context, err, reversed_merged_hooks, hook_hints
654+
flag_type, err, reversed_merged_hooks_and_context, hook_hints
624655
)
625656
return flag_evaluation
626657

627658
after_hooks(
628659
flag_type,
629-
hook_context,
630660
flag_evaluation,
631-
reversed_merged_hooks,
661+
reversed_merged_hooks_and_context,
632662
hook_hints,
633663
)
634664

635665
return flag_evaluation
636666

637667
except OpenFeatureError as err:
638-
error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints)
668+
error_hooks(flag_type, err, reversed_merged_hooks_and_context, hook_hints)
639669
flag_evaluation = FlagEvaluationDetails(
640670
flag_key=flag_key,
641671
value=default_value,
@@ -651,7 +681,7 @@ async def evaluate_flag_details_async(
651681
"Unable to correctly evaluate flag with key: '%s'", flag_key
652682
)
653683

654-
error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints)
684+
error_hooks(flag_type, err, reversed_merged_hooks_and_context, hook_hints)
655685

656686
error_message = getattr(err, "error_message", str(err))
657687
flag_evaluation = FlagEvaluationDetails(
@@ -666,9 +696,8 @@ async def evaluate_flag_details_async(
666696
finally:
667697
after_all_hooks(
668698
flag_type,
669-
hook_context,
670699
flag_evaluation,
671-
reversed_merged_hooks,
700+
reversed_merged_hooks_and_context,
672701
hook_hints,
673702
)
674703

@@ -751,23 +780,26 @@ def evaluate_flag_details(
751780
:return: a FlagEvaluationDetails object with the fully evaluated flag from a
752781
provider
753782
"""
754-
provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = (
755-
self._establish_hooks_and_provider(
756-
flag_type,
757-
flag_key,
758-
default_value,
759-
evaluation_context,
760-
flag_evaluation_options,
761-
)
783+
(
784+
provider,
785+
hook_hints,
786+
merged_hooks_and_context,
787+
reversed_merged_hooks_and_context,
788+
merged_eval_context,
789+
) = self._establish_hooks_and_provider(
790+
flag_type,
791+
flag_key,
792+
default_value,
793+
evaluation_context,
794+
flag_evaluation_options,
762795
)
763796

764797
try:
765798
if provider_err := self._assert_provider_status():
766799
error_hooks(
767800
flag_type,
768-
hook_context,
769801
provider_err,
770-
reversed_merged_hooks,
802+
reversed_merged_hooks_and_context,
771803
hook_hints,
772804
)
773805
flag_evaluation = FlagEvaluationDetails(
@@ -781,10 +813,9 @@ def evaluate_flag_details(
781813

782814
merged_context = self._run_before_hooks_and_update_context(
783815
flag_type,
784-
hook_context,
785-
merged_hooks,
816+
merged_hooks_and_context,
786817
hook_hints,
787-
evaluation_context,
818+
merged_eval_context,
788819
)
789820

790821
flag_evaluation = self._create_provider_evaluation(
@@ -796,23 +827,22 @@ def evaluate_flag_details(
796827
)
797828
if err := flag_evaluation.get_exception():
798829
error_hooks(
799-
flag_type, hook_context, err, reversed_merged_hooks, hook_hints
830+
flag_type, err, reversed_merged_hooks_and_context, hook_hints
800831
)
801832
flag_evaluation.value = default_value
802833
return flag_evaluation
803834

804835
after_hooks(
805836
flag_type,
806-
hook_context,
807837
flag_evaluation,
808-
reversed_merged_hooks,
838+
reversed_merged_hooks_and_context,
809839
hook_hints,
810840
)
811841

812842
return flag_evaluation
813843

814844
except OpenFeatureError as err:
815-
error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints)
845+
error_hooks(flag_type, err, reversed_merged_hooks_and_context, hook_hints)
816846

817847
flag_evaluation = FlagEvaluationDetails(
818848
flag_key=flag_key,
@@ -829,7 +859,7 @@ def evaluate_flag_details(
829859
"Unable to correctly evaluate flag with key: '%s'", flag_key
830860
)
831861

832-
error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints)
862+
error_hooks(flag_type, err, reversed_merged_hooks_and_context, hook_hints)
833863

834864
error_message = getattr(err, "error_message", str(err))
835865
flag_evaluation = FlagEvaluationDetails(
@@ -844,9 +874,8 @@ def evaluate_flag_details(
844874
finally:
845875
after_all_hooks(
846876
flag_type,
847-
hook_context,
848877
flag_evaluation,
849-
reversed_merged_hooks,
878+
reversed_merged_hooks_and_context,
850879
hook_hints,
851880
)
852881

openfeature/hook/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import typing
4-
from collections.abc import Sequence
4+
from collections.abc import MutableMapping, Sequence
55
from datetime import datetime
66
from enum import Enum
77
from typing import TYPE_CHECKING
@@ -16,6 +16,7 @@
1616
__all__ = [
1717
"Hook",
1818
"HookContext",
19+
"HookData",
1920
"HookHints",
2021
"HookType",
2122
"add_hooks",
@@ -26,6 +27,10 @@
2627
_hooks: list[Hook] = []
2728

2829

30+
# https://openfeature.dev/specification/sections/hooks/#requirement-461
31+
HookData = MutableMapping[str, typing.Any]
32+
33+
2934
class HookType(Enum):
3035
BEFORE = "before"
3136
AFTER = "after"
@@ -34,21 +39,23 @@ class HookType(Enum):
3439

3540

3641
class HookContext:
37-
def __init__(
42+
def __init__( # noqa: PLR0913
3843
self,
3944
flag_key: str,
4045
flag_type: FlagType,
4146
default_value: FlagValueType,
4247
evaluation_context: EvaluationContext,
4348
client_metadata: typing.Optional[ClientMetadata] = None,
4449
provider_metadata: typing.Optional[Metadata] = None,
50+
hook_data: typing.Optional[HookData] = None,
4551
):
4652
self.flag_key = flag_key
4753
self.flag_type = flag_type
4854
self.default_value = default_value
4955
self.evaluation_context = evaluation_context
5056
self.client_metadata = client_metadata
5157
self.provider_metadata = provider_metadata
58+
self.hook_data = hook_data or {}
5259

5360
def __setattr__(self, key: str, value: typing.Any) -> None:
5461
if hasattr(self, key) and key in (

0 commit comments

Comments
 (0)