Skip to content

Commit c7f81b6

Browse files
committed
Merge branch 'feat/caching' into feat/grpc-sync-addition
2 parents af0df41 + f01d6e5 commit c7f81b6

File tree

14 files changed

+7130
-187
lines changed

14 files changed

+7130
-187
lines changed

providers/openfeature-provider-flagd/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ dependencies = [
2424
"panzi-json-logic>=1.0.1",
2525
"semver>=3,<4",
2626
"pyyaml>=6.0.1",
27+
"cachebox"
2728
]
2829
requires-python = ">=3.8"
2930

providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22
import typing
33
from enum import Enum
44

5+
ENV_VAR_MAX_CACHE_SIZE = "FLAGD_MAX_CACHE_SIZE"
6+
ENV_VAR_CACHE_TYPE = "FLAGD_CACHE_TYPE"
7+
ENV_VAR_OFFLINE_POLL_INTERVAL_SECONDS = "FLAGD_OFFLINE_POLL_INTERVAL_SECONDS"
8+
ENV_VAR_OFFLINE_FLAG_SOURCE_PATH = "FLAGD_OFFLINE_FLAG_SOURCE_PATH"
9+
ENV_VAR_PORT = "FLAGD_PORT"
10+
ENV_VAR_RESOLVER_TYPE = "FLAGD_RESOLVER_TYPE"
11+
ENV_VAR_TLS = "FLAGD_TLS"
12+
ENV_VAR_HOST = "FLAGD_HOST"
13+
514
T = typing.TypeVar("T")
615

716

@@ -23,6 +32,11 @@ class ResolverType(Enum):
2332
IN_PROCESS = "in-process"
2433

2534

35+
class CacheType(Enum):
36+
LRU = "lru"
37+
DISABLED = "disabled"
38+
39+
2640
class Config:
2741
def __init__( # noqa: PLR0913
2842
self,
@@ -35,13 +49,12 @@ def __init__( # noqa: PLR0913
3549
resolver_type: typing.Optional[ResolverType] = None,
3650
offline_flag_source_path: typing.Optional[str] = None,
3751
offline_poll_interval_seconds: typing.Optional[float] = None,
52+
cache_type: typing.Optional[CacheType] = None,
53+
max_cache_size: typing.Optional[int] = None,
3854
):
39-
self.host = env_or_default("FLAGD_HOST", "localhost") if host is None else host
40-
self.port = (
41-
env_or_default("FLAGD_PORT", 8013, cast=int) if port is None else port
42-
)
55+
self.host = env_or_default(ENV_VAR_HOST, "localhost") if host is None else host
4356
self.tls = (
44-
env_or_default("FLAGD_TLS", False, cast=str_to_bool) if tls is None else tls
57+
env_or_default(ENV_VAR_TLS, False, cast=str_to_bool) if tls is None else tls
4558
)
4659
self.timeout = 5 if timeout is None else timeout
4760
self.retry_backoff_seconds: float = (
@@ -53,17 +66,36 @@ def __init__( # noqa: PLR0913
5366
env_or_default("FLAGD_SELECTOR", None) if selector is None else selector
5467
)
5568
self.resolver_type = (
56-
ResolverType(env_or_default("FLAGD_RESOLVER_TYPE", "grpc"))
69+
ResolverType(env_or_default(ENV_VAR_RESOLVER_TYPE, "grpc"))
5770
if resolver_type is None
5871
else resolver_type
5972
)
73+
74+
default_port = 8013 if self.resolver_type is ResolverType.GRPC else 8015
75+
self.port = (
76+
env_or_default(ENV_VAR_PORT, default_port, cast=int)
77+
if port is None
78+
else port
79+
)
6080
self.offline_flag_source_path = (
61-
env_or_default("FLAGD_OFFLINE_FLAG_SOURCE_PATH", None)
81+
env_or_default(ENV_VAR_OFFLINE_FLAG_SOURCE_PATH, None)
6282
if offline_flag_source_path is None
6383
else offline_flag_source_path
6484
)
6585
self.offline_poll_interval_seconds = (
66-
float(env_or_default("FLAGD_OFFLINE_POLL_INTERVAL_SECONDS", 1.0))
86+
float(env_or_default(ENV_VAR_OFFLINE_POLL_INTERVAL_SECONDS, 1.0))
6787
if offline_poll_interval_seconds is None
6888
else offline_poll_interval_seconds
6989
)
90+
91+
self.cache_type = (
92+
CacheType(env_or_default(ENV_VAR_CACHE_TYPE, CacheType.DISABLED))
93+
if cache_type is None
94+
else cache_type
95+
)
96+
97+
self.max_cache_size = (
98+
env_or_default(ENV_VAR_MAX_CACHE_SIZE, 16, cast=int)
99+
if max_cache_size is None
100+
else max_cache_size
101+
)

providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from openfeature.provider.metadata import Metadata
3030
from openfeature.provider.provider import AbstractProvider
3131

32-
from .config import Config, ResolverType
32+
from .config import CacheType, Config, ResolverType
3333
from .resolvers import AbstractResolver, GrpcResolver, InProcessResolver
3434

3535
T = typing.TypeVar("T")
@@ -49,6 +49,8 @@ def __init__( # noqa: PLR0913
4949
resolver_type: typing.Optional[ResolverType] = None,
5050
offline_flag_source_path: typing.Optional[str] = None,
5151
offline_poll_interval_seconds: typing.Optional[float] = None,
52+
cache_type: typing.Optional[CacheType] = None,
53+
max_cache_size: typing.Optional[int] = None,
5254
):
5355
"""
5456
Create an instance of the FlagdProvider
@@ -68,13 +70,20 @@ def __init__( # noqa: PLR0913
6870
resolver_type=resolver_type,
6971
offline_flag_source_path=offline_flag_source_path,
7072
offline_poll_interval_seconds=offline_poll_interval_seconds,
73+
cache_type=cache_type,
74+
max_cache_size=max_cache_size,
7175
)
7276

7377
self.resolver = self.setup_resolver()
7478

7579
def setup_resolver(self) -> AbstractResolver:
7680
if self.config.resolver_type == ResolverType.GRPC:
77-
return GrpcResolver(self.config)
81+
return GrpcResolver(
82+
self.config,
83+
self.emit_provider_ready,
84+
self.emit_provider_error,
85+
self.emit_provider_configuration_changed,
86+
)
7887
elif self.config.resolver_type == ResolverType.IN_PROCESS:
7988
return InProcessResolver(
8089
self.config,

providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/grpc.py

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import logging
2+
import threading
3+
import time
14
import typing
25

36
import grpc
7+
from cachebox import LRUCache # type:ignore[import-not-found]
48
from google.protobuf.json_format import MessageToDict
59
from google.protobuf.struct_pb2 import Struct
610
from schemas.protobuf.flagd.evaluation.v1 import ( # type:ignore[import-not-found]
@@ -9,33 +13,123 @@
913
)
1014

1115
from openfeature.evaluation_context import EvaluationContext
16+
from openfeature.event import ProviderEventDetails
1217
from openfeature.exception import (
18+
ErrorCode,
1319
FlagNotFoundError,
1420
GeneralError,
1521
InvalidContextError,
1622
ParseError,
1723
TypeMismatchError,
1824
)
19-
from openfeature.flag_evaluation import FlagResolutionDetails
25+
from openfeature.flag_evaluation import FlagResolutionDetails, Reason
2026

21-
from ..config import Config
27+
from ..config import CacheType, Config
2228
from ..flag_type import FlagType
2329
from .protocol import AbstractResolver
2430

2531
T = typing.TypeVar("T")
2632

33+
logger = logging.getLogger("openfeature.contrib")
34+
2735

2836
class GrpcResolver(AbstractResolver):
29-
def __init__(self, config: Config):
37+
MAX_BACK_OFF = 120
38+
39+
def __init__(
40+
self,
41+
config: Config,
42+
emit_provider_ready: typing.Callable[[ProviderEventDetails], None],
43+
emit_provider_error: typing.Callable[[ProviderEventDetails], None],
44+
emit_provider_configuration_changed: typing.Callable[
45+
[ProviderEventDetails], None
46+
],
47+
):
3048
self.config = config
49+
self.emit_provider_ready = emit_provider_ready
50+
self.emit_provider_error = emit_provider_error
51+
self.emit_provider_configuration_changed = emit_provider_configuration_changed
3152
channel_factory = (
3253
grpc.secure_channel if self.config.tls else grpc.insecure_channel
3354
)
3455
self.channel = channel_factory(f"{self.config.host}:{self.config.port}")
3556
self.stub = evaluation_pb2_grpc.ServiceStub(self.channel)
57+
self.retry_backoff_seconds = 0.1
58+
self.connected = False
59+
60+
self._cache = (
61+
LRUCache(maxsize=self.config.max_cache_size)
62+
if self.config.cache_type == CacheType.LRU
63+
else None
64+
)
65+
66+
def initialize(self, evaluation_context: EvaluationContext) -> None:
67+
self.connect()
3668

3769
def shutdown(self) -> None:
70+
self.active = False
3871
self.channel.close()
72+
if self._cache:
73+
self._cache.clear()
74+
75+
def connect(self) -> None:
76+
self.active = True
77+
self.thread = threading.Thread(
78+
target=self.listen, daemon=True, name="FlagdGrpcServiceWorkerThread"
79+
)
80+
self.thread.start()
81+
82+
def listen(self) -> None:
83+
retry_delay = self.retry_backoff_seconds
84+
while self.active:
85+
request = evaluation_pb2.EventStreamRequest()
86+
try:
87+
logger.debug("Setting up gRPC sync flags connection")
88+
for message in self.stub.EventStream(request):
89+
if message.type == "provider_ready":
90+
if not self.connected:
91+
self.emit_provider_ready(
92+
ProviderEventDetails(
93+
message="gRPC sync connection established"
94+
)
95+
)
96+
self.connected = True
97+
# reset retry delay after successsful read
98+
retry_delay = self.retry_backoff_seconds
99+
100+
elif message.type == "configuration_change":
101+
data = MessageToDict(message)["data"]
102+
self.handle_changed_flags(data)
103+
104+
if not self.active:
105+
logger.info("Terminating gRPC sync thread")
106+
return
107+
except grpc.RpcError as e:
108+
logger.error(f"SyncFlags stream error, {e.code()=} {e.details()=}")
109+
except ParseError:
110+
logger.exception(
111+
f"Could not parse flag data using flagd syntax: {message=}"
112+
)
113+
114+
self.connected = False
115+
self.emit_provider_error(
116+
ProviderEventDetails(
117+
message=f"gRPC sync disconnected, reconnecting in {retry_delay}s",
118+
error_code=ErrorCode.GENERAL,
119+
)
120+
)
121+
logger.info(f"gRPC sync disconnected, reconnecting in {retry_delay}s")
122+
time.sleep(retry_delay)
123+
retry_delay = min(2 * retry_delay, self.MAX_BACK_OFF)
124+
125+
def handle_changed_flags(self, data: typing.Any) -> None:
126+
changed_flags = list(data["flags"].keys())
127+
128+
if self._cache:
129+
for flag in changed_flags:
130+
self._cache.pop(flag)
131+
132+
self.emit_provider_configuration_changed(ProviderEventDetails(changed_flags))
39133

40134
def resolve_boolean_details(
41135
self,
@@ -77,13 +171,18 @@ def resolve_object_details(
77171
) -> FlagResolutionDetails[typing.Union[dict, list]]:
78172
return self._resolve(key, FlagType.OBJECT, default_value, evaluation_context)
79173

80-
def _resolve( # noqa: PLR0915
174+
def _resolve( # noqa: PLR0915 C901
81175
self,
82176
flag_key: str,
83177
flag_type: FlagType,
84178
default_value: T,
85179
evaluation_context: typing.Optional[EvaluationContext],
86180
) -> FlagResolutionDetails[T]:
181+
if self._cache is not None and flag_key in self._cache:
182+
cached_flag: FlagResolutionDetails[T] = self._cache[flag_key]
183+
cached_flag.reason = Reason.CACHED
184+
return cached_flag
185+
87186
context = self._convert_context(evaluation_context)
88187
call_args = {"timeout": self.config.timeout}
89188
try:
@@ -135,12 +234,17 @@ def _resolve( # noqa: PLR0915
135234
raise GeneralError(message) from e
136235

137236
# Got a valid flag and valid type. Return it.
138-
return FlagResolutionDetails(
237+
result = FlagResolutionDetails(
139238
value=value,
140239
reason=response.reason,
141240
variant=response.variant,
142241
)
143242

243+
if response.reason == Reason.STATIC and self._cache is not None:
244+
self._cache.insert(flag_key, result)
245+
246+
return result
247+
144248
def _convert_context(
145249
self, evaluation_context: typing.Optional[EvaluationContext]
146250
) -> Struct:

providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/in_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from .process.connector import FlagStateConnector
1010
from .process.connector.file_watcher import FileWatcher
1111
from .process.connector.grpc_watcher import GrpcWatcher
12-
from .process.targeting import targeting
1312
from .process.flags import FlagStore
13+
from .process.targeting import targeting
1414

1515
T = typing.TypeVar("T")
1616

providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/custom_ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def sem_ver(data: dict, *args: JsonLogicArg) -> typing.Optional[bool]: # noqa:
130130
arg1, op, arg2 = args
131131

132132
try:
133-
v1 = semver.Version.parse(str(arg1))
134-
v2 = semver.Version.parse(str(arg2))
133+
v1 = parse_version(arg1)
134+
v2 = parse_version(arg2)
135135
except ValueError as e:
136136
logger.exception(e)
137137
return None
@@ -155,3 +155,11 @@ def sem_ver(data: dict, *args: JsonLogicArg) -> typing.Optional[bool]: # noqa:
155155
else:
156156
logger.error(f"Op not supported by sem_ver: {op}")
157157
return None
158+
159+
160+
def parse_version(arg: typing.Any) -> semver.Version:
161+
version = str(arg)
162+
if version.startswith(("v", "V")):
163+
version = version[1:]
164+
165+
return semver.Version.parse(version)

0 commit comments

Comments
 (0)