Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 81 additions & 65 deletions openfeature/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import typing
from collections.abc import Awaitable, Sequence
from dataclasses import dataclass
from itertools import chain

from openfeature import _event_support
from openfeature.evaluation_context import EvaluationContext, get_evaluation_context
Expand Down Expand Up @@ -420,10 +421,10 @@ def _establish_hooks_and_provider(
flag_evaluation_options: typing.Optional[FlagEvaluationOptions],
) -> tuple[
FeatureProvider,
HookContext,
HookHints,
list[Hook],
list[Hook],
list[tuple[Hook, HookContext]],
list[tuple[Hook, HookContext]],
EvaluationContext,
]:
if evaluation_context is None:
evaluation_context = EvaluationContext()
Expand All @@ -444,25 +445,43 @@ def _establish_hooks_and_provider(
.merge(evaluation_context)
)

hook_context = HookContext(
flag_key=flag_key,
flag_type=flag_type,
default_value=default_value,
evaluation_context=merged_eval_context,
client_metadata=self.get_metadata(),
provider_metadata=provider.get_metadata(),
)
client_metadata = self.get_metadata()
provider_metadata = provider.get_metadata()

# Hooks need to be handled in different orders at different stages
# in the flag evaluation
# before: API, Client, Invocation, Provider
merged_hooks = (
get_hooks() + self.hooks + evaluation_hooks + provider.get_provider_hooks()
)
merged_hooks_and_context = [
(
hook,
HookContext(
flag_key=flag_key,
flag_type=flag_type,
default_value=default_value,
evaluation_context=merged_eval_context,
client_metadata=client_metadata,
provider_metadata=provider_metadata,
hook_data={},
),
)
for hook in chain(
get_hooks(),
self.hooks,
evaluation_hooks,
provider.get_provider_hooks(),
)
]
# after, error, finally: Provider, Invocation, Client, API
reversed_merged_hooks = merged_hooks[:]
reversed_merged_hooks.reverse()

return provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks
reversed_merged_hooks_and_context = merged_hooks_and_context[:]
reversed_merged_hooks_and_context.reverse()

return (
provider,
hook_hints,
merged_hooks_and_context,
reversed_merged_hooks_and_context,
merged_eval_context,
)

def _assert_provider_status(
self,
Expand All @@ -477,24 +496,21 @@ def _assert_provider_status(
def _run_before_hooks_and_update_context(
self,
flag_type: FlagType,
hook_context: HookContext,
merged_hooks: list[Hook],
merged_hooks_and_context: list[tuple[Hook, HookContext]],
hook_hints: HookHints,
evaluation_context: typing.Optional[EvaluationContext],
evaluation_context: EvaluationContext,
) -> EvaluationContext:
# https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md
# Any resulting evaluation context from a before hook will overwrite
# duplicate fields defined globally, on the client, or in the invocation.
# Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context
before_hooks_context = before_hooks(
flag_type, hook_context, merged_hooks, hook_hints
flag_type, merged_hooks_and_context, hook_hints
)

# The hook_context.evaluation_context already contains the merged context from
# _establish_hooks_and_provider, so we just need to merge with the before hooks result
merged_context = hook_context.evaluation_context.merge(before_hooks_context)

return merged_context
return evaluation_context.merge(before_hooks_context)

@typing.overload
async def evaluate_flag_details_async(
Expand Down Expand Up @@ -575,23 +591,26 @@ async def evaluate_flag_details_async(
:return: a typing.Awaitable[FlagEvaluationDetails] object with the fully evaluated flag from a
provider
"""
provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = (
self._establish_hooks_and_provider(
flag_type,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
(
provider,
hook_hints,
merged_hooks_and_context,
reversed_merged_hooks_and_context,
merged_eval_context,
) = self._establish_hooks_and_provider(
flag_type,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)

try:
if provider_err := self._assert_provider_status():
error_hooks(
flag_type,
hook_context,
provider_err,
reversed_merged_hooks,
reversed_merged_hooks_and_context,
hook_hints,
)
flag_evaluation = FlagEvaluationDetails(
Expand All @@ -605,10 +624,9 @@ async def evaluate_flag_details_async(

merged_context = self._run_before_hooks_and_update_context(
flag_type,
hook_context,
merged_hooks,
merged_hooks_and_context,
hook_hints,
evaluation_context,
merged_eval_context,
)

flag_evaluation = await self._create_provider_evaluation_async(
Expand All @@ -620,22 +638,21 @@ async def evaluate_flag_details_async(
)
if err := flag_evaluation.get_exception():
error_hooks(
flag_type, hook_context, err, reversed_merged_hooks, hook_hints
flag_type, err, reversed_merged_hooks_and_context, hook_hints
)
return flag_evaluation

after_hooks(
flag_type,
hook_context,
flag_evaluation,
reversed_merged_hooks,
reversed_merged_hooks_and_context,
hook_hints,
)

return flag_evaluation

except OpenFeatureError as err:
error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints)
error_hooks(flag_type, err, reversed_merged_hooks_and_context, hook_hints)
flag_evaluation = FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
Expand All @@ -651,7 +668,7 @@ async def evaluate_flag_details_async(
"Unable to correctly evaluate flag with key: '%s'", flag_key
)

error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints)
error_hooks(flag_type, err, reversed_merged_hooks_and_context, hook_hints)

error_message = getattr(err, "error_message", str(err))
flag_evaluation = FlagEvaluationDetails(
Expand All @@ -666,9 +683,8 @@ async def evaluate_flag_details_async(
finally:
after_all_hooks(
flag_type,
hook_context,
flag_evaluation,
reversed_merged_hooks,
reversed_merged_hooks_and_context,
hook_hints,
)

Expand Down Expand Up @@ -751,23 +767,26 @@ def evaluate_flag_details(
:return: a FlagEvaluationDetails object with the fully evaluated flag from a
provider
"""
provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = (
self._establish_hooks_and_provider(
flag_type,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
(
provider,
hook_hints,
merged_hooks_and_context,
reversed_merged_hooks_and_context,
merged_eval_context,
) = self._establish_hooks_and_provider(
flag_type,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)

try:
if provider_err := self._assert_provider_status():
error_hooks(
flag_type,
hook_context,
provider_err,
reversed_merged_hooks,
reversed_merged_hooks_and_context,
hook_hints,
)
flag_evaluation = FlagEvaluationDetails(
Expand All @@ -781,10 +800,9 @@ def evaluate_flag_details(

merged_context = self._run_before_hooks_and_update_context(
flag_type,
hook_context,
merged_hooks,
merged_hooks_and_context,
hook_hints,
evaluation_context,
merged_eval_context,
)

flag_evaluation = self._create_provider_evaluation(
Expand All @@ -796,23 +814,22 @@ def evaluate_flag_details(
)
if err := flag_evaluation.get_exception():
error_hooks(
flag_type, hook_context, err, reversed_merged_hooks, hook_hints
flag_type, err, reversed_merged_hooks_and_context, hook_hints
)
flag_evaluation.value = default_value
return flag_evaluation

after_hooks(
flag_type,
hook_context,
flag_evaluation,
reversed_merged_hooks,
reversed_merged_hooks_and_context,
hook_hints,
)

return flag_evaluation

except OpenFeatureError as err:
error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints)
error_hooks(flag_type, err, reversed_merged_hooks_and_context, hook_hints)

flag_evaluation = FlagEvaluationDetails(
flag_key=flag_key,
Expand All @@ -829,7 +846,7 @@ def evaluate_flag_details(
"Unable to correctly evaluate flag with key: '%s'", flag_key
)

error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints)
error_hooks(flag_type, err, reversed_merged_hooks_and_context, hook_hints)

error_message = getattr(err, "error_message", str(err))
flag_evaluation = FlagEvaluationDetails(
Expand All @@ -844,9 +861,8 @@ def evaluate_flag_details(
finally:
after_all_hooks(
flag_type,
hook_context,
flag_evaluation,
reversed_merged_hooks,
reversed_merged_hooks_and_context,
hook_hints,
)

Expand Down
11 changes: 9 additions & 2 deletions openfeature/hook/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import typing
from collections.abc import Sequence
from collections.abc import MutableMapping, Sequence
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING
Expand All @@ -16,6 +16,7 @@
__all__ = [
"Hook",
"HookContext",
"HookData",
"HookHints",
"HookType",
"add_hooks",
Expand All @@ -26,6 +27,10 @@
_hooks: list[Hook] = []


# https://openfeature.dev/specification/sections/hooks/#requirement-461
HookData = MutableMapping[str, typing.Any]


class HookType(Enum):
BEFORE = "before"
AFTER = "after"
Expand All @@ -34,21 +39,23 @@ class HookType(Enum):


class HookContext:
def __init__(
def __init__( # noqa: PLR0913
self,
flag_key: str,
flag_type: FlagType,
default_value: FlagValueType,
evaluation_context: EvaluationContext,
client_metadata: typing.Optional[ClientMetadata] = None,
provider_metadata: typing.Optional[Metadata] = None,
hook_data: typing.Optional[HookData] = None,
):
self.flag_key = flag_key
self.flag_type = flag_type
self.default_value = default_value
self.evaluation_context = evaluation_context
self.client_metadata = client_metadata
self.provider_metadata = provider_metadata
self.hook_data = hook_data or {}

def __setattr__(self, key: str, value: typing.Any) -> None:
if hasattr(self, key) and key in (
Expand Down
Loading