Skip to content
This repository was archived by the owner on Nov 8, 2024. It is now read-only.

Commit 8b79f1f

Browse files
authored
[UFC] Add bandit to SDK (#38)
1 parent 5397cff commit 8b79f1f

12 files changed

+1012
-21
lines changed

eppo_client/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
)
77
from eppo_client.configuration_store import ConfigurationStore
88
from eppo_client.http_client import HttpClient, SdkParams
9-
from eppo_client.models import Flag
9+
from eppo_client.models import BanditData, Flag
1010
from eppo_client.read_write_lock import ReadWriteLock
1111
from eppo_client.version import __version__
1212

@@ -30,9 +30,12 @@ def init(config: Config) -> EppoClient:
3030
apiKey=config.api_key, sdkName="python", sdkVersion=__version__
3131
)
3232
http_client = HttpClient(base_url=config.base_url, sdk_params=sdk_params)
33-
config_store: ConfigurationStore[Flag] = ConfigurationStore()
33+
flag_config_store: ConfigurationStore[Flag] = ConfigurationStore()
34+
bandit_config_store: ConfigurationStore[BanditData] = ConfigurationStore()
3435
config_requestor = ExperimentConfigurationRequestor(
35-
http_client=http_client, config_store=config_store
36+
http_client=http_client,
37+
flag_config_store=flag_config_store,
38+
bandit_config_store=bandit_config_store,
3639
)
3740
assignment_logger = config.assignment_logger
3841
is_graceful_mode = config.is_graceful_mode

eppo_client/assignment_logger.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ class AssignmentLogger(BaseModel):
88

99
def log_assignment(self, assignment_event: Dict):
1010
pass
11+
12+
def log_bandit_action(self, bandit_event: Dict):
13+
pass

eppo_client/bandit.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
from dataclasses import dataclass
2+
import logging
3+
from typing import Dict, List, Optional, Tuple
4+
5+
from eppo_client.models import (
6+
BanditCategoricalAttributeCoefficient,
7+
BanditCoefficients,
8+
BanditModelData,
9+
BanditNumericAttributeCoefficient,
10+
)
11+
from eppo_client.sharders import Sharder
12+
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class BanditEvaluationError(Exception):
18+
pass
19+
20+
21+
@dataclass
22+
class Attributes:
23+
numeric_attributes: Dict[str, float]
24+
categorical_attributes: Dict[str, str]
25+
26+
27+
@dataclass
28+
class ActionContext:
29+
action_key: str
30+
attributes: Attributes
31+
32+
@classmethod
33+
def create(
34+
cls,
35+
action_key: str,
36+
numeric_attributes: Dict[str, float],
37+
categorical_attributes: Dict[str, str],
38+
):
39+
"""
40+
Create an instance of ActionContext.
41+
42+
Args:
43+
action_key (str): The key representing the action.
44+
numeric_attributes (Dict[str, float]): A dictionary of numeric attributes.
45+
categorical_attributes (Dict[str, str]): A dictionary of categorical attributes.
46+
47+
Returns:
48+
ActionContext: An instance of ActionContext with the provided action key and attributes.
49+
"""
50+
return cls(
51+
action_key,
52+
Attributes(
53+
numeric_attributes=numeric_attributes,
54+
categorical_attributes=categorical_attributes,
55+
),
56+
)
57+
58+
@property
59+
def numeric_attributes(self):
60+
return self.attributes.numeric_attributes
61+
62+
@property
63+
def categorical_attributes(self):
64+
return self.attributes.categorical_attributes
65+
66+
67+
@dataclass
68+
class BanditEvaluation:
69+
flag_key: str
70+
subject_key: str
71+
subject_attributes: Attributes
72+
action_key: Optional[str]
73+
action_attributes: Optional[Attributes]
74+
action_score: float
75+
action_weight: float
76+
gamma: float
77+
78+
79+
@dataclass
80+
class BanditResult:
81+
variation: str
82+
action: Optional[str]
83+
84+
def to_string(self) -> str:
85+
return coalesce(self.action, self.variation)
86+
87+
88+
def null_evaluation(
89+
flag_key: str, subject_key: str, subject_attributes: Attributes, gamma: float
90+
):
91+
return BanditEvaluation(
92+
flag_key,
93+
subject_key,
94+
subject_attributes,
95+
None,
96+
None,
97+
0.0,
98+
0.0,
99+
gamma,
100+
)
101+
102+
103+
@dataclass
104+
class BanditEvaluator:
105+
sharder: Sharder
106+
total_shards: int = 10_000
107+
108+
def evaluate_bandit(
109+
self,
110+
flag_key: str,
111+
subject_key: str,
112+
subject_attributes: Attributes,
113+
actions_with_contexts: List[ActionContext],
114+
bandit_model: BanditModelData,
115+
) -> BanditEvaluation:
116+
# handle the edge case that there are no actions
117+
if not actions_with_contexts:
118+
return null_evaluation(
119+
flag_key, subject_key, subject_attributes, bandit_model.gamma
120+
)
121+
122+
action_scores = self.score_actions(
123+
subject_attributes, actions_with_contexts, bandit_model
124+
)
125+
126+
action_weights = self.weigh_actions(
127+
action_scores,
128+
bandit_model.gamma,
129+
bandit_model.action_probability_floor,
130+
)
131+
132+
selected_idx, selected_action = self.select_action(
133+
flag_key, subject_key, action_weights
134+
)
135+
return BanditEvaluation(
136+
flag_key,
137+
subject_key,
138+
subject_attributes,
139+
selected_action,
140+
actions_with_contexts[selected_idx].attributes,
141+
action_scores[selected_idx][1],
142+
action_weights[selected_idx][1],
143+
bandit_model.gamma,
144+
)
145+
146+
def score_actions(
147+
self,
148+
subject_attributes: Attributes,
149+
actions_with_contexts: List[ActionContext],
150+
bandit_model: BanditModelData,
151+
) -> List[Tuple[str, float]]:
152+
return [
153+
(
154+
action_context.action_key,
155+
(
156+
score_action(
157+
subject_attributes,
158+
action_context.attributes,
159+
bandit_model.coefficients[action_context.action_key],
160+
)
161+
if action_context.action_key in bandit_model.coefficients
162+
else bandit_model.default_action_score
163+
),
164+
)
165+
for action_context in actions_with_contexts
166+
]
167+
168+
def weigh_actions(
169+
self, action_scores, gamma, probability_floor
170+
) -> List[Tuple[str, float]]:
171+
number_of_actions = len(action_scores)
172+
best_action, best_score = max(action_scores, key=lambda t: t[1])
173+
174+
# adjust probability floor for number of actions to control the sum
175+
min_probability = probability_floor / number_of_actions
176+
177+
# weight all but the best action
178+
weights = [
179+
(
180+
action_key,
181+
max(
182+
min_probability,
183+
1.0 / (number_of_actions + gamma * (best_score - score)),
184+
),
185+
)
186+
for action_key, score in action_scores
187+
if action_key != best_action
188+
]
189+
190+
# remaining weight goes to best action
191+
remaining_weight = max(0.0, 1.0 - sum(weight for _, weight in weights))
192+
weights.append((best_action, remaining_weight))
193+
return weights
194+
195+
def select_action(self, flag_key, subject_key, action_weights) -> Tuple[int, str]:
196+
# deterministic ordering
197+
sorted_action_weights = sorted(
198+
action_weights,
199+
key=lambda t: (
200+
self.sharder.get_shard(
201+
f"{flag_key}-{subject_key}-{t[0]}", self.total_shards
202+
),
203+
t[0], # tie-break using action name
204+
),
205+
)
206+
207+
# select action based on weights
208+
shard = self.sharder.get_shard(f"{flag_key}-{subject_key}", self.total_shards)
209+
cumulative_weight = 0.0
210+
shard_value = shard / self.total_shards
211+
212+
for idx, (action_key, weight) in enumerate(sorted_action_weights):
213+
cumulative_weight += weight
214+
if cumulative_weight > shard_value:
215+
return idx, action_key
216+
217+
# If no action is selected, return the last action (fallback)
218+
raise BanditEvaluationError(
219+
f"[Eppo SDK] No action selected for {flag_key} {subject_key}"
220+
)
221+
222+
223+
def score_action(
224+
subject_attributes: Attributes,
225+
action_attributes: Attributes,
226+
coefficients: BanditCoefficients,
227+
) -> float:
228+
score = coefficients.intercept
229+
score += score_numeric_attributes(
230+
coefficients.subject_numeric_coefficients,
231+
subject_attributes.numeric_attributes,
232+
)
233+
score += score_categorical_attributes(
234+
coefficients.subject_categorical_coefficients,
235+
subject_attributes.categorical_attributes,
236+
)
237+
score += score_numeric_attributes(
238+
coefficients.action_numeric_coefficients,
239+
action_attributes.numeric_attributes,
240+
)
241+
score += score_categorical_attributes(
242+
coefficients.action_categorical_coefficients,
243+
action_attributes.categorical_attributes,
244+
)
245+
return score
246+
247+
248+
def coalesce(value, default=0):
249+
return value if value is not None else default
250+
251+
252+
def score_numeric_attributes(
253+
coefficients: List[BanditNumericAttributeCoefficient],
254+
attributes: Dict[str, float],
255+
) -> float:
256+
score = 0.0
257+
for coefficient in coefficients:
258+
if (
259+
coefficient.attribute_key in attributes
260+
and attributes[coefficient.attribute_key] is not None
261+
):
262+
score += coefficient.coefficient * attributes[coefficient.attribute_key]
263+
else:
264+
score += coefficient.missing_value_coefficient
265+
266+
return score
267+
268+
269+
def score_categorical_attributes(
270+
coefficients: List[BanditCategoricalAttributeCoefficient],
271+
attributes: Dict[str, str],
272+
) -> float:
273+
score = 0.0
274+
for coefficient in coefficients:
275+
if coefficient.attribute_key in attributes:
276+
score += coefficient.value_coefficients.get(
277+
attributes[coefficient.attribute_key],
278+
coefficient.missing_value_coefficient,
279+
)
280+
else:
281+
score += coefficient.missing_value_coefficient
282+
return score

0 commit comments

Comments
 (0)