1
1
from dataclasses import dataclass
2
2
import logging
3
- from typing import Dict , List , Optional , Tuple
3
+ from typing import Dict , List , Optional
4
4
5
5
from eppo_client .models import (
6
6
BanditCategoricalAttributeCoefficient ,
@@ -23,45 +23,12 @@ class Attributes:
23
23
numeric_attributes : Dict [str , float ]
24
24
categorical_attributes : Dict [str , str ]
25
25
26
-
27
- @dataclass
28
- class ActionContext :
29
- action_key : str
30
- attributes : Attributes
31
-
32
26
@classmethod
33
- def create (
34
- cls ,
35
- action_key : str ,
36
- numeric_attributes : Dict [str , float ],
37
- categorical_attributes : Dict [str , str ],
38
- ):
39
- """
40
- Create an instance of ActionContext.
41
-
42
- Args:
43
- action_key (str): The key representing the action.
44
- numeric_attributes (Dict[str, float]): A dictionary of numeric attributes.
45
- categorical_attributes (Dict[str, str]): A dictionary of categorical attributes.
46
-
47
- Returns:
48
- ActionContext: An instance of ActionContext with the provided action key and attributes.
49
- """
50
- return cls (
51
- action_key ,
52
- Attributes (
53
- numeric_attributes = numeric_attributes ,
54
- categorical_attributes = categorical_attributes ,
55
- ),
56
- )
27
+ def empty (cls ):
28
+ return cls ({}, {})
57
29
58
- @property
59
- def numeric_attributes (self ):
60
- return self .attributes .numeric_attributes
61
30
62
- @property
63
- def categorical_attributes (self ):
64
- return self .attributes .categorical_attributes
31
+ ActionContexts = Dict [str , Attributes ]
65
32
66
33
67
34
@dataclass
@@ -104,101 +71,85 @@ def evaluate_bandit(
104
71
flag_key : str ,
105
72
subject_key : str ,
106
73
subject_attributes : Attributes ,
107
- actions_with_contexts : List [ ActionContext ] ,
74
+ actions : ActionContexts ,
108
75
bandit_model : BanditModelData ,
109
76
) -> BanditEvaluation :
110
77
# handle the edge case that there are no actions
111
- if not actions_with_contexts :
78
+ if not actions :
112
79
return null_evaluation (
113
80
flag_key , subject_key , subject_attributes , bandit_model .gamma
114
81
)
115
82
116
- action_scores = self .score_actions (
117
- subject_attributes , actions_with_contexts , bandit_model
118
- )
119
-
83
+ action_scores = self .score_actions (subject_attributes , actions , bandit_model )
120
84
action_weights = self .weigh_actions (
121
85
action_scores ,
122
86
bandit_model .gamma ,
123
87
bandit_model .action_probability_floor ,
124
88
)
125
89
126
90
selected_action = self .select_action (flag_key , subject_key , action_weights )
127
- selected_idx = next (
128
- idx
129
- for idx , action_context in enumerate (actions_with_contexts )
130
- if action_context .action_key == selected_action
131
- )
132
-
133
- optimality_gap = (
134
- max (score for _ , score in action_scores ) - action_scores [selected_idx ][1 ]
135
- )
91
+ optimality_gap = max (action_scores .values ()) - action_scores [selected_action ]
136
92
137
93
return BanditEvaluation (
138
94
flag_key ,
139
95
subject_key ,
140
96
subject_attributes ,
141
97
selected_action ,
142
- actions_with_contexts [ selected_idx ]. attributes ,
143
- action_scores [selected_idx ][ 1 ],
144
- action_weights [selected_idx ][ 1 ],
98
+ actions [ selected_action ] ,
99
+ action_scores [selected_action ],
100
+ action_weights [selected_action ],
145
101
bandit_model .gamma ,
146
102
optimality_gap ,
147
103
)
148
104
149
105
def score_actions (
150
106
self ,
151
107
subject_attributes : Attributes ,
152
- actions_with_contexts : List [ ActionContext ] ,
108
+ actions : ActionContexts ,
153
109
bandit_model : BanditModelData ,
154
- ) -> List [Tuple [str , float ]]:
155
- return [
156
- (
157
- action_context .action_key ,
158
- (
159
- score_action (
160
- subject_attributes ,
161
- action_context .attributes ,
162
- bandit_model .coefficients [action_context .action_key ],
163
- )
164
- if action_context .action_key in bandit_model .coefficients
165
- else bandit_model .default_action_score
166
- ),
110
+ ) -> Dict [str , float ]:
111
+ return {
112
+ action_key : (
113
+ score_action (
114
+ subject_attributes ,
115
+ action_attributes ,
116
+ bandit_model .coefficients [action_key ],
117
+ )
118
+ if action_key in bandit_model .coefficients
119
+ else bandit_model .default_action_score
167
120
)
168
- for action_context in actions_with_contexts
169
- ]
121
+ for action_key , action_attributes in actions . items ()
122
+ }
170
123
171
124
def weigh_actions (
172
125
self , action_scores , gamma , probability_floor
173
- ) -> List [ Tuple [ str , float ] ]:
126
+ ) -> Dict [ str , float ]:
174
127
number_of_actions = len (action_scores )
175
- best_action , best_score = max (action_scores , key = lambda t : t [1 ])
128
+ best_action = max (action_scores , key = action_scores .get )
129
+ best_score = action_scores [best_action ]
176
130
177
131
# adjust probability floor for number of actions to control the sum
178
132
min_probability = probability_floor / number_of_actions
179
133
180
134
# weight all but the best action
181
- weights = [
182
- (
183
- action_key ,
184
- max (
185
- min_probability ,
186
- 1.0 / (number_of_actions + gamma * (best_score - score )),
187
- ),
135
+ weights = {
136
+ action_key : max (
137
+ min_probability ,
138
+ 1.0 / (number_of_actions + gamma * (best_score - score )),
188
139
)
189
- for action_key , score in action_scores
140
+ for action_key , score in action_scores . items ()
190
141
if action_key != best_action
191
- ]
142
+ }
192
143
193
144
# remaining weight goes to best action
194
- remaining_weight = max (0.0 , 1.0 - sum (weight for _ , weight in weights ))
195
- weights . append (( best_action , remaining_weight ))
145
+ remaining_weight = max (0.0 , 1.0 - sum (weights . values () ))
146
+ weights [ best_action ] = remaining_weight
196
147
return weights
197
148
198
149
def select_action (self , flag_key , subject_key , action_weights ) -> str :
199
150
# deterministic ordering
200
151
sorted_action_weights = sorted (
201
- action_weights ,
152
+ action_weights . items () ,
202
153
key = lambda t : (
203
154
self .sharder .get_shard (
204
155
f"{ flag_key } -{ subject_key } -{ t [0 ]} " , self .total_shards
0 commit comments