Skip to content

Commit c63da12

Browse files
author
Varun Rathore
committed
Added unit testcase
1 parent 76b9b5f commit c63da12

File tree

3 files changed

+821
-83
lines changed

3 files changed

+821
-83
lines changed

firebase_admin/remote_config.py

Lines changed: 92 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@
3232
_REMOTE_CONFIG_ATTRIBUTE = '_remoteconfig'
3333
MAX_CONDITION_RECURSION_DEPTH = 10
3434
ValueSource = Literal['default', 'remote', 'static'] # Define the ValueSource type
35+
class PercentConditionOperator(Enum):
36+
"""Enum representing the available operators for percent conditions.
37+
"""
38+
LESS_OR_EQUAL = "LESS_OR_EQUAL"
39+
GREATER_THAN = "GREATER_THAN"
40+
BETWEEN = "BETWEEN"
41+
UNKNOWN = "UNKNOWN"
3542

3643
class CustomSignalOperator(Enum):
3744
"""Enum representing the available operators for custom signal conditions.
@@ -52,6 +59,7 @@ class CustomSignalOperator(Enum):
5259
SEMANTIC_VERSION_NOT_EQUAL = "SEMANTIC_VERSION_NOT_EQUAL"
5360
SEMANTIC_VERSION_GREATER_THAN = "SEMANTIC_VERSION_GREATER_THAN"
5461
SEMANTIC_VERSION_GREATER_EQUAL = "SEMANTIC_VERSION_GREATER_EQUAL"
62+
UNKNOWN = "UNKNOWN"
5563

5664
class ServerTemplateData:
5765
"""Represents a Server Template Data class."""
@@ -131,13 +139,13 @@ def evaluate(self, context: Optional[Dict[str, Union[str, int]]] = None) -> 'Ser
131139
Call load() before calling evaluate().""")
132140
context = context or {}
133141
config_values = {}
134-
135142
# Initializes config Value objects with default values.
136-
for key, value in self._stringified_default_config.items():
137-
config_values[key] = _Value('default', value)
138-
139-
self._evaluator = _ConditionEvaluator(self._cache.conditions, context,
140-
config_values, self._cache.parameters)
143+
if self._stringified_default_config is not None:
144+
for key, value in json.loads(self._stringified_default_config).items():
145+
config_values[key] = _Value('default', value)
146+
self._evaluator = _ConditionEvaluator(self._cache.conditions,
147+
self._cache.parameters, context,
148+
config_values)
141149
return ServerConfig(config_values=self._evaluator.evaluate())
142150

143151
def set(self, template):
@@ -156,13 +164,13 @@ def __init__(self, config_values):
156164
self._config_values = config_values # dictionary of param key to values
157165

158166
def get_boolean(self, key):
159-
return bool(self.get_value(key))
167+
return self.get_value(key).as_boolean()
160168

161169
def get_string(self, key):
162-
return str(self.get_value(key))
170+
return self.get_value(key).as_string()
163171

164172
def get_int(self, key):
165-
return int(self.get_value(key))
173+
return self.get_value(key).as_number()
166174

167175
def get_value(self, key):
168176
return self._config_values[key]
@@ -209,7 +217,7 @@ def _get_url_prefix(self):
209217
class _ConditionEvaluator:
210218
"""Internal class that facilitates sending requests to the Firebase Remote
211219
Config backend API."""
212-
def __init__(self, context, conditions, config_values, parameters):
220+
def __init__(self, conditions, parameters, context, config_values):
213221
self._context = context
214222
self._conditions = conditions
215223
self._parameters = parameters
@@ -221,51 +229,53 @@ def evaluate(self):
221229
evaluated_conditions = self.evaluate_conditions(self._conditions, self._context)
222230

223231
# Overlays config Value objects derived by evaluating the template.
224-
for key, parameter in self._parameters.items():
225-
conditional_values = parameter.conditional_values or {}
226-
default_value = parameter.default_value or {}
227-
parameter_value_wrapper = None
228-
229-
# Iterates in order over condition list. If there is a value associated
230-
# with a condition, this checks if the condition is true.
231-
for condition_name, condition_evaluation in evaluated_conditions.items():
232-
if condition_name in conditional_values and condition_evaluation:
233-
parameter_value_wrapper = conditional_values[condition_name]
234-
break
235-
if parameter_value_wrapper and parameter_value_wrapper.get('useInAppDefault'):
236-
logger.info("Using in-app default value for key '%s'", key)
237-
continue
238-
239-
if parameter_value_wrapper:
240-
parameter_value = parameter_value_wrapper.value
241-
self._config_values[key] = _Value('remote', parameter_value)
242-
continue
243-
244-
if not default_value:
245-
logger.warning("No default value found for key '%s'", key)
246-
continue
247-
248-
if default_value.get('useInAppDefault'):
249-
logger.info("Using in-app default value for key '%s'", key)
250-
continue
251-
252-
self._config_values[key] = _Value('remote', default_value.get('value'))
232+
# evaluated_conditions = None
233+
if self._parameters is not None:
234+
for key, parameter in self._parameters.items():
235+
conditional_values = parameter.get('conditionalValues', {})
236+
default_value = parameter.get('defaultValue', {})
237+
parameter_value_wrapper = None
238+
# Iterates in order over condition list. If there is a value associated
239+
# with a condition, this checks if the condition is true.
240+
if evaluated_conditions is not None:
241+
for condition_name, condition_evaluation in evaluated_conditions.items():
242+
if condition_name in conditional_values and condition_evaluation:
243+
parameter_value_wrapper = conditional_values[condition_name]
244+
break
245+
246+
if parameter_value_wrapper and parameter_value_wrapper.get('useInAppDefault'):
247+
logger.info("Using in-app default value for key '%s'", key)
248+
continue
249+
250+
if parameter_value_wrapper:
251+
parameter_value = parameter_value_wrapper.get('value')
252+
self._config_values[key] = _Value('remote', parameter_value)
253+
continue
254+
255+
if not default_value:
256+
logger.warning("No default value found for key '%s'", key)
257+
continue
258+
259+
if default_value.get('useInAppDefault'):
260+
logger.info("Using in-app default value for key '%s'", key)
261+
continue
262+
self._config_values[key] = _Value('remote', default_value.get('value'))
253263
return self._config_values
254264

255-
def evaluate_conditions(self, named_conditions, context)-> Dict[str, bool]:
256-
"""Evaluates a list of named conditions and returns a dictionary of results.
265+
def evaluate_conditions(self, conditions, context)-> Dict[str, bool]:
266+
"""Evaluates a list of conditions and returns a dictionary of results.
257267
258268
Args:
259-
named_conditions: A list of NamedCondition objects.
269+
conditions: A list of NamedCondition objects.
260270
context: An EvaluationContext object.
261271
262272
Returns:
263273
A dictionary mapping condition names to boolean evaluation results.
264274
"""
265275
evaluated_conditions = {}
266-
for named_condition in named_conditions:
267-
evaluated_conditions[named_condition.name] = self.evaluate_condition(
268-
named_condition.condition, context
276+
for condition in conditions:
277+
evaluated_conditions[condition.get('name')] = self.evaluate_condition(
278+
condition.get('condition'), context
269279
)
270280
return evaluated_conditions
271281

@@ -284,18 +294,20 @@ def evaluate_condition(self, condition, context,
284294
if nesting_level >= MAX_CONDITION_RECURSION_DEPTH:
285295
logger.warning("Maximum condition recursion depth exceeded.")
286296
return False
287-
if condition.or_condition:
288-
return self.evaluate_or_condition(condition.or_condition, context, nesting_level + 1)
289-
if condition.and_condition:
290-
return self.evaluate_and_condition(condition.and_condition, context, nesting_level + 1)
291-
if condition.true_condition:
297+
if condition.get('orCondition') is not None:
298+
return self.evaluate_or_condition(condition.get('orCondition'),
299+
context, nesting_level + 1)
300+
if condition.get('andCondition') is not None:
301+
return self.evaluate_and_condition(condition.get('andCondition'),
302+
context, nesting_level + 1)
303+
if condition.get('true') is not None:
292304
return True
293-
if condition.false_condition:
305+
if condition.get('false') is not None:
294306
return False
295-
if condition.percent_condition:
296-
return self.evaluate_percent_condition(condition.percent_condition, context)
297-
if condition.custom_signal_condition:
298-
return self.evaluate_custom_signal_condition(condition.custom_signal_condition, context)
307+
if condition.get('percent') is not None:
308+
return self.evaluate_percent_condition(condition.get('percent'), context)
309+
if condition.get('customSignal') is not None:
310+
return self.evaluate_custom_signal_condition(condition.get('customSignal'), context)
299311
logger.warning("Unknown condition type encountered.")
300312
return False
301313

@@ -312,7 +324,7 @@ def evaluate_or_condition(self, or_condition,
312324
Returns:
313325
True if any of the subconditions are true, False otherwise.
314326
"""
315-
sub_conditions = or_condition.conditions or []
327+
sub_conditions = or_condition.get('conditions') or []
316328
for sub_condition in sub_conditions:
317329
result = self.evaluate_condition(sub_condition, context, nesting_level + 1)
318330
if result:
@@ -332,7 +344,7 @@ def evaluate_and_condition(self, and_condition,
332344
Returns:
333345
True if all of the subconditions are true, False otherwise.
334346
"""
335-
sub_conditions = and_condition.conditions or []
347+
sub_conditions = and_condition.get('conditions') or []
336348
for sub_condition in sub_conditions:
337349
result = self.evaluate_condition(sub_condition, context, nesting_level + 1)
338350
if not result:
@@ -350,36 +362,33 @@ def evaluate_percent_condition(self, percent_condition,
350362
Returns:
351363
True if the condition is met, False otherwise.
352364
"""
353-
if not context.randomization_id:
365+
if not context.get('randomization_id'):
354366
logger.warning("Missing randomization ID for percent condition.")
355367
return False
356368

357-
seed = percent_condition.seed
358-
percent_operator = percent_condition.percent_operator
359-
micro_percent = percent_condition.micro_percent or 0
360-
micro_percent_range = percent_condition.micro_percent_range
361-
369+
seed = percent_condition.get('seed')
370+
percent_operator = percent_condition.get('percentOperator')
371+
micro_percent = percent_condition.get('microPercent')
372+
micro_percent_range = percent_condition.get('microPercentRange')
362373
if not percent_operator:
363374
logger.warning("Missing percent operator for percent condition.")
364375
return False
365376
if micro_percent_range:
366-
norm_percent_upper_bound = micro_percent_range.micro_percent_upper_bound
367-
norm_percent_lower_bound = micro_percent_range.micro_percent_lower_bound
377+
norm_percent_upper_bound = micro_percent_range.get('microPercentUpperBound')
378+
norm_percent_lower_bound = micro_percent_range.get('microPercentLowerBound')
368379
else:
369380
norm_percent_upper_bound = 0
370381
norm_percent_lower_bound = 0
371382
seed_prefix = f"{seed}." if seed else ""
372-
string_to_hash = f"{seed_prefix}{context.randomization_id}"
383+
string_to_hash = f"{seed_prefix}{context.get('randomization_id')}"
373384

374385
hash64 = self.hash_seeded_randomization_id(string_to_hash)
375-
376-
instance_micro_percentile = hash64 % (100 * 1_000_000)
377-
378-
if percent_operator == "LESS_OR_EQUAL":
386+
instance_micro_percentile = hash64 % (100 * 1000000)
387+
if percent_operator == PercentConditionOperator.LESS_OR_EQUAL:
379388
return instance_micro_percentile <= micro_percent
380-
if percent_operator == "GREATER_THAN":
389+
if percent_operator == PercentConditionOperator.GREATER_THAN:
381390
return instance_micro_percentile > micro_percent
382-
if percent_operator == "BETWEEN":
391+
if percent_operator == PercentConditionOperator.BETWEEN:
383392
return norm_percent_lower_bound < instance_micro_percentile <= norm_percent_upper_bound
384393
logger.warning("Unknown percent operator: %s", percent_operator)
385394
return False
@@ -393,9 +402,9 @@ def hash_seeded_randomization_id(self, seeded_randomization_id: str) -> int:
393402
The hashed value.
394403
"""
395404
hash_object = hashlib.sha256()
396-
hash_object.update(seeded_randomization_id)
405+
hash_object.update(seeded_randomization_id.encode('utf-8'))
397406
hash64 = hash_object.hexdigest()
398-
return abs(hash64)
407+
return abs(int(hash64, 16))
399408
def evaluate_custom_signal_condition(self, custom_signal_condition,
400409
context) -> bool:
401410
"""Evaluates a custom signal condition.
@@ -407,15 +416,15 @@ def evaluate_custom_signal_condition(self, custom_signal_condition,
407416
Returns:
408417
True if the condition is met, False otherwise.
409418
"""
410-
custom_signal_operator = custom_signal_condition.custom_signal_operator
411-
custom_signal_key = custom_signal_condition.custom_signal_key
412-
target_custom_signal_values = custom_signal_condition.target_custom_signal_values
419+
custom_signal_operator = custom_signal_condition.get('custom_signal_operator') or {}
420+
custom_signal_key = custom_signal_condition.get('custom_signal_key') or {}
421+
tgt_custom_signal_values = custom_signal_condition.get('target_custom_signal_values') or {}
413422

414-
if not all([custom_signal_operator, custom_signal_key, target_custom_signal_values]):
423+
if not all([custom_signal_operator, custom_signal_key, tgt_custom_signal_values]):
415424
logger.warning("Missing operator, key, or target values for custom signal condition.")
416425
return False
417426

418-
if not target_custom_signal_values:
427+
if not tgt_custom_signal_values:
419428
return False
420429
actual_custom_signal_value = getattr(context, custom_signal_key, None)
421430
if actual_custom_signal_value is None:
@@ -466,14 +475,14 @@ def compare_strings(predicate_fn: Callable[[str, str], bool]) -> bool:
466475
bool: True if the predicate function returns True for any target value in the list,
467476
False otherwise.
468477
"""
469-
for target in target_custom_signal_values:
478+
for target in tgt_custom_signal_values:
470479
if predicate_fn(target, str(actual_custom_signal_value)):
471480
return True
472481
return False
473482

474483
def compare_numbers(predicate_fn: Callable[[int], bool]) -> bool:
475484
try:
476-
target = float(target_custom_signal_values[0])
485+
target = float(tgt_custom_signal_values[0])
477486
actual = float(actual_custom_signal_value)
478487
result = -1 if actual < target else 1 if actual > target else 0
479488
return predicate_fn(result)
@@ -494,7 +503,7 @@ def compare_semantic_versions(predicate_fn: Callable[[int], bool]) -> bool:
494503
False otherwise.
495504
"""
496505
return compare_versions(str(actual_custom_signal_value),
497-
str(target_custom_signal_values[0]), predicate_fn)
506+
str(tgt_custom_signal_values[0]), predicate_fn)
498507
def compare_versions(version1: str, version2: str,
499508
predicate_fn: Callable[[int], bool]) -> bool:
500509
"""Compares two semantic version strings.
@@ -587,7 +596,7 @@ def as_boolean(self) -> bool:
587596
"""Returns the value as a boolean."""
588597
if self.source == 'static':
589598
return self.DEFAULT_VALUE_FOR_BOOLEAN
590-
return self.value.lower() in self.BOOLEAN_TRUTHY_VALUES
599+
return str(self.value).lower() in self.BOOLEAN_TRUTHY_VALUES
591600

592601
def as_number(self) -> float:
593602
"""Returns the value as a number."""

0 commit comments

Comments
 (0)