Skip to content

Commit 266d2f1

Browse files
committed
feat(flagd-rpc): add caching with tests
Signed-off-by: Simon Schrottner <[email protected]>
1 parent b62d3d1 commit 266d2f1

File tree

13 files changed

+477
-23
lines changed

13 files changed

+477
-23
lines changed

providers/openfeature-provider-flagd/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ dependencies = [
2424
"panzi-json-logic>=1.0.1",
2525
"semver>=3,<4",
2626
"pyyaml>=6.0.1",
27+
"cachebox"
2728
]
2829
requires-python = ">=3.8"
2930

@@ -59,6 +60,7 @@ cov = [
5960
"cov-report",
6061
]
6162

63+
6264
[tool.hatch.envs.mypy]
6365
dependencies = [
6466
"mypy[faster-cache]>=1.13.0",
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[pytest]
2+
markers =
3+
rpc: tests for rpc mode.
4+
in-process: tests for rpc mode.
5+
customCert: Supports custom certs.
6+
unixsocket: Supports unixsockets.
7+
events: Supports events.
8+
sync: Supports sync.
9+
caching: Supports caching.
10+
offline: Supports offline.

providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ class ResolverType(Enum):
88
IN_PROCESS = "in-process"
99

1010

11+
class CacheType(Enum):
12+
LRU = "lru"
13+
DISABLED = "disabled"
14+
15+
16+
DEFAULT_CACHE = CacheType.LRU
17+
DEFAULT_CACHE_SIZE = 1000
1118
DEFAULT_DEADLINE = 500
1219
DEFAULT_HOST = "localhost"
1320
DEFAULT_KEEP_ALIVE = 0
@@ -19,12 +26,14 @@ class ResolverType(Enum):
1926
DEFAULT_STREAM_DEADLINE = 600000
2027
DEFAULT_TLS = False
2128

29+
ENV_VAR_CACHE_SIZE = "FLAGD_MAX_CACHE_SIZE"
30+
ENV_VAR_CACHE_TYPE = "FLAGD_CACHE"
2231
ENV_VAR_DEADLINE_MS = "FLAGD_DEADLINE_MS"
2332
ENV_VAR_HOST = "FLAGD_HOST"
2433
ENV_VAR_KEEP_ALIVE_TIME_MS = "FLAGD_KEEP_ALIVE_TIME_MS"
2534
ENV_VAR_OFFLINE_FLAG_SOURCE_PATH = "FLAGD_OFFLINE_FLAG_SOURCE_PATH"
2635
ENV_VAR_PORT = "FLAGD_PORT"
27-
ENV_VAR_RESOLVER_TYPE = "FLAGD_RESOLVER_TYPE"
36+
ENV_VAR_RESOLVER_TYPE = "FLAGD_RESOLVER"
2837
ENV_VAR_RETRY_BACKOFF_MS = "FLAGD_RETRY_BACKOFF_MS"
2938
ENV_VAR_STREAM_DEADLINE_MS = "FLAGD_STREAM_DEADLINE_MS"
3039
ENV_VAR_TLS = "FLAGD_TLS"
@@ -36,6 +45,14 @@ def str_to_bool(val: str) -> bool:
3645
return val.lower() == "true"
3746

3847

48+
def convert_resolver_type(val: typing.Union[str, ResolverType]) -> ResolverType:
49+
if isinstance(val, str):
50+
v = val.lower()
51+
return ResolverType(v)
52+
else:
53+
return ResolverType(val)
54+
55+
3956
def env_or_default(
4057
env_var: str, default: T, cast: typing.Optional[typing.Callable[[str], T]] = None
4158
) -> typing.Union[str, T]:
@@ -56,7 +73,9 @@ def __init__( # noqa: PLR0913
5673
retry_backoff_ms: typing.Optional[int] = None,
5774
deadline: typing.Optional[int] = None,
5875
stream_deadline_ms: typing.Optional[int] = None,
59-
keep_alive_time: typing.Optional[int] = None,
76+
keep_alive: typing.Optional[int] = None,
77+
cache_type: typing.Optional[CacheType] = None,
78+
max_cache_size: typing.Optional[int] = None,
6079
):
6180
self.host = env_or_default(ENV_VAR_HOST, DEFAULT_HOST) if host is None else host
6281

@@ -77,7 +96,9 @@ def __init__( # noqa: PLR0913
7796
)
7897

7998
self.resolver_type = (
80-
ResolverType(env_or_default(ENV_VAR_RESOLVER_TYPE, DEFAULT_RESOLVER_TYPE))
99+
env_or_default(
100+
ENV_VAR_RESOLVER_TYPE, DEFAULT_RESOLVER_TYPE, cast=convert_resolver_type
101+
)
81102
if resolver_type is None
82103
else resolver_type
83104
)
@@ -118,10 +139,22 @@ def __init__( # noqa: PLR0913
118139
else stream_deadline_ms
119140
)
120141

121-
self.keep_alive_time: int = (
142+
self.keep_alive: int = (
122143
int(
123144
env_or_default(ENV_VAR_KEEP_ALIVE_TIME_MS, DEFAULT_KEEP_ALIVE, cast=int)
124145
)
125-
if keep_alive_time is None
126-
else keep_alive_time
146+
if keep_alive is None
147+
else keep_alive
148+
)
149+
150+
self.cache_type = (
151+
CacheType(env_or_default(ENV_VAR_CACHE_TYPE, DEFAULT_CACHE))
152+
if cache_type is None
153+
else cache_type
154+
)
155+
156+
self.max_cache_size: int = (
157+
int(env_or_default(ENV_VAR_CACHE_SIZE, DEFAULT_CACHE_SIZE, cast=int))
158+
if max_cache_size is None
159+
else max_cache_size
127160
)

providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from openfeature.provider.metadata import Metadata
3030
from openfeature.provider.provider import AbstractProvider
3131

32-
from .config import Config, ResolverType
32+
from .config import CacheType, Config, ResolverType
3333
from .resolvers import AbstractResolver, GrpcResolver, InProcessResolver
3434

3535
T = typing.TypeVar("T")
@@ -50,6 +50,8 @@ def __init__( # noqa: PLR0913
5050
offline_flag_source_path: typing.Optional[str] = None,
5151
stream_deadline_ms: typing.Optional[int] = None,
5252
keep_alive_time: typing.Optional[int] = None,
53+
cache_type: typing.Optional[CacheType] = None,
54+
max_cache_size: typing.Optional[int] = None,
5355
):
5456
"""
5557
Create an instance of the FlagdProvider
@@ -82,7 +84,9 @@ def __init__( # noqa: PLR0913
8284
resolver_type=resolver_type,
8385
offline_flag_source_path=offline_flag_source_path,
8486
stream_deadline_ms=stream_deadline_ms,
85-
keep_alive_time=keep_alive_time,
87+
keep_alive=keep_alive_time,
88+
cache_type=cache_type,
89+
max_cache_size=max_cache_size,
8690
)
8791

8892
self.resolver = self.setup_resolver()

providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/grpc.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import typing
55

66
import grpc
7+
from cachebox import BaseCacheImpl, LRUCache
78
from google.protobuf.json_format import MessageToDict
89
from google.protobuf.struct_pb2 import Struct
910

@@ -18,13 +19,13 @@
1819
ProviderNotReadyError,
1920
TypeMismatchError,
2021
)
21-
from openfeature.flag_evaluation import FlagResolutionDetails
22+
from openfeature.flag_evaluation import FlagResolutionDetails, Reason
2223
from openfeature.schemas.protobuf.flagd.evaluation.v1 import (
2324
evaluation_pb2,
2425
evaluation_pb2_grpc,
2526
)
2627

27-
from ..config import Config
28+
from ..config import CacheType, Config
2829
from ..flag_type import FlagType
2930

3031
if typing.TYPE_CHECKING:
@@ -57,24 +58,40 @@ def __init__(
5758
self.deadline = config.deadline * 0.001
5859
self.connected = False
5960

61+
self._cache: typing.Optional[BaseCacheImpl] = (
62+
LRUCache(maxsize=self.config.max_cache_size)
63+
if self.config.cache_type == CacheType.LRU
64+
else None
65+
)
66+
6067
def _create_stub(
6168
self,
6269
) -> typing.Tuple[evaluation_pb2_grpc.ServiceStub, grpc.Channel]:
6370
config = self.config
6471
channel_factory = grpc.secure_channel if config.tls else grpc.insecure_channel
6572
channel = channel_factory(
6673
f"{config.host}:{config.port}",
67-
options=(("grpc.keepalive_time_ms", config.keep_alive_time),),
74+
options=(("grpc.keepalive_time_ms", config.keep_alive),),
6875
)
6976
stub = evaluation_pb2_grpc.ServiceStub(channel)
7077
return stub, channel
7178

7279
def initialize(self, evaluation_context: EvaluationContext) -> None:
7380
self.connect()
81+
self.retry_backoff_seconds = 0.1
82+
self.connected = False
83+
84+
self._cache = (
85+
LRUCache(maxsize=self.config.max_cache_size)
86+
if self.config.cache_type == CacheType.LRU
87+
else None
88+
)
7489

7590
def shutdown(self) -> None:
7691
self.active = False
7792
self.channel.close()
93+
if self._cache:
94+
self._cache.clear()
7895

7996
def connect(self) -> None:
8097
self.active = True
@@ -96,7 +113,6 @@ def connect(self) -> None:
96113

97114
def listen(self) -> None:
98115
retry_delay = self.retry_backoff_seconds
99-
100116
call_args = (
101117
{"timeout": self.streamline_deadline_seconds}
102118
if self.streamline_deadline_seconds > 0
@@ -148,6 +164,10 @@ def listen(self) -> None:
148164
def handle_changed_flags(self, data: typing.Any) -> None:
149165
changed_flags = list(data["flags"].keys())
150166

167+
if self._cache:
168+
for flag in changed_flags:
169+
self._cache.pop(flag)
170+
151171
self.emit_provider_configuration_changed(ProviderEventDetails(changed_flags))
152172

153173
def resolve_boolean_details(
@@ -190,13 +210,18 @@ def resolve_object_details(
190210
) -> FlagResolutionDetails[typing.Union[dict, list]]:
191211
return self._resolve(key, FlagType.OBJECT, default_value, evaluation_context)
192212

193-
def _resolve( # noqa: PLR0915
213+
def _resolve( # noqa: PLR0915 C901
194214
self,
195215
flag_key: str,
196216
flag_type: FlagType,
197217
default_value: T,
198218
evaluation_context: typing.Optional[EvaluationContext],
199219
) -> FlagResolutionDetails[T]:
220+
if self._cache is not None and flag_key in self._cache:
221+
cached_flag: FlagResolutionDetails[T] = self._cache[flag_key]
222+
cached_flag.reason = Reason.CACHED
223+
return cached_flag
224+
200225
context = self._convert_context(evaluation_context)
201226
call_args = {"timeout": self.deadline}
202227
try:
@@ -249,12 +274,17 @@ def _resolve( # noqa: PLR0915
249274
raise GeneralError(message) from e
250275

251276
# Got a valid flag and valid type. Return it.
252-
return FlagResolutionDetails(
277+
result = FlagResolutionDetails(
253278
value=value,
254279
reason=response.reason,
255280
variant=response.variant,
256281
)
257282

283+
if response.reason == Reason.STATIC and self._cache is not None:
284+
self._cache.insert(flag_key, result)
285+
286+
return result
287+
258288
def _convert_context(
259289
self, evaluation_context: typing.Optional[EvaluationContext]
260290
) -> Struct:

0 commit comments

Comments
 (0)