Skip to content

Commit 55c3fce

Browse files
authored
Add bandit support (#25)
* bandit test harness ready * add bandit result * set up for dropping in bandit evaluation * bandit deserialization wired up * loading bandit parameters * bandit stuff happening * shared bandit tests passing * bandit logger classes * bandit log test passing * more tests for logger * bandit tests for graceful mode * apply spotless formatting autofix * changes from self-review of PR so far * more changes from self-review of PR * more changes from self-review * make test less fragile * bump version; don't sign local maven * bandit logging errors should be non-fatal * use normalized probability floor * update result before even attempting to log bandit * spotless 🙄
1 parent 9e8ba3e commit 55c3fce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1884
-934
lines changed

build.gradle

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ plugins {
66
}
77

88
group = 'cloud.eppo'
9-
version = '2.1.0-SNAPSHOT'
9+
version = '3.0.0-SNAPSHOT'
1010
ext.isReleaseVersion = !version.endsWith("SNAPSHOT")
1111

1212
dependencies {
@@ -23,6 +23,7 @@ dependencies {
2323
testImplementation 'commons-io:commons-io:2.11.0'
2424
testImplementation "com.google.truth:truth:1.4.4"
2525
testImplementation 'org.mockito:mockito-core:4.11.0'
26+
testImplementation 'org.mockito:mockito-inline:4.11.0'
2627
}
2728

2829
test {
@@ -140,14 +141,15 @@ tasks.withType(PublishToMavenRepository) {
140141
}
141142
}
142143

143-
signing {
144-
sign publishing.publications.mavenJava
145-
if (System.env['CI']) {
146-
useInMemoryPgpKeys(System.env.GPG_PRIVATE_KEY, System.env.GPG_PASSPHRASE)
144+
if (!project.gradle.startParameter.taskNames.contains('publishToMavenLocal')) {
145+
signing {
146+
sign publishing.publications.mavenJava
147+
if (System.env['CI']) {
148+
useInMemoryPgpKeys(System.env.GPG_PRIVATE_KEY, System.env.GPG_PASSPHRASE)
149+
}
147150
}
148151
}
149152

150-
151153
javadoc {
152154
failOnError = false
153155
options.addStringOption('Xdoclint:none', '-quiet')
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package cloud.eppo;
2+
3+
import cloud.eppo.ufc.dto.DiscriminableAttributes;
4+
5+
public class BanditEvaluationResult {
6+
7+
private final String flagKey;
8+
private final String subjectKey;
9+
private final DiscriminableAttributes subjectAttributes;
10+
private final String actionKey;
11+
private final DiscriminableAttributes actionAttributes;
12+
private final double actionScore;
13+
private final double actionWeight;
14+
private final double gamma;
15+
private final double optimalityGap;
16+
17+
public BanditEvaluationResult(
18+
String flagKey,
19+
String subjectKey,
20+
DiscriminableAttributes subjectAttributes,
21+
String actionKey,
22+
DiscriminableAttributes actionAttributes,
23+
double actionScore,
24+
double actionWeight,
25+
double gamma,
26+
double optimalityGap) {
27+
this.flagKey = flagKey;
28+
this.subjectKey = subjectKey;
29+
this.subjectAttributes = subjectAttributes;
30+
this.actionKey = actionKey;
31+
this.actionAttributes = actionAttributes;
32+
this.actionScore = actionScore;
33+
this.actionWeight = actionWeight;
34+
this.gamma = gamma;
35+
this.optimalityGap = optimalityGap;
36+
}
37+
38+
public String getFlagKey() {
39+
return flagKey;
40+
}
41+
42+
public String getSubjectKey() {
43+
return subjectKey;
44+
}
45+
46+
public DiscriminableAttributes getSubjectAttributes() {
47+
return subjectAttributes;
48+
}
49+
50+
public String getActionKey() {
51+
return actionKey;
52+
}
53+
54+
public DiscriminableAttributes getActionAttributes() {
55+
return actionAttributes;
56+
}
57+
58+
public double getActionScore() {
59+
return actionScore;
60+
}
61+
62+
public double getActionWeight() {
63+
return actionWeight;
64+
}
65+
66+
public double getGamma() {
67+
return gamma;
68+
}
69+
70+
public double getOptimalityGap() {
71+
return optimalityGap;
72+
}
73+
}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package cloud.eppo;
2+
3+
import cloud.eppo.ufc.dto.*;
4+
import java.util.*;
5+
import java.util.stream.Collectors;
6+
7+
public class BanditEvaluator {
8+
9+
private static final int BANDIT_ASSIGNMENT_SHARDS = 10000; // hard-coded for now
10+
11+
public static BanditEvaluationResult evaluateBandit(
12+
String flagKey,
13+
String subjectKey,
14+
DiscriminableAttributes subjectAttributes,
15+
Actions actions,
16+
BanditModelData modelData) {
17+
Map<String, Double> actionScores = scoreActions(subjectAttributes, actions, modelData);
18+
Map<String, Double> actionWeights =
19+
weighActions(actionScores, modelData.getGamma(), modelData.getActionProbabilityFloor());
20+
String selectedActionKey = selectAction(flagKey, subjectKey, actionWeights);
21+
22+
// Compute optimality gap in terms of score
23+
double topScore =
24+
actionScores.values().stream().mapToDouble(Double::doubleValue).max().orElse(0);
25+
double optimalityGap = topScore - actionScores.get(selectedActionKey);
26+
27+
return new BanditEvaluationResult(
28+
flagKey,
29+
subjectKey,
30+
subjectAttributes,
31+
selectedActionKey,
32+
actions.get(selectedActionKey),
33+
actionScores.get(selectedActionKey),
34+
actionWeights.get(selectedActionKey),
35+
modelData.getGamma(),
36+
optimalityGap);
37+
}
38+
39+
private static Map<String, Double> scoreActions(
40+
DiscriminableAttributes subjectAttributes, Actions actions, BanditModelData modelData) {
41+
return actions.entrySet().stream()
42+
.collect(
43+
Collectors.toMap(
44+
Map.Entry::getKey,
45+
e -> {
46+
String actionName = e.getKey();
47+
DiscriminableAttributes actionAttributes = e.getValue();
48+
49+
// get all coefficients known to the model for this action
50+
BanditCoefficients banditCoefficients =
51+
modelData.getCoefficients().get(actionName);
52+
53+
if (banditCoefficients == null) {
54+
// Unknown action; return the default action score
55+
return modelData.getDefaultActionScore();
56+
}
57+
58+
// Score the action using the provided attributes
59+
double actionScore = banditCoefficients.getIntercept();
60+
actionScore +=
61+
scoreContextForCoefficients(
62+
actionAttributes.getNumericAttributes(),
63+
banditCoefficients.getActionNumericCoefficients());
64+
actionScore +=
65+
scoreContextForCoefficients(
66+
actionAttributes.getCategoricalAttributes(),
67+
banditCoefficients.getActionCategoricalCoefficients());
68+
actionScore +=
69+
scoreContextForCoefficients(
70+
subjectAttributes.getNumericAttributes(),
71+
banditCoefficients.getSubjectNumericCoefficients());
72+
actionScore +=
73+
scoreContextForCoefficients(
74+
subjectAttributes.getCategoricalAttributes(),
75+
banditCoefficients.getSubjectCategoricalCoefficients());
76+
77+
return actionScore;
78+
}));
79+
}
80+
81+
private static double scoreContextForCoefficients(
82+
Attributes attributes, Map<String, ? extends BanditAttributeCoefficients> coefficients) {
83+
84+
double totalScore = 0.0;
85+
86+
for (BanditAttributeCoefficients attributeCoefficients : coefficients.values()) {
87+
EppoValue contextValue = attributes.get(attributeCoefficients.getAttributeKey());
88+
// The coefficient implementation knows how to score
89+
double attributeScore = attributeCoefficients.scoreForAttributeValue(contextValue);
90+
totalScore += attributeScore;
91+
}
92+
93+
return totalScore;
94+
}
95+
96+
private static Map<String, Double> weighActions(
97+
Map<String, Double> actionScores, double gamma, double actionProbabilityFloor) {
98+
Double highestScore = null;
99+
String highestScoredAction = null;
100+
for (Map.Entry<String, Double> actionScore : actionScores.entrySet()) {
101+
if (highestScore == null
102+
|| actionScore.getValue() > highestScore
103+
|| actionScore
104+
.getValue()
105+
.equals(highestScore) // note: we break ties for scores by action name
106+
&& actionScore.getKey().compareTo(highestScoredAction) < 0) {
107+
highestScore = actionScore.getValue();
108+
highestScoredAction = actionScore.getKey();
109+
}
110+
}
111+
112+
// Weigh all the actions using their score
113+
Map<String, Double> actionWeights = new HashMap<>();
114+
double totalNonHighestWeight = 0.0;
115+
for (Map.Entry<String, Double> actionScore : actionScores.entrySet()) {
116+
if (actionScore.getKey().equals(highestScoredAction)) {
117+
// The highest scored action is weighed at the end
118+
continue;
119+
}
120+
121+
// Compute weight (probability)
122+
double unboundedProbability =
123+
1 / (actionScores.size() + (gamma * (highestScore - actionScore.getValue())));
124+
double minimumProbability = actionProbabilityFloor / actionScores.size();
125+
double boundedProbability = Math.max(unboundedProbability, minimumProbability);
126+
totalNonHighestWeight += boundedProbability;
127+
128+
actionWeights.put(actionScore.getKey(), boundedProbability);
129+
}
130+
131+
// Weigh the highest scoring action (defensively preventing a negative probability)
132+
double weightForHighestScore = Math.max(1 - totalNonHighestWeight, 0);
133+
actionWeights.put(highestScoredAction, weightForHighestScore);
134+
return actionWeights;
135+
}
136+
137+
private static String selectAction(
138+
String flagKey, String subjectKey, Map<String, Double> actionWeights) {
139+
// Deterministically "shuffle" the actions
140+
// This way as action weights shift, a bunch of users who were on the edge of one action won't
141+
// all be shifted to the same new action at the same time.
142+
List<String> shuffledActionKeys =
143+
actionWeights.keySet().stream()
144+
.sorted(
145+
Comparator.comparingInt(
146+
(String actionKey) ->
147+
ShardUtils.getShard(
148+
flagKey + "-" + subjectKey + "-" + actionKey,
149+
BANDIT_ASSIGNMENT_SHARDS))
150+
.thenComparing(actionKey -> actionKey))
151+
.collect(Collectors.toList());
152+
153+
// Select action from the shuffled actions, based on weight
154+
double assignedShard =
155+
ShardUtils.getShard(flagKey + "-" + subjectKey, BANDIT_ASSIGNMENT_SHARDS);
156+
double assignmentWeightThreshold = assignedShard / (double) BANDIT_ASSIGNMENT_SHARDS;
157+
double cumulativeWeight = 0;
158+
String assignedAction = null;
159+
for (String actionKey : shuffledActionKeys) {
160+
cumulativeWeight += actionWeights.get(actionKey);
161+
if (cumulativeWeight > assignmentWeightThreshold) {
162+
assignedAction = actionKey;
163+
break;
164+
}
165+
}
166+
return assignedAction;
167+
}
168+
}

src/main/java/cloud/eppo/ConfigurationRequestor.java

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import cloud.eppo.ufc.dto.FlagConfig;
44
import java.io.IOException;
5+
import java.util.HashSet;
6+
import java.util.Set;
57
import okhttp3.Response;
68
import org.slf4j.Logger;
79
import org.slf4j.LoggerFactory;
@@ -12,28 +14,44 @@ public class ConfigurationRequestor {
1214

1315
private final EppoHttpClient client;
1416
private final ConfigurationStore configurationStore;
17+
private final Set<String> loadedBanditModelVersions;
1518

1619
public ConfigurationRequestor(ConfigurationStore configurationStore, EppoHttpClient client) {
1720
this.configurationStore = configurationStore;
1821
this.client = client;
22+
this.loadedBanditModelVersions = new HashSet<>();
1923
}
2024

25+
// TODO: async loading for android
2126
public void load() {
2227
log.debug("Fetching configuration");
23-
Response response = client.get("/api/flag-config/v1/config");
28+
String flagConfigurationJsonString = requestBody("/api/flag-config/v1/config");
29+
configurationStore.setFlagsFromJsonString(flagConfigurationJsonString);
30+
31+
Set<String> neededModelVersions = configurationStore.banditModelVersions();
32+
boolean needBanditParameters = !loadedBanditModelVersions.containsAll(neededModelVersions);
33+
if (needBanditParameters) {
34+
String banditParametersJsonString = requestBody("/api/flag-config/v1/bandits");
35+
configurationStore.setBanditParametersFromJsonString(banditParametersJsonString);
36+
// Record the model versions that we just loaded, so we can compare when the store is later
37+
// updated
38+
loadedBanditModelVersions.clear();
39+
loadedBanditModelVersions.addAll(configurationStore.banditModelVersions());
40+
}
41+
}
42+
43+
private String requestBody(String route) {
44+
Response response = client.get(route);
45+
if (!response.isSuccessful() || response.body() == null) {
46+
throw new RuntimeException("Failed to fetch from " + route);
47+
}
2448
try {
25-
if (!response.isSuccessful()) {
26-
throw new RuntimeException("Failed to fetch configuration");
27-
}
28-
configurationStore.setFlagsFromJsonString(response.body().string());
49+
return response.body().string();
2950
} catch (IOException e) {
30-
// TODO: better exception handling?
3151
throw new RuntimeException(e);
3252
}
3353
}
3454

35-
// TODO: async loading for android
36-
3755
public FlagConfig getConfiguration(String flagKey) {
3856
return configurationStore.getFlag(flagKey);
3957
}

0 commit comments

Comments
 (0)