Skip to content
This repository was archived by the owner on Nov 8, 2024. It is now read-only.

Commit d7a4a69

Browse files
authored
[ff-2388] get_bandit_actions: supply actions as a dict instead of list (#53)
1 parent 6847a70 commit d7a4a69

File tree

5 files changed

+173
-143
lines changed

5 files changed

+173
-143
lines changed

eppo_client/bandit.py

Lines changed: 36 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
import logging
3-
from typing import Dict, List, Optional, Tuple
3+
from typing import Dict, List, Optional
44

55
from eppo_client.models import (
66
BanditCategoricalAttributeCoefficient,
@@ -23,45 +23,12 @@ class Attributes:
2323
numeric_attributes: Dict[str, float]
2424
categorical_attributes: Dict[str, str]
2525

26-
27-
@dataclass
28-
class ActionContext:
29-
action_key: str
30-
attributes: Attributes
31-
3226
@classmethod
33-
def create(
34-
cls,
35-
action_key: str,
36-
numeric_attributes: Dict[str, float],
37-
categorical_attributes: Dict[str, str],
38-
):
39-
"""
40-
Create an instance of ActionContext.
41-
42-
Args:
43-
action_key (str): The key representing the action.
44-
numeric_attributes (Dict[str, float]): A dictionary of numeric attributes.
45-
categorical_attributes (Dict[str, str]): A dictionary of categorical attributes.
46-
47-
Returns:
48-
ActionContext: An instance of ActionContext with the provided action key and attributes.
49-
"""
50-
return cls(
51-
action_key,
52-
Attributes(
53-
numeric_attributes=numeric_attributes,
54-
categorical_attributes=categorical_attributes,
55-
),
56-
)
27+
def empty(cls):
28+
return cls({}, {})
5729

58-
@property
59-
def numeric_attributes(self):
60-
return self.attributes.numeric_attributes
6130

62-
@property
63-
def categorical_attributes(self):
64-
return self.attributes.categorical_attributes
31+
ActionContexts = Dict[str, Attributes]
6532

6633

6734
@dataclass
@@ -104,101 +71,85 @@ def evaluate_bandit(
10471
flag_key: str,
10572
subject_key: str,
10673
subject_attributes: Attributes,
107-
actions_with_contexts: List[ActionContext],
74+
actions: ActionContexts,
10875
bandit_model: BanditModelData,
10976
) -> BanditEvaluation:
11077
# handle the edge case that there are no actions
111-
if not actions_with_contexts:
78+
if not actions:
11279
return null_evaluation(
11380
flag_key, subject_key, subject_attributes, bandit_model.gamma
11481
)
11582

116-
action_scores = self.score_actions(
117-
subject_attributes, actions_with_contexts, bandit_model
118-
)
119-
83+
action_scores = self.score_actions(subject_attributes, actions, bandit_model)
12084
action_weights = self.weigh_actions(
12185
action_scores,
12286
bandit_model.gamma,
12387
bandit_model.action_probability_floor,
12488
)
12589

12690
selected_action = self.select_action(flag_key, subject_key, action_weights)
127-
selected_idx = next(
128-
idx
129-
for idx, action_context in enumerate(actions_with_contexts)
130-
if action_context.action_key == selected_action
131-
)
132-
133-
optimality_gap = (
134-
max(score for _, score in action_scores) - action_scores[selected_idx][1]
135-
)
91+
optimality_gap = max(action_scores.values()) - action_scores[selected_action]
13692

13793
return BanditEvaluation(
13894
flag_key,
13995
subject_key,
14096
subject_attributes,
14197
selected_action,
142-
actions_with_contexts[selected_idx].attributes,
143-
action_scores[selected_idx][1],
144-
action_weights[selected_idx][1],
98+
actions[selected_action],
99+
action_scores[selected_action],
100+
action_weights[selected_action],
145101
bandit_model.gamma,
146102
optimality_gap,
147103
)
148104

149105
def score_actions(
150106
self,
151107
subject_attributes: Attributes,
152-
actions_with_contexts: List[ActionContext],
108+
actions: ActionContexts,
153109
bandit_model: BanditModelData,
154-
) -> List[Tuple[str, float]]:
155-
return [
156-
(
157-
action_context.action_key,
158-
(
159-
score_action(
160-
subject_attributes,
161-
action_context.attributes,
162-
bandit_model.coefficients[action_context.action_key],
163-
)
164-
if action_context.action_key in bandit_model.coefficients
165-
else bandit_model.default_action_score
166-
),
110+
) -> Dict[str, float]:
111+
return {
112+
action_key: (
113+
score_action(
114+
subject_attributes,
115+
action_attributes,
116+
bandit_model.coefficients[action_key],
117+
)
118+
if action_key in bandit_model.coefficients
119+
else bandit_model.default_action_score
167120
)
168-
for action_context in actions_with_contexts
169-
]
121+
for action_key, action_attributes in actions.items()
122+
}
170123

171124
def weigh_actions(
172125
self, action_scores, gamma, probability_floor
173-
) -> List[Tuple[str, float]]:
126+
) -> Dict[str, float]:
174127
number_of_actions = len(action_scores)
175-
best_action, best_score = max(action_scores, key=lambda t: t[1])
128+
best_action = max(action_scores, key=action_scores.get)
129+
best_score = action_scores[best_action]
176130

177131
# adjust probability floor for number of actions to control the sum
178132
min_probability = probability_floor / number_of_actions
179133

180134
# weight all but the best action
181-
weights = [
182-
(
183-
action_key,
184-
max(
185-
min_probability,
186-
1.0 / (number_of_actions + gamma * (best_score - score)),
187-
),
135+
weights = {
136+
action_key: max(
137+
min_probability,
138+
1.0 / (number_of_actions + gamma * (best_score - score)),
188139
)
189-
for action_key, score in action_scores
140+
for action_key, score in action_scores.items()
190141
if action_key != best_action
191-
]
142+
}
192143

193144
# remaining weight goes to best action
194-
remaining_weight = max(0.0, 1.0 - sum(weight for _, weight in weights))
195-
weights.append((best_action, remaining_weight))
145+
remaining_weight = max(0.0, 1.0 - sum(weights.values()))
146+
weights[best_action] = remaining_weight
196147
return weights
197148

198149
def select_action(self, flag_key, subject_key, action_weights) -> str:
199150
# deterministic ordering
200151
sorted_action_weights = sorted(
201-
action_weights,
152+
action_weights.items(),
202153
key=lambda t: (
203154
self.sharder.get_shard(
204155
f"{flag_key}-{subject_key}-{t[0]}", self.total_shards

eppo_client/client.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import datetime
22
import logging
33
import json
4-
from typing import Any, Dict, List, Optional
4+
from typing import Any, Dict, Optional
55
from eppo_client.assignment_logger import AssignmentLogger
6-
from eppo_client.bandit import BanditEvaluator, BanditResult, ActionContext, Attributes
6+
from eppo_client.bandit import BanditEvaluator, BanditResult, Attributes, ActionContexts
77
from eppo_client.configuration_requestor import (
88
ExperimentConfigurationRequestor,
99
)
@@ -226,7 +226,7 @@ def get_bandit_action(
226226
flag_key: str,
227227
subject_key: str,
228228
subject_context: Attributes,
229-
actions_with_contexts: List[ActionContext],
229+
actions: ActionContexts,
230230
default: str,
231231
) -> BanditResult:
232232
"""
@@ -245,20 +245,41 @@ def get_bandit_action(
245245
flag_key (str): The feature flag key that contains the bandit as one of the variations.
246246
subject_key (str): The key identifying the subject.
247247
subject_context (Attributes): The subject context
248-
actions_with_contexts (List[ActionContext]): The list of actions with their contexts.
248+
actions (Dict[str, Attributes]): The dictionary that maps action keys
249+
to their context of actions with their contexts.
250+
default (str): The default variation to use if the subject is not part of the bandit.
249251
250252
Returns:
251253
BanditResult: The result containing either the bandit action if the subject is part of the bandit,
252254
or the assignment if they are not. The BanditResult includes:
253255
- variation (str): The assignment key indicating the subject's variation.
254256
- action (str): The key of the selected action if the subject is part of the bandit.
257+
258+
Example:
259+
result = client.get_bandit_action(
260+
"flag_key",
261+
"subject_key",
262+
Attributes(
263+
numeric_attributes={"age": 25},
264+
categorical_attributes={"country": "USA"}),
265+
{
266+
"action1": Attributes(numeric_attributes={"price": 10.0}, categorical_attributes={"category": "A"}),
267+
"action2": Attributes.empty()
268+
},
269+
"default"
270+
)
271+
if result.action is None:
272+
do_variation(result.variation)
273+
else:
274+
do_action(result.action)
255275
"""
276+
256277
try:
257278
return self.get_bandit_action_detail(
258279
flag_key,
259280
subject_key,
260281
subject_context,
261-
actions_with_contexts,
282+
actions,
262283
default,
263284
)
264285
except Exception as e:
@@ -272,7 +293,7 @@ def get_bandit_action_detail(
272293
flag_key: str,
273294
subject_key: str,
274295
subject_context: Attributes,
275-
actions_with_contexts: List[ActionContext],
296+
actions: ActionContexts,
276297
default: str,
277298
) -> BanditResult:
278299
# get experiment assignment
@@ -298,7 +319,7 @@ def get_bandit_action_detail(
298319
flag_key,
299320
subject_key,
300321
subject_context,
301-
actions_with_contexts,
322+
actions,
302323
bandit_data.model_data,
303324
)
304325

eppo_client/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.1.4"
1+
__version__ = "3.2.0"

0 commit comments

Comments
 (0)