|
| 1 | +import logging |
| 2 | +import threading |
| 3 | +import time |
1 | 4 | import typing |
2 | 5 |
|
3 | 6 | import grpc |
| 7 | +from cachebox import LRUCache # type:ignore[import-not-found] |
4 | 8 | from google.protobuf.json_format import MessageToDict |
5 | 9 | from google.protobuf.struct_pb2 import Struct |
6 | 10 | from schemas.protobuf.flagd.evaluation.v1 import ( # type:ignore[import-not-found] |
|
9 | 13 | ) |
10 | 14 |
|
11 | 15 | from openfeature.evaluation_context import EvaluationContext |
| 16 | +from openfeature.event import ProviderEventDetails |
12 | 17 | from openfeature.exception import ( |
| 18 | + ErrorCode, |
13 | 19 | FlagNotFoundError, |
14 | 20 | GeneralError, |
15 | 21 | InvalidContextError, |
16 | 22 | ParseError, |
17 | 23 | TypeMismatchError, |
18 | 24 | ) |
19 | | -from openfeature.flag_evaluation import FlagResolutionDetails |
| 25 | +from openfeature.flag_evaluation import FlagResolutionDetails, Reason |
20 | 26 |
|
21 | | -from ..config import Config |
| 27 | +from ..config import CacheType, Config |
22 | 28 | from ..flag_type import FlagType |
23 | 29 | from .protocol import AbstractResolver |
24 | 30 |
|
25 | 31 | T = typing.TypeVar("T") |
26 | 32 |
|
| 33 | +logger = logging.getLogger("openfeature.contrib") |
| 34 | + |
27 | 35 |
|
28 | 36 | class GrpcResolver(AbstractResolver): |
29 | | - def __init__(self, config: Config): |
| 37 | + MAX_BACK_OFF = 120 |
| 38 | + |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + config: Config, |
| 42 | + emit_provider_ready: typing.Callable[[ProviderEventDetails], None], |
| 43 | + emit_provider_error: typing.Callable[[ProviderEventDetails], None], |
| 44 | + emit_provider_configuration_changed: typing.Callable[ |
| 45 | + [ProviderEventDetails], None |
| 46 | + ], |
| 47 | + ): |
30 | 48 | self.config = config |
| 49 | + self.emit_provider_ready = emit_provider_ready |
| 50 | + self.emit_provider_error = emit_provider_error |
| 51 | + self.emit_provider_configuration_changed = emit_provider_configuration_changed |
31 | 52 | channel_factory = ( |
32 | 53 | grpc.secure_channel if self.config.tls else grpc.insecure_channel |
33 | 54 | ) |
34 | 55 | self.channel = channel_factory(f"{self.config.host}:{self.config.port}") |
35 | 56 | self.stub = evaluation_pb2_grpc.ServiceStub(self.channel) |
| 57 | + self.retry_backoff_seconds = 0.1 |
| 58 | + self.connected = False |
| 59 | + |
| 60 | + self._cache = ( |
| 61 | + LRUCache(maxsize=self.config.max_cache_size) |
| 62 | + if self.config.cache_type == CacheType.LRU |
| 63 | + else None |
| 64 | + ) |
| 65 | + |
| 66 | + def initialize(self, evaluation_context: EvaluationContext) -> None: |
| 67 | + self.connect() |
36 | 68 |
|
37 | 69 | def shutdown(self) -> None: |
| 70 | + self.active = False |
38 | 71 | self.channel.close() |
| 72 | + if self._cache: |
| 73 | + self._cache.clear() |
| 74 | + |
| 75 | + def connect(self) -> None: |
| 76 | + self.active = True |
| 77 | + self.thread = threading.Thread( |
| 78 | + target=self.listen, daemon=True, name="FlagdGrpcServiceWorkerThread" |
| 79 | + ) |
| 80 | + self.thread.start() |
| 81 | + |
| 82 | + def listen(self) -> None: |
| 83 | + retry_delay = self.retry_backoff_seconds |
| 84 | + while self.active: |
| 85 | + request = evaluation_pb2.EventStreamRequest() |
| 86 | + try: |
| 87 | + logger.debug("Setting up gRPC sync flags connection") |
| 88 | + for message in self.stub.EventStream(request): |
| 89 | + if message.type == "provider_ready": |
| 90 | + if not self.connected: |
| 91 | + self.emit_provider_ready( |
| 92 | + ProviderEventDetails( |
| 93 | + message="gRPC sync connection established" |
| 94 | + ) |
| 95 | + ) |
| 96 | + self.connected = True |
| 97 | + # reset retry delay after successsful read |
| 98 | + retry_delay = self.retry_backoff_seconds |
| 99 | + |
| 100 | + elif message.type == "configuration_change": |
| 101 | + data = MessageToDict(message)["data"] |
| 102 | + self.handle_changed_flags(data) |
| 103 | + |
| 104 | + if not self.active: |
| 105 | + logger.info("Terminating gRPC sync thread") |
| 106 | + return |
| 107 | + except grpc.RpcError as e: |
| 108 | + logger.error(f"SyncFlags stream error, {e.code()=} {e.details()=}") |
| 109 | + except ParseError: |
| 110 | + logger.exception( |
| 111 | + f"Could not parse flag data using flagd syntax: {message=}" |
| 112 | + ) |
| 113 | + |
| 114 | + self.connected = False |
| 115 | + self.emit_provider_error( |
| 116 | + ProviderEventDetails( |
| 117 | + message=f"gRPC sync disconnected, reconnecting in {retry_delay}s", |
| 118 | + error_code=ErrorCode.GENERAL, |
| 119 | + ) |
| 120 | + ) |
| 121 | + logger.info(f"gRPC sync disconnected, reconnecting in {retry_delay}s") |
| 122 | + time.sleep(retry_delay) |
| 123 | + retry_delay = min(2 * retry_delay, self.MAX_BACK_OFF) |
| 124 | + |
| 125 | + def handle_changed_flags(self, data: typing.Any) -> None: |
| 126 | + changed_flags = list(data["flags"].keys()) |
| 127 | + |
| 128 | + if self._cache: |
| 129 | + for flag in changed_flags: |
| 130 | + self._cache.pop(flag) |
| 131 | + |
| 132 | + self.emit_provider_configuration_changed(ProviderEventDetails(changed_flags)) |
39 | 133 |
|
40 | 134 | def resolve_boolean_details( |
41 | 135 | self, |
@@ -77,13 +171,18 @@ def resolve_object_details( |
77 | 171 | ) -> FlagResolutionDetails[typing.Union[dict, list]]: |
78 | 172 | return self._resolve(key, FlagType.OBJECT, default_value, evaluation_context) |
79 | 173 |
|
80 | | - def _resolve( # noqa: PLR0915 |
| 174 | + def _resolve( # noqa: PLR0915 C901 |
81 | 175 | self, |
82 | 176 | flag_key: str, |
83 | 177 | flag_type: FlagType, |
84 | 178 | default_value: T, |
85 | 179 | evaluation_context: typing.Optional[EvaluationContext], |
86 | 180 | ) -> FlagResolutionDetails[T]: |
| 181 | + if self._cache is not None and flag_key in self._cache: |
| 182 | + cached_flag: FlagResolutionDetails[T] = self._cache[flag_key] |
| 183 | + cached_flag.reason = Reason.CACHED |
| 184 | + return cached_flag |
| 185 | + |
87 | 186 | context = self._convert_context(evaluation_context) |
88 | 187 | call_args = {"timeout": self.config.timeout} |
89 | 188 | try: |
@@ -135,12 +234,17 @@ def _resolve( # noqa: PLR0915 |
135 | 234 | raise GeneralError(message) from e |
136 | 235 |
|
137 | 236 | # Got a valid flag and valid type. Return it. |
138 | | - return FlagResolutionDetails( |
| 237 | + result = FlagResolutionDetails( |
139 | 238 | value=value, |
140 | 239 | reason=response.reason, |
141 | 240 | variant=response.variant, |
142 | 241 | ) |
143 | 242 |
|
| 243 | + if response.reason == Reason.STATIC and self._cache is not None: |
| 244 | + self._cache.insert(flag_key, result) |
| 245 | + |
| 246 | + return result |
| 247 | + |
144 | 248 | def _convert_context( |
145 | 249 | self, evaluation_context: typing.Optional[EvaluationContext] |
146 | 250 | ) -> Struct: |
|
0 commit comments