Skip to content
This repository was archived by the owner on Nov 8, 2024. It is now read-only.

Commit c6ce44d

Browse files
authored
[ff-2352] Add optimality gap to bandit logging (#50)
1 parent 061fb5a commit c6ce44d

File tree

5 files changed

+71
-20
lines changed

5 files changed

+71
-20
lines changed

eppo_client/bandit.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class BanditEvaluation:
7474
action_score: float
7575
action_weight: float
7676
gamma: float
77+
optimality_gap: float
7778

7879

7980
@dataclass
@@ -89,14 +90,7 @@ def null_evaluation(
8990
flag_key: str, subject_key: str, subject_attributes: Attributes, gamma: float
9091
):
9192
return BanditEvaluation(
92-
flag_key,
93-
subject_key,
94-
subject_attributes,
95-
None,
96-
None,
97-
0.0,
98-
0.0,
99-
gamma,
93+
flag_key, subject_key, subject_attributes, None, None, 0.0, 0.0, gamma, 0.0
10094
)
10195

10296

@@ -129,9 +123,17 @@ def evaluate_bandit(
129123
bandit_model.action_probability_floor,
130124
)
131125

132-
selected_idx, selected_action = self.select_action(
133-
flag_key, subject_key, action_weights
126+
selected_action = self.select_action(flag_key, subject_key, action_weights)
127+
selected_idx = next(
128+
idx
129+
for idx, action_context in enumerate(actions_with_contexts)
130+
if action_context.action_key == selected_action
134131
)
132+
133+
optimality_gap = (
134+
max(score for _, score in action_scores) - action_scores[selected_idx][1]
135+
)
136+
135137
return BanditEvaluation(
136138
flag_key,
137139
subject_key,
@@ -141,6 +143,7 @@ def evaluate_bandit(
141143
action_scores[selected_idx][1],
142144
action_weights[selected_idx][1],
143145
bandit_model.gamma,
146+
optimality_gap,
144147
)
145148

146149
def score_actions(
@@ -192,7 +195,7 @@ def weigh_actions(
192195
weights.append((best_action, remaining_weight))
193196
return weights
194197

195-
def select_action(self, flag_key, subject_key, action_weights) -> Tuple[int, str]:
198+
def select_action(self, flag_key, subject_key, action_weights) -> str:
196199
# deterministic ordering
197200
sorted_action_weights = sorted(
198201
action_weights,
@@ -209,10 +212,10 @@ def select_action(self, flag_key, subject_key, action_weights) -> Tuple[int, str
209212
cumulative_weight = 0.0
210213
shard_value = shard / self.total_shards
211214

212-
for idx, (action_key, weight) in enumerate(sorted_action_weights):
215+
for action_key, weight in sorted_action_weights:
213216
cumulative_weight += weight
214217
if cumulative_weight > shard_value:
215-
return idx, action_key
218+
return action_key
216219

217220
# If no action is selected, return the last action (fallback)
218221
raise BanditEvaluationError(

eppo_client/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def get_bandit_action_detail(
309309
"subject": subject_key,
310310
"action": evaluation.action_key if evaluation else None,
311311
"actionProbability": evaluation.action_weight if evaluation else None,
312+
"optimalityGap": evaluation.optimality_gap if evaluation else None,
312313
"modelVersion": bandit_data.model_version if evaluation else None,
313314
"timestamp": datetime.datetime.utcnow().isoformat(),
314315
"subjectNumericAttributes": (

eppo_client/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.1.3"
1+
__version__ = "3.1.4"

test/client_bandit_test.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import os
77
from time import sleep
8-
from typing import Dict
8+
from typing import Dict, List
99
from eppo_client.bandit import BanditResult, ActionContext, Attributes
1010

1111
import httpretty # type: ignore
@@ -34,11 +34,17 @@
3434

3535

3636
class MockAssignmentLogger(AssignmentLogger):
37+
assignment_events: List[Dict] = []
38+
bandit_events: List[Dict] = []
39+
3740
def log_assignment(self, assignment_event: Dict):
38-
print(f"Assignment Event: {assignment_event}")
41+
self.assignment_events.append(assignment_event)
3942

4043
def log_bandit_action(self, bandit_event: Dict):
41-
print(f"Bandit Event: {bandit_event}")
44+
self.bandit_events.append(bandit_event)
45+
46+
47+
mock_assignment_logger = MockAssignmentLogger()
4248

4349

4450
@pytest.fixture(scope="session", autouse=True)
@@ -64,7 +70,7 @@ def init_fixture():
6470
Config(
6571
base_url=MOCK_BASE_URL,
6672
api_key="dummy",
67-
assignment_logger=AssignmentLogger(),
73+
assignment_logger=mock_assignment_logger,
6874
)
6975
)
7076
sleep(0.1) # wait for initialization
@@ -102,16 +108,55 @@ def test_get_bandit_action_flag_without_bandit():
102108
def test_get_bandit_action_with_subject_attributes():
103109
# tests that allocation filtering based on subject attributes works correctly
104110
client = get_instance()
111+
actions = [
112+
ActionContext.create("adidas", {"discount": 0.1}, {"from": "germany"}),
113+
ActionContext.create("nike", {"discount": 0.2}, {"from": "usa"}),
114+
]
105115
result = client.get_bandit_action(
106116
"banner_bandit_flag_uk_only",
107-
"subject_key",
117+
"alice",
108118
DEFAULT_SUBJECT_ATTRIBUTES,
109-
[ActionContext.create("adidas", {}, {}), ActionContext.create("nike", {}, {})],
119+
actions,
110120
"default_variation",
111121
)
112122
assert result.variation == "banner_bandit"
113123
assert result.action in ["adidas", "nike"]
114124

125+
# testing assignment logger
126+
assignment_log_statement = mock_assignment_logger.assignment_events[-1]
127+
assert assignment_log_statement["featureFlag"] == "banner_bandit_flag_uk_only"
128+
assert assignment_log_statement["variation"] == "banner_bandit"
129+
assert assignment_log_statement["subject"] == "alice"
130+
131+
# testing bandit logger
132+
bandit_log_statement = mock_assignment_logger.bandit_events[-1]
133+
assert bandit_log_statement["flagKey"] == "banner_bandit_flag_uk_only"
134+
assert bandit_log_statement["banditKey"] == "banner_bandit"
135+
assert bandit_log_statement["subject"] == "alice"
136+
assert (
137+
bandit_log_statement["subjectNumericAttributes"]
138+
== DEFAULT_SUBJECT_ATTRIBUTES.numeric_attributes
139+
)
140+
assert (
141+
bandit_log_statement["subjectCategoricalAttributes"]
142+
== DEFAULT_SUBJECT_ATTRIBUTES.categorical_attributes
143+
)
144+
assert bandit_log_statement["action"] == result.action
145+
assert bandit_log_statement["optimalityGap"] >= 0
146+
assert bandit_log_statement["actionProbability"] >= 0
147+
148+
chosen_action = next(
149+
action for action in actions if action.action_key == result.action
150+
)
151+
assert (
152+
bandit_log_statement["actionNumericAttributes"]
153+
== chosen_action.attributes.numeric_attributes
154+
)
155+
assert (
156+
bandit_log_statement["actionCategoricalAttributes"]
157+
== chosen_action.attributes.categorical_attributes
158+
)
159+
115160

116161
@pytest.mark.parametrize("test_case", test_data)
117162
def test_bandit_generic_test_cases(test_case):

test/eval_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def test_flag_target_on_id():
167167
assert result.variation == Variation(key="control", value="control")
168168
result = evaluator.evaluate_flag(flag, "user-3", {})
169169
assert result.variation is None
170+
result = evaluator.evaluate_flag(flag, "user-1", {"id": "do-not-overwrite-me"})
171+
assert result.variation is None
170172

171173

172174
def test_catch_all_allocation():

0 commit comments

Comments
 (0)