|
1 | 1 | import re |
2 | 2 | from datetime import datetime, timedelta, timezone |
3 | 3 | from email.utils import parsedate_to_datetime |
4 | | -from typing import Any, Dict, List, NoReturn, Optional, Union |
| 4 | +from typing import Any, Dict, List, NoReturn, Optional, Tuple, Type, Union |
5 | 5 | from urllib.parse import urljoin |
6 | 6 |
|
7 | 7 | import requests |
|
16 | 16 | OpenFeatureError, |
17 | 17 | ParseError, |
18 | 18 | TargetingKeyMissingError, |
| 19 | + TypeMismatchError, |
19 | 20 | ) |
20 | | -from openfeature.flag_evaluation import FlagResolutionDetails, Reason |
| 21 | +from openfeature.flag_evaluation import FlagResolutionDetails, FlagType, Reason |
21 | 22 | from openfeature.hook import Hook |
22 | 23 | from openfeature.provider import AbstractProvider, Metadata |
23 | 24 |
|
24 | 25 | __all__ = ["OFREPProvider"] |
25 | 26 |
|
26 | 27 |
|
| 28 | +TypeMap = Dict[ |
| 29 | + FlagType, |
| 30 | + Union[ |
| 31 | + Type[bool], |
| 32 | + Type[int], |
| 33 | + Type[float], |
| 34 | + Type[str], |
| 35 | + Tuple[Type[dict], Type[list]], |
| 36 | + ], |
| 37 | +] |
| 38 | + |
| 39 | + |
27 | 40 | class OFREPProvider(AbstractProvider): |
28 | 41 | def __init__( |
29 | 42 | self, |
@@ -53,42 +66,53 @@ def resolve_boolean_details( |
53 | 66 | default_value: bool, |
54 | 67 | evaluation_context: Optional[EvaluationContext] = None, |
55 | 68 | ) -> FlagResolutionDetails[bool]: |
56 | | - return self._resolve(flag_key, default_value, evaluation_context) |
| 69 | + return self._resolve( |
| 70 | + FlagType.BOOLEAN, flag_key, default_value, evaluation_context |
| 71 | + ) |
57 | 72 |
|
58 | 73 | def resolve_string_details( |
59 | 74 | self, |
60 | 75 | flag_key: str, |
61 | 76 | default_value: str, |
62 | 77 | evaluation_context: Optional[EvaluationContext] = None, |
63 | 78 | ) -> FlagResolutionDetails[str]: |
64 | | - return self._resolve(flag_key, default_value, evaluation_context) |
| 79 | + return self._resolve( |
| 80 | + FlagType.STRING, flag_key, default_value, evaluation_context |
| 81 | + ) |
65 | 82 |
|
66 | 83 | def resolve_integer_details( |
67 | 84 | self, |
68 | 85 | flag_key: str, |
69 | 86 | default_value: int, |
70 | 87 | evaluation_context: Optional[EvaluationContext] = None, |
71 | 88 | ) -> FlagResolutionDetails[int]: |
72 | | - return self._resolve(flag_key, default_value, evaluation_context) |
| 89 | + return self._resolve( |
| 90 | + FlagType.INTEGER, flag_key, default_value, evaluation_context |
| 91 | + ) |
73 | 92 |
|
74 | 93 | def resolve_float_details( |
75 | 94 | self, |
76 | 95 | flag_key: str, |
77 | 96 | default_value: float, |
78 | 97 | evaluation_context: Optional[EvaluationContext] = None, |
79 | 98 | ) -> FlagResolutionDetails[float]: |
80 | | - return self._resolve(flag_key, default_value, evaluation_context) |
| 99 | + return self._resolve( |
| 100 | + FlagType.FLOAT, flag_key, default_value, evaluation_context |
| 101 | + ) |
81 | 102 |
|
82 | 103 | def resolve_object_details( |
83 | 104 | self, |
84 | 105 | flag_key: str, |
85 | 106 | default_value: Union[dict, list], |
86 | 107 | evaluation_context: Optional[EvaluationContext] = None, |
87 | 108 | ) -> FlagResolutionDetails[Union[dict, list]]: |
88 | | - return self._resolve(flag_key, default_value, evaluation_context) |
| 109 | + return self._resolve( |
| 110 | + FlagType.OBJECT, flag_key, default_value, evaluation_context |
| 111 | + ) |
89 | 112 |
|
90 | 113 | def _resolve( |
91 | 114 | self, |
| 115 | + flag_type: FlagType, |
92 | 116 | flag_key: str, |
93 | 117 | default_value: Union[bool, str, int, float, dict, list], |
94 | 118 | evaluation_context: Optional[EvaluationContext] = None, |
@@ -117,6 +141,8 @@ def _resolve( |
117 | 141 | except JSONDecodeError as e: |
118 | 142 | raise ParseError(str(e)) from e |
119 | 143 |
|
| 144 | + _typecheck_flag_value(data["value"], flag_type) |
| 145 | + |
120 | 146 | return FlagResolutionDetails( |
121 | 147 | value=data["value"], |
122 | 148 | reason=Reason[data["reason"]], |
@@ -178,3 +204,18 @@ def _parse_retry_after(retry_after: Optional[str]) -> Optional[datetime]: |
178 | 204 | seconds = int(retry_after) |
179 | 205 | return datetime.now(timezone.utc) + timedelta(seconds=seconds) |
180 | 206 | return parsedate_to_datetime(retry_after) |
| 207 | + |
| 208 | + |
| 209 | +def _typecheck_flag_value(value: Any, flag_type: FlagType) -> None: |
| 210 | + type_map: TypeMap = { |
| 211 | + FlagType.BOOLEAN: bool, |
| 212 | + FlagType.STRING: str, |
| 213 | + FlagType.OBJECT: (dict, list), |
| 214 | + FlagType.FLOAT: float, |
| 215 | + FlagType.INTEGER: int, |
| 216 | + } |
| 217 | + _type = type_map.get(flag_type) |
| 218 | + if not _type: |
| 219 | + raise GeneralError(error_message="Unknown flag type") |
| 220 | + if not isinstance(value, _type): |
| 221 | + raise TypeMismatchError(f"Expected type {_type} but got {type(value)}") |
0 commit comments