Skip to content
This repository was archived by the owner on Nov 8, 2024. It is now read-only.

Commit 00fb6ae

Browse files
authored
[ff-2397] Allow attributes to be specified as dicts (#55)
1 parent 57e4606 commit 00fb6ae

File tree

6 files changed

+100
-33
lines changed

6 files changed

+100
-33
lines changed

eppo_client/bandit.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
BanditModelData,
99
BanditNumericAttributeCoefficient,
1010
)
11+
from eppo_client.rules import to_string
1112
from eppo_client.sharders import Sharder
13+
from eppo_client.types import AttributesDict
1214

1315

1416
logger = logging.getLogger(__name__)
@@ -25,10 +27,42 @@ class Attributes:
2527

2628
@classmethod
2729
def empty(cls):
30+
"""
31+
Create an empty Attributes instance with no numeric or categorical attributes.
32+
33+
Returns:
34+
Attributes: An instance of the Attributes class with empty dictionaries
35+
for numeric and categorical attributes.
36+
"""
2837
return cls({}, {})
2938

39+
@classmethod
40+
def from_dict(cls, attributes: AttributesDict):
41+
"""
42+
Create an Attributes instance from a dictionary of attributes.
43+
44+
Args:
45+
attributes (Dict[str, Union[float, int, bool, str]]): A dictionary where keys are attribute names
46+
and values are attribute values which can be of type float, int, bool, or str.
47+
48+
Returns:
49+
Attributes: An instance of the Attributes class with numeric and categorical attributes separated.
50+
"""
51+
numeric_attributes = {
52+
key: float(value)
53+
for key, value in attributes.items()
54+
if isinstance(value, (int, float))
55+
}
56+
categorical_attributes = {
57+
key: to_string(value)
58+
for key, value in attributes.items()
59+
if isinstance(value, (str, bool))
60+
}
61+
return cls(numeric_attributes, categorical_attributes)
62+
3063

3164
ActionContexts = Dict[str, Attributes]
65+
ActionContextsDict = Dict[str, AttributesDict]
3266

3367

3468
@dataclass

eppo_client/client.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
import datetime
22
import logging
33
import json
4-
from typing import Any, Dict, Optional
4+
from typing import Any, Dict, Optional, Union
55
from eppo_client.assignment_logger import AssignmentLogger
6-
from eppo_client.bandit import BanditEvaluator, BanditResult, Attributes, ActionContexts
6+
from eppo_client.bandit import (
7+
ActionContextsDict,
8+
BanditEvaluator,
9+
BanditResult,
10+
Attributes,
11+
ActionContexts,
12+
)
713
from eppo_client.configuration_requestor import (
814
ExperimentConfigurationRequestor,
915
)
1016
from eppo_client.constants import POLL_INTERVAL_MILLIS, POLL_JITTER_MILLIS
1117
from eppo_client.models import VariationType
1218
from eppo_client.poller import Poller
1319
from eppo_client.sharders import MD5Sharder
14-
from eppo_client.types import SubjectAttributes, ValueType
20+
from eppo_client.types import AttributesDict, ValueType
1521
from eppo_client.validation import validate_not_blank
1622
from eppo_client.eval import FlagEvaluation, Evaluator, none_result
1723
from eppo_client.version import __version__
@@ -43,7 +49,7 @@ def get_string_assignment(
4349
self,
4450
flag_key: str,
4551
subject_key: str,
46-
subject_attributes: SubjectAttributes,
52+
subject_attributes: AttributesDict,
4753
default: str,
4854
) -> str:
4955
return self.get_assignment_variation(
@@ -58,7 +64,7 @@ def get_integer_assignment(
5864
self,
5965
flag_key: str,
6066
subject_key: str,
61-
subject_attributes: SubjectAttributes,
67+
subject_attributes: AttributesDict,
6268
default: int,
6369
) -> int:
6470
return self.get_assignment_variation(
@@ -73,7 +79,7 @@ def get_numeric_assignment(
7379
self,
7480
flag_key: str,
7581
subject_key: str,
76-
subject_attributes: SubjectAttributes,
82+
subject_attributes: AttributesDict,
7783
default: float,
7884
) -> float:
7985
# convert to float in case we get an int
@@ -91,7 +97,7 @@ def get_boolean_assignment(
9197
self,
9298
flag_key: str,
9399
subject_key: str,
94-
subject_attributes: SubjectAttributes,
100+
subject_attributes: AttributesDict,
95101
default: bool,
96102
) -> bool:
97103
return self.get_assignment_variation(
@@ -106,7 +112,7 @@ def get_json_assignment(
106112
self,
107113
flag_key: str,
108114
subject_key: str,
109-
subject_attributes: SubjectAttributes,
115+
subject_attributes: AttributesDict,
110116
default: Dict[Any, Any],
111117
) -> Dict[Any, Any]:
112118
json_value = self.get_assignment_variation(
@@ -125,7 +131,7 @@ def get_assignment_variation(
125131
self,
126132
flag_key: str,
127133
subject_key: str,
128-
subject_attributes: SubjectAttributes,
134+
subject_attributes: AttributesDict,
129135
default: Optional[ValueType],
130136
expected_variation_type: VariationType,
131137
):
@@ -149,7 +155,7 @@ def get_assignment_detail(
149155
self,
150156
flag_key: str,
151157
subject_key: str,
152-
subject_attributes: SubjectAttributes,
158+
subject_attributes: AttributesDict,
153159
expected_variation_type: VariationType,
154160
) -> FlagEvaluation:
155161
"""Maps a subject to a variation for a given flag
@@ -225,8 +231,8 @@ def get_bandit_action(
225231
self,
226232
flag_key: str,
227233
subject_key: str,
228-
subject_context: Attributes,
229-
actions: ActionContexts,
234+
subject_context: Union[Attributes, AttributesDict],
235+
actions: Union[ActionContexts, ActionContextsDict],
230236
default: str,
231237
) -> BanditResult:
232238
"""
@@ -244,9 +250,11 @@ def get_bandit_action(
244250
Args:
245251
flag_key (str): The feature flag key that contains the bandit as one of the variations.
246252
subject_key (str): The key identifying the subject.
247-
subject_context (Attributes): The subject context
248-
actions (Dict[str, Attributes]): The dictionary that maps action keys
253+
subject_context (Attributes | AttributesDict): The subject context.
254+
If supplying an AttributesDict, it gets converted to an Attributes instance
255+
actions (ActionContexts | ActionContextsDict): The dictionary that maps action keys
249256
to their context of actions with their contexts.
257+
If supplying an AttributesDict, it gets converted to an Attributes instance.
250258
default (str): The default variation to use if the subject is not part of the bandit.
251259
252260
Returns:
@@ -264,7 +272,8 @@ def get_bandit_action(
264272
categorical_attributes={"country": "USA"}),
265273
{
266274
"action1": Attributes(numeric_attributes={"price": 10.0}, categorical_attributes={"category": "A"}),
267-
"action2": Attributes.empty()
275+
"action2": {"price": 10.0, "category": "B"}
276+
"action3": Attributes.empty(),
268277
},
269278
"default"
270279
)
@@ -273,7 +282,6 @@ def get_bandit_action(
273282
else:
274283
do_action(result.action)
275284
"""
276-
277285
try:
278286
return self.get_bandit_action_detail(
279287
flag_key,
@@ -292,14 +300,21 @@ def get_bandit_action_detail(
292300
self,
293301
flag_key: str,
294302
subject_key: str,
295-
subject_context: Attributes,
296-
actions: ActionContexts,
303+
subject_context: Union[Attributes, AttributesDict],
304+
actions: Union[ActionContexts, ActionContextsDict],
297305
default: str,
298306
) -> BanditResult:
307+
subject_attributes = convert_subject_context_to_attributes(subject_context)
308+
action_contexts = convert_actions_to_action_contexts(actions)
309+
299310
# get experiment assignment
300311
# ignoring type because Dict[str, str] satisfies Dict[str, str | ...] but mypy does not understand
301312
variation = self.get_string_assignment(
302-
flag_key, subject_key, subject_context.categorical_attributes, default # type: ignore
313+
flag_key,
314+
subject_key,
315+
subject_attributes.categorical_attributes
316+
| subject_attributes.numeric_attributes, # type: ignore
317+
default,
303318
)
304319

305320
# if the variation is not the bandit key, then the subject is not allocated in the bandit
@@ -318,8 +333,8 @@ def get_bandit_action_detail(
318333
evaluation = self.__bandit_evaluator.evaluate_bandit(
319334
flag_key,
320335
subject_key,
321-
subject_context,
322-
actions,
336+
subject_attributes,
337+
action_contexts,
323338
bandit_data.model_data,
324339
)
325340

@@ -334,12 +349,12 @@ def get_bandit_action_detail(
334349
"modelVersion": bandit_data.model_version if evaluation else None,
335350
"timestamp": datetime.datetime.utcnow().isoformat(),
336351
"subjectNumericAttributes": (
337-
subject_context.numeric_attributes
352+
subject_attributes.numeric_attributes
338353
if evaluation.subject_attributes
339354
else {}
340355
),
341356
"subjectCategoricalAttributes": (
342-
subject_context.categorical_attributes
357+
subject_attributes.categorical_attributes
343358
if evaluation.subject_attributes
344359
else {}
345360
),
@@ -410,3 +425,20 @@ def check_value_type_match(
410425
if expected_type == VariationType.BOOLEAN:
411426
return isinstance(value, bool)
412427
return False
428+
429+
430+
def convert_subject_context_to_attributes(
431+
subject_context: Union[Attributes, AttributesDict]
432+
) -> Attributes:
433+
if isinstance(subject_context, dict):
434+
return Attributes.from_dict(subject_context)
435+
return subject_context
436+
437+
438+
def convert_actions_to_action_contexts(
439+
actions: Union[ActionContexts, ActionContextsDict]
440+
) -> ActionContexts:
441+
return {
442+
k: Attributes.from_dict(v) if isinstance(v, dict) else v
443+
for k, v in actions.items()
444+
}

eppo_client/eval.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from dataclasses import dataclass
66
import datetime
77

8-
from eppo_client.types import SubjectAttributes
8+
from eppo_client.types import AttributesDict
99

1010

1111
@dataclass
1212
class FlagEvaluation:
1313
flag_key: str
1414
variation_type: VariationType
1515
subject_key: str
16-
subject_attributes: SubjectAttributes
16+
subject_attributes: AttributesDict
1717
allocation_key: Optional[str]
1818
variation: Optional[Variation]
1919
extra_logging: Dict[str, str]
@@ -28,7 +28,7 @@ def evaluate_flag(
2828
self,
2929
flag: Flag,
3030
subject_key: str,
31-
subject_attributes: SubjectAttributes,
31+
subject_attributes: AttributesDict,
3232
) -> FlagEvaluation:
3333
if not flag.enabled:
3434
return none_result(
@@ -93,7 +93,7 @@ def none_result(
9393
flag_key: str,
9494
variation_type: VariationType,
9595
subject_key: str,
96-
subject_attributes: SubjectAttributes,
96+
subject_attributes: AttributesDict,
9797
) -> FlagEvaluation:
9898
return FlagEvaluation(
9999
flag_key=flag_key,

eppo_client/rules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import semver
88

99
from eppo_client.models import SdkBaseModel
10-
from eppo_client.types import AttributeType, ConditionValueType, SubjectAttributes
10+
from eppo_client.types import AttributeType, ConditionValueType, AttributesDict
1111

1212

1313
class OperatorType(Enum):
@@ -32,15 +32,15 @@ class Rule(SdkBaseModel):
3232
conditions: List[Condition]
3333

3434

35-
def matches_rule(rule: Rule, subject_attributes: SubjectAttributes) -> bool:
35+
def matches_rule(rule: Rule, subject_attributes: AttributesDict) -> bool:
3636
return all(
3737
evaluate_condition(condition, subject_attributes)
3838
for condition in rule.conditions
3939
)
4040

4141

4242
def evaluate_condition(
43-
condition: Condition, subject_attributes: SubjectAttributes
43+
condition: Condition, subject_attributes: AttributesDict
4444
) -> bool:
4545
subject_value = subject_attributes.get(condition.attribute, None)
4646
if condition.operator == OperatorType.IS_NULL:

eppo_client/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
ValueType = Union[str, int, float, bool]
44
AttributeType = Union[str, int, float, bool, None]
55
ConditionValueType = Union[AttributeType, List[AttributeType]]
6-
SubjectAttributes = Dict[str, AttributeType]
6+
AttributesDict = Dict[str, AttributeType]
77
Action = str

test/client_bandit_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def init_fixture():
7171
base_url=MOCK_BASE_URL,
7272
api_key="dummy",
7373
assignment_logger=mock_assignment_logger,
74+
is_graceful_mode=False,
7475
)
7576
)
7677
sleep(0.1) # wait for initialization
@@ -91,7 +92,7 @@ def test_get_bandit_action_bandit_does_not_exist():
9192
"nonexistent_bandit",
9293
"subject_key",
9394
DEFAULT_SUBJECT_ATTRIBUTES,
94-
[],
95+
{},
9596
"default_variation",
9697
)
9798
assert result == BanditResult("default_variation", None)
@@ -100,7 +101,7 @@ def test_get_bandit_action_bandit_does_not_exist():
100101
def test_get_bandit_action_flag_without_bandit():
101102
client = get_instance()
102103
result = client.get_bandit_action(
103-
"a_flag", "subject_key", DEFAULT_SUBJECT_ATTRIBUTES, [], "default_variation"
104+
"a_flag", "subject_key", DEFAULT_SUBJECT_ATTRIBUTES, {}, "default_variation"
104105
)
105106
assert result == BanditResult("default_variation", None)
106107

0 commit comments

Comments
 (0)