5
5
import json
6
6
import os
7
7
from time import sleep
8
- from typing import Dict
8
+ from typing import Dict , List
9
9
from eppo_client .bandit import BanditResult , ActionContext , Attributes
10
10
11
11
import httpretty # type: ignore
34
34
35
35
36
36
class MockAssignmentLogger (AssignmentLogger ):
37
+ assignment_events : List [Dict ] = []
38
+ bandit_events : List [Dict ] = []
39
+
37
40
def log_assignment (self , assignment_event : Dict ):
38
- print ( f"Assignment Event: { assignment_event } " )
41
+ self . assignment_events . append ( assignment_event )
39
42
40
43
def log_bandit_action (self , bandit_event : Dict ):
41
- print (f"Bandit Event: { bandit_event } " )
44
+ self .bandit_events .append (bandit_event )
45
+
46
+
47
+ mock_assignment_logger = MockAssignmentLogger ()
42
48
43
49
44
50
@pytest .fixture (scope = "session" , autouse = True )
@@ -64,7 +70,7 @@ def init_fixture():
64
70
Config (
65
71
base_url = MOCK_BASE_URL ,
66
72
api_key = "dummy" ,
67
- assignment_logger = AssignmentLogger () ,
73
+ assignment_logger = mock_assignment_logger ,
68
74
)
69
75
)
70
76
sleep (0.1 ) # wait for initialization
@@ -102,16 +108,55 @@ def test_get_bandit_action_flag_without_bandit():
102
108
def test_get_bandit_action_with_subject_attributes ():
103
109
# tests that allocation filtering based on subject attributes works correctly
104
110
client = get_instance ()
111
+ actions = [
112
+ ActionContext .create ("adidas" , {"discount" : 0.1 }, {"from" : "germany" }),
113
+ ActionContext .create ("nike" , {"discount" : 0.2 }, {"from" : "usa" }),
114
+ ]
105
115
result = client .get_bandit_action (
106
116
"banner_bandit_flag_uk_only" ,
107
- "subject_key " ,
117
+ "alice " ,
108
118
DEFAULT_SUBJECT_ATTRIBUTES ,
109
- [ ActionContext . create ( "adidas" , {}, {}), ActionContext . create ( "nike" , {}, {})] ,
119
+ actions ,
110
120
"default_variation" ,
111
121
)
112
122
assert result .variation == "banner_bandit"
113
123
assert result .action in ["adidas" , "nike" ]
114
124
125
+ # testing assignment logger
126
+ assignment_log_statement = mock_assignment_logger .assignment_events [- 1 ]
127
+ assert assignment_log_statement ["featureFlag" ] == "banner_bandit_flag_uk_only"
128
+ assert assignment_log_statement ["variation" ] == "banner_bandit"
129
+ assert assignment_log_statement ["subject" ] == "alice"
130
+
131
+ # testing bandit logger
132
+ bandit_log_statement = mock_assignment_logger .bandit_events [- 1 ]
133
+ assert bandit_log_statement ["flagKey" ] == "banner_bandit_flag_uk_only"
134
+ assert bandit_log_statement ["banditKey" ] == "banner_bandit"
135
+ assert bandit_log_statement ["subject" ] == "alice"
136
+ assert (
137
+ bandit_log_statement ["subjectNumericAttributes" ]
138
+ == DEFAULT_SUBJECT_ATTRIBUTES .numeric_attributes
139
+ )
140
+ assert (
141
+ bandit_log_statement ["subjectCategoricalAttributes" ]
142
+ == DEFAULT_SUBJECT_ATTRIBUTES .categorical_attributes
143
+ )
144
+ assert bandit_log_statement ["action" ] == result .action
145
+ assert bandit_log_statement ["optimalityGap" ] >= 0
146
+ assert bandit_log_statement ["actionProbability" ] >= 0
147
+
148
+ chosen_action = next (
149
+ action for action in actions if action .action_key == result .action
150
+ )
151
+ assert (
152
+ bandit_log_statement ["actionNumericAttributes" ]
153
+ == chosen_action .attributes .numeric_attributes
154
+ )
155
+ assert (
156
+ bandit_log_statement ["actionCategoricalAttributes" ]
157
+ == chosen_action .attributes .categorical_attributes
158
+ )
159
+
115
160
116
161
@pytest .mark .parametrize ("test_case" , test_data )
117
162
def test_bandit_generic_test_cases (test_case ):
0 commit comments