Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
import re
from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta, timezone
from email.utils import parsedate_to_datetime
from json import JSONDecodeError
from typing import Any, Callable, NoReturn, Optional, Union
from urllib.parse import urljoin

import requests
from requests.exceptions import JSONDecodeError
import httpx

from openfeature.evaluation_context import EvaluationContext
from openfeature.exception import (
Expand Down Expand Up @@ -55,7 +56,23 @@ def __init__(
self.headers_factory = headers_factory
self.timeout = timeout
self.retry_after: Optional[datetime] = None
self.session = requests.Session()

self.client = httpx.Client()
self.client_async = httpx.AsyncClient()
self._client_async_is_entered = False

def initialize(self, evaluation_context: EvaluationContext) -> None:
self.client.__enter__()
Comment on lines +64 to +65
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why enter the client here? when you initialize it in the constructor then it should be already entered or not?


def shutdown(self) -> None:
self.client.__exit__(None, None, None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just call .close() on it


try:
# TODO(someday): support non asyncio runtimes here
asyncio.get_running_loop().create_task(self.client_async.__aexit__(None, None, None))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here aclose()

self._client_async_is_entered = False
except Exception:
pass

def get_metadata(self) -> Metadata:
return Metadata(name="OpenFeature Remote Evaluation Protocol Provider")
Expand All @@ -73,6 +90,16 @@ def resolve_boolean_details(
FlagType.BOOLEAN, flag_key, default_value, evaluation_context
)

async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return await self._resolve_async(
FlagType.BOOLEAN, flag_key, default_value, evaluation_context
)

def resolve_string_details(
self,
flag_key: str,
Expand All @@ -83,6 +110,16 @@ def resolve_string_details(
FlagType.STRING, flag_key, default_value, evaluation_context
)

async def resolve_string_details_async(
self,
flag_key: str,
default_value: str,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
return await self._resolve_async(
FlagType.STRING, flag_key, default_value, evaluation_context
)

def resolve_integer_details(
self,
flag_key: str,
Expand All @@ -93,6 +130,16 @@ def resolve_integer_details(
FlagType.INTEGER, flag_key, default_value, evaluation_context
)

async def resolve_integer_details_async(
self,
flag_key: str,
default_value: int,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
return await self._resolve_async(
FlagType.INTEGER, flag_key, default_value, evaluation_context
)

def resolve_float_details(
self,
flag_key: str,
Expand All @@ -103,6 +150,17 @@ def resolve_float_details(
FlagType.FLOAT, flag_key, default_value, evaluation_context
)


async def resolve_float_details_async(
self,
flag_key: str,
default_value: float,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
return await self._resolve_async(
FlagType.FLOAT, flag_key, default_value, evaluation_context
)

def resolve_object_details(
self,
flag_key: str,
Expand All @@ -115,6 +173,16 @@ def resolve_object_details(
FlagType.OBJECT, flag_key, default_value, evaluation_context
)

async def resolve_object_details_async(
self,
flag_key: str,
default_value: Union[Sequence[FlagValueType], Mapping[str, FlagValueType]],
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]:
return await self._resolve_async(
FlagType.OBJECT, flag_key, default_value, evaluation_context
)

def _get_ofrep_api_url(self, api_version: str = "v1") -> str:
ofrep_base_url = (
self.base_url if self.base_url.endswith("/") else f"{self.base_url}/"
Expand Down Expand Up @@ -146,15 +214,15 @@ def _resolve(
self.retry_after = None

try:
response = self.session.post(
response = self.client.post(
urljoin(self._get_ofrep_api_url(), f"evaluate/flags/{flag_key}"),
json=_build_request_data(evaluation_context),
timeout=self.timeout,
headers=self.headers_factory() if self.headers_factory else None,
)
response.raise_for_status()

except requests.RequestException as e:
except httpx.HTTPError as e:
self._handle_error(e)

try:
Expand All @@ -171,11 +239,66 @@ def _resolve(
flag_metadata=data.get("metadata", {}),
)

def _handle_error(self, exception: requests.RequestException) -> NoReturn:
response = exception.response
if response is None:
async def _resolve_async(
self,
flag_type: FlagType,
flag_key: str,
default_value: Union[
bool,
str,
int,
float,
dict,
list,
Sequence[FlagValueType],
Mapping[str, FlagValueType],
],
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[Any]:
if not self._client_async_is_entered:
await self.client_async.__aenter__()
self._client_async_is_entered = True
Comment on lines +258 to +260
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as far as I can tell from the official docs this is not needed


now = datetime.now(timezone.utc)
if self.retry_after and now <= self.retry_after:
raise GeneralError(
f"OFREP evaluation paused due to TooManyRequests until {self.retry_after}"
)
elif self.retry_after:
self.retry_after = None
Comment on lines +262 to +268
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated code, please create a small private method for it


try:
response = await self.client_async.post(
urljoin(self._get_ofrep_api_url(), f"evaluate/flags/{flag_key}"),
json=_build_request_data(evaluation_context),
timeout=self.timeout,
headers=self.headers_factory() if self.headers_factory else None,
)
response.raise_for_status()

except httpx.HTTPError as e:
self._handle_error(e)

try:
data = response.json()
except JSONDecodeError as e:
raise ParseError(str(e)) from e

_typecheck_flag_value(data["value"], flag_type)

return FlagResolutionDetails(
value=data["value"],
reason=Reason[data["reason"]],
variant=data["variant"],
flag_metadata=data.get("metadata", {}),
)
Comment on lines +282 to +294
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this


def _handle_error(self, exception: httpx.HTTPError) -> NoReturn:
if not isinstance(exception, httpx.HTTPStatusError):
raise GeneralError(str(exception)) from exception

response = exception.response

if response.status_code == 429:
retry_after = response.headers.get("Retry-After")
self.retry_after = _parse_retry_after(retry_after)
Expand Down Expand Up @@ -205,6 +328,10 @@ def _handle_error(self, exception: requests.RequestException) -> NoReturn:

raise OpenFeatureError(error_code, error_details) from exception

def __del__(self):
# Ensure clients get cleaned up
self.shutdown()
Comment on lines +331 to +333
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, if this is really needed, but it is better to use weakref for this.



def _build_request_data(
evaluation_context: Optional[EvaluationContext],
Expand Down
Loading