|
| 1 | +package cloud.eppo; |
| 2 | + |
| 3 | +import cloud.eppo.ufc.dto.*; |
| 4 | +import java.util.*; |
| 5 | +import java.util.stream.Collectors; |
| 6 | + |
| 7 | +public class BanditEvaluator { |
| 8 | + |
| 9 | + private static final int BANDIT_ASSIGNMENT_SHARDS = 10000; // hard-coded for now |
| 10 | + |
| 11 | + public static BanditEvaluationResult evaluateBandit( |
| 12 | + String flagKey, |
| 13 | + String subjectKey, |
| 14 | + DiscriminableAttributes subjectAttributes, |
| 15 | + Actions actions, |
| 16 | + BanditModelData modelData) { |
| 17 | + Map<String, Double> actionScores = scoreActions(subjectAttributes, actions, modelData); |
| 18 | + Map<String, Double> actionWeights = |
| 19 | + weighActions(actionScores, modelData.getGamma(), modelData.getActionProbabilityFloor()); |
| 20 | + String selectedActionKey = selectAction(flagKey, subjectKey, actionWeights); |
| 21 | + |
| 22 | + // Compute optimality gap in terms of score |
| 23 | + double topScore = |
| 24 | + actionScores.values().stream().mapToDouble(Double::doubleValue).max().orElse(0); |
| 25 | + double optimalityGap = topScore - actionScores.get(selectedActionKey); |
| 26 | + |
| 27 | + return new BanditEvaluationResult( |
| 28 | + flagKey, |
| 29 | + subjectKey, |
| 30 | + subjectAttributes, |
| 31 | + selectedActionKey, |
| 32 | + actions.get(selectedActionKey), |
| 33 | + actionScores.get(selectedActionKey), |
| 34 | + actionWeights.get(selectedActionKey), |
| 35 | + modelData.getGamma(), |
| 36 | + optimalityGap); |
| 37 | + } |
| 38 | + |
| 39 | + private static Map<String, Double> scoreActions( |
| 40 | + DiscriminableAttributes subjectAttributes, Actions actions, BanditModelData modelData) { |
| 41 | + return actions.entrySet().stream() |
| 42 | + .collect( |
| 43 | + Collectors.toMap( |
| 44 | + Map.Entry::getKey, |
| 45 | + e -> { |
| 46 | + String actionName = e.getKey(); |
| 47 | + DiscriminableAttributes actionAttributes = e.getValue(); |
| 48 | + |
| 49 | + // get all coefficients known to the model for this action |
| 50 | + BanditCoefficients banditCoefficients = |
| 51 | + modelData.getCoefficients().get(actionName); |
| 52 | + |
| 53 | + if (banditCoefficients == null) { |
| 54 | + // Unknown action; return the default action score |
| 55 | + return modelData.getDefaultActionScore(); |
| 56 | + } |
| 57 | + |
| 58 | + // Score the action using the provided attributes |
| 59 | + double actionScore = banditCoefficients.getIntercept(); |
| 60 | + actionScore += |
| 61 | + scoreContextForCoefficients( |
| 62 | + actionAttributes.getNumericAttributes(), |
| 63 | + banditCoefficients.getActionNumericCoefficients()); |
| 64 | + actionScore += |
| 65 | + scoreContextForCoefficients( |
| 66 | + actionAttributes.getCategoricalAttributes(), |
| 67 | + banditCoefficients.getActionCategoricalCoefficients()); |
| 68 | + actionScore += |
| 69 | + scoreContextForCoefficients( |
| 70 | + subjectAttributes.getNumericAttributes(), |
| 71 | + banditCoefficients.getSubjectNumericCoefficients()); |
| 72 | + actionScore += |
| 73 | + scoreContextForCoefficients( |
| 74 | + subjectAttributes.getCategoricalAttributes(), |
| 75 | + banditCoefficients.getSubjectCategoricalCoefficients()); |
| 76 | + |
| 77 | + return actionScore; |
| 78 | + })); |
| 79 | + } |
| 80 | + |
| 81 | + private static double scoreContextForCoefficients( |
| 82 | + Attributes attributes, Map<String, ? extends BanditAttributeCoefficients> coefficients) { |
| 83 | + |
| 84 | + double totalScore = 0.0; |
| 85 | + |
| 86 | + for (BanditAttributeCoefficients attributeCoefficients : coefficients.values()) { |
| 87 | + EppoValue contextValue = attributes.get(attributeCoefficients.getAttributeKey()); |
| 88 | + // The coefficient implementation knows how to score |
| 89 | + double attributeScore = attributeCoefficients.scoreForAttributeValue(contextValue); |
| 90 | + totalScore += attributeScore; |
| 91 | + } |
| 92 | + |
| 93 | + return totalScore; |
| 94 | + } |
| 95 | + |
| 96 | + private static Map<String, Double> weighActions( |
| 97 | + Map<String, Double> actionScores, double gamma, double actionProbabilityFloor) { |
| 98 | + Double highestScore = null; |
| 99 | + String highestScoredAction = null; |
| 100 | + for (Map.Entry<String, Double> actionScore : actionScores.entrySet()) { |
| 101 | + if (highestScore == null |
| 102 | + || actionScore.getValue() > highestScore |
| 103 | + || actionScore |
| 104 | + .getValue() |
| 105 | + .equals(highestScore) // note: we break ties for scores by action name |
| 106 | + && actionScore.getKey().compareTo(highestScoredAction) < 0) { |
| 107 | + highestScore = actionScore.getValue(); |
| 108 | + highestScoredAction = actionScore.getKey(); |
| 109 | + } |
| 110 | + } |
| 111 | + |
| 112 | + // Weigh all the actions using their score |
| 113 | + Map<String, Double> actionWeights = new HashMap<>(); |
| 114 | + double totalNonHighestWeight = 0.0; |
| 115 | + for (Map.Entry<String, Double> actionScore : actionScores.entrySet()) { |
| 116 | + if (actionScore.getKey().equals(highestScoredAction)) { |
| 117 | + // The highest scored action is weighed at the end |
| 118 | + continue; |
| 119 | + } |
| 120 | + |
| 121 | + // Compute weight (probability) |
| 122 | + double unboundedProbability = |
| 123 | + 1 / (actionScores.size() + (gamma * (highestScore - actionScore.getValue()))); |
| 124 | + double minimumProbability = actionProbabilityFloor / actionScores.size(); |
| 125 | + double boundedProbability = Math.max(unboundedProbability, minimumProbability); |
| 126 | + totalNonHighestWeight += boundedProbability; |
| 127 | + |
| 128 | + actionWeights.put(actionScore.getKey(), boundedProbability); |
| 129 | + } |
| 130 | + |
| 131 | + // Weigh the highest scoring action (defensively preventing a negative probability) |
| 132 | + double weightForHighestScore = Math.max(1 - totalNonHighestWeight, 0); |
| 133 | + actionWeights.put(highestScoredAction, weightForHighestScore); |
| 134 | + return actionWeights; |
| 135 | + } |
| 136 | + |
| 137 | + private static String selectAction( |
| 138 | + String flagKey, String subjectKey, Map<String, Double> actionWeights) { |
| 139 | + // Deterministically "shuffle" the actions |
| 140 | + // This way as action weights shift, a bunch of users who were on the edge of one action won't |
| 141 | + // all be shifted to the same new action at the same time. |
| 142 | + List<String> shuffledActionKeys = |
| 143 | + actionWeights.keySet().stream() |
| 144 | + .sorted( |
| 145 | + Comparator.comparingInt( |
| 146 | + (String actionKey) -> |
| 147 | + ShardUtils.getShard( |
| 148 | + flagKey + "-" + subjectKey + "-" + actionKey, |
| 149 | + BANDIT_ASSIGNMENT_SHARDS)) |
| 150 | + .thenComparing(actionKey -> actionKey)) |
| 151 | + .collect(Collectors.toList()); |
| 152 | + |
| 153 | + // Select action from the shuffled actions, based on weight |
| 154 | + double assignedShard = |
| 155 | + ShardUtils.getShard(flagKey + "-" + subjectKey, BANDIT_ASSIGNMENT_SHARDS); |
| 156 | + double assignmentWeightThreshold = assignedShard / (double) BANDIT_ASSIGNMENT_SHARDS; |
| 157 | + double cumulativeWeight = 0; |
| 158 | + String assignedAction = null; |
| 159 | + for (String actionKey : shuffledActionKeys) { |
| 160 | + cumulativeWeight += actionWeights.get(actionKey); |
| 161 | + if (cumulativeWeight > assignmentWeightThreshold) { |
| 162 | + assignedAction = actionKey; |
| 163 | + break; |
| 164 | + } |
| 165 | + } |
| 166 | + return assignedAction; |
| 167 | + } |
| 168 | +} |
0 commit comments