Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint-test-sdk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ jobs:
gpg-passphrase: ${{ secrets.GPG_PASSPHRASE }}

- name: Run tests
run: ./gradlew check --no-daemon
run: make test
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
src/test/resources/shared
.gradle
build/
!gradle/wrapper/gradle-wrapper.jar
Expand Down
48 changes: 48 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Make settings - @see https://tech.davis-hansson.com/p/make/
SHELL := bash
.ONESHELL:
.SHELLFLAGS := -eu -o pipefail -c
.DELETE_ON_ERROR:
MAKEFLAGS += --warn-undefined-variables
MAKEFLAGS += --no-builtin-rules

# Log levels
DEBUG := $(shell printf "\e[2D\e[35m")
INFO := $(shell printf "\e[2D\e[36m🔵 ")
OK := $(shell printf "\e[2D\e[32m🟢 ")
WARN := $(shell printf "\e[2D\e[33m🟡 ")
ERROR := $(shell printf "\e[2D\e[31m🔴 ")
END := $(shell printf "\e[0m")


.PHONY: default
default: help

## help - Print help message.
.PHONY: help
help: Makefile
@echo "usage: make <target>"
@sed -n 's/^##//p' $<

.PHONY: build
build: test-data
./gradlew assemble

## test-data
testDataDir := src/test/resources/shared
tempDir := ${testDataDir}/temp
gitDataDir := ${tempDir}/sdk-test-data
branchName := main
githubRepoLink := https://github.com/Eppo-exp/sdk-test-data.git
.PHONY: test-data
test-data:
rm -rf $(testDataDir)
mkdir -p ${tempDir}
git clone -b ${branchName} --depth 1 --single-branch ${githubRepoLink} ${gitDataDir}
cp -r ${gitDataDir}/ufc ${testDataDir}
rm ${testDataDir}/ufc/bandit-tests/*.dynamic-typing.json
rm -rf ${tempDir}

.PHONY: test
test: test-data build
./gradlew check --no-daemon
38 changes: 30 additions & 8 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@ plugins {
}

group = 'cloud.eppo'
version = '2.1.0-SNAPSHOT'
version = '3.0.0-SNAPSHOT'
ext.isReleaseVersion = !version.endsWith("SNAPSHOT")

dependencies {
implementation 'com.fasterxml.jackson.core:jackson-databind:2.17.2'
implementation 'com.github.zafarkhaja:java-semver:0.10.2'
implementation "com.squareup.okhttp3:okhttp:4.12.0"

// For UFC DTOs
implementation 'commons-codec:commons-codec:1.17.0'
implementation 'org.slf4j:slf4j-api:2.0.13'
Expand All @@ -19,10 +22,19 @@ dependencies {
testImplementation 'org.skyscreamer:jsonassert:1.5.3'
testImplementation 'commons-io:commons-io:2.11.0'
testImplementation "com.google.truth:truth:1.4.4"
testImplementation 'org.mockito:mockito-core:4.11.0'
testImplementation 'org.mockito:mockito-inline:4.11.0'
}

test {
useJUnitPlatform()
testLogging {
events "started", "passed", "skipped", "failed"
exceptionFormat "full"
showExceptions true
showCauses true
showStackTraces true
}
}

spotless {
Expand Down Expand Up @@ -50,11 +62,17 @@ java {
withSourcesJar()
}

tasks.register('testJar', Jar) {
archiveClassifier.set('tests')
from sourceSets.test.output
}

publishing {
publications {
mavenJava(MavenPublication) {
artifactId = 'sdk-common-jvm'
from components.java
artifact testJar // Include the test-jar in the published artifacts
versionMapping {
usage('java-api') {
fromResolutionOf('runtimeClasspath')
Expand Down Expand Up @@ -102,7 +120,7 @@ publishing {

// Custom task to ensure we can conditionally publish either a release or snapshot artifact
// based on a command line switch. See github workflow files for more details on usage.
task checkVersion {
tasks.register('checkVersion') {
doLast {
if (!project.hasProperty('release') && !project.hasProperty('snapshot')) {
throw new GradleException("You must specify either -Prelease or -Psnapshot")
Expand All @@ -123,21 +141,25 @@ tasks.named('publish').configure {
}

// Conditionally enable or disable publishing tasks
tasks.withType(PublishToMavenRepository) {
tasks.withType(PublishToMavenRepository).configureEach {
onlyIf {
project.ext.has('shouldPublish') && project.ext.shouldPublish
}
}

signing {
sign publishing.publications.mavenJava
if (System.env['CI']) {
useInMemoryPgpKeys(System.env.GPG_PRIVATE_KEY, System.env.GPG_PASSPHRASE)
if (!project.gradle.startParameter.taskNames.contains('publishToMavenLocal')) {
signing {
sign publishing.publications.mavenJava
if (System.env['CI']) {
useInMemoryPgpKeys(System.env.GPG_PRIVATE_KEY, System.env.GPG_PASSPHRASE)
}
}
}


javadoc {
failOnError = false
options.addStringOption('Xdoclint:none', '-quiet')
options.addBooleanOption('failOnError', false)
if (JavaVersion.current().isJava9Compatible()) {
options.addBooleanOption('html5', true)
}
Expand Down
73 changes: 73 additions & 0 deletions src/main/java/cloud/eppo/BanditEvaluationResult.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package cloud.eppo;

import cloud.eppo.ufc.dto.DiscriminableAttributes;

public class BanditEvaluationResult {

private final String flagKey;
private final String subjectKey;
private final DiscriminableAttributes subjectAttributes;
private final String actionKey;
private final DiscriminableAttributes actionAttributes;
private final double actionScore;
private final double actionWeight;
private final double gamma;
private final double optimalityGap;

public BanditEvaluationResult(
String flagKey,
String subjectKey,
DiscriminableAttributes subjectAttributes,
String actionKey,
DiscriminableAttributes actionAttributes,
double actionScore,
double actionWeight,
double gamma,
double optimalityGap) {
this.flagKey = flagKey;
this.subjectKey = subjectKey;
this.subjectAttributes = subjectAttributes;
this.actionKey = actionKey;
this.actionAttributes = actionAttributes;
this.actionScore = actionScore;
this.actionWeight = actionWeight;
this.gamma = gamma;
this.optimalityGap = optimalityGap;
}

public String getFlagKey() {
return flagKey;
}

public String getSubjectKey() {
return subjectKey;
}

public DiscriminableAttributes getSubjectAttributes() {
return subjectAttributes;
}

public String getActionKey() {
return actionKey;
}

public DiscriminableAttributes getActionAttributes() {
return actionAttributes;
}

public double getActionScore() {
return actionScore;
}

public double getActionWeight() {
return actionWeight;
}

public double getGamma() {
return gamma;
}

public double getOptimalityGap() {
return optimalityGap;
}
}
168 changes: 168 additions & 0 deletions src/main/java/cloud/eppo/BanditEvaluator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package cloud.eppo;

import cloud.eppo.ufc.dto.*;
import java.util.*;
import java.util.stream.Collectors;

public class BanditEvaluator {

private static final int BANDIT_ASSIGNMENT_SHARDS = 10000; // hard-coded for now

public static BanditEvaluationResult evaluateBandit(
String flagKey,
String subjectKey,
DiscriminableAttributes subjectAttributes,
Actions actions,
BanditModelData modelData) {
Map<String, Double> actionScores = scoreActions(subjectAttributes, actions, modelData);
Map<String, Double> actionWeights =
weighActions(actionScores, modelData.getGamma(), modelData.getActionProbabilityFloor());
String selectedActionKey = selectAction(flagKey, subjectKey, actionWeights);

// Compute optimality gap in terms of score
double topScore =
actionScores.values().stream().mapToDouble(Double::doubleValue).max().orElse(0);
double optimalityGap = topScore - actionScores.get(selectedActionKey);

return new BanditEvaluationResult(
flagKey,
subjectKey,
subjectAttributes,
selectedActionKey,
actions.get(selectedActionKey),
actionScores.get(selectedActionKey),
actionWeights.get(selectedActionKey),
modelData.getGamma(),
optimalityGap);
}

private static Map<String, Double> scoreActions(
DiscriminableAttributes subjectAttributes, Actions actions, BanditModelData modelData) {
return actions.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
e -> {
String actionName = e.getKey();
DiscriminableAttributes actionAttributes = e.getValue();

// get all coefficients known to the model for this action
BanditCoefficients banditCoefficients =
modelData.getCoefficients().get(actionName);

if (banditCoefficients == null) {
// Unknown action; return the default action score
return modelData.getDefaultActionScore();
}

// Score the action using the provided attributes
double actionScore = banditCoefficients.getIntercept();
actionScore +=
scoreContextForCoefficients(
actionAttributes.getNumericAttributes(),
banditCoefficients.getActionNumericCoefficients());
actionScore +=
scoreContextForCoefficients(
actionAttributes.getCategoricalAttributes(),
banditCoefficients.getActionCategoricalCoefficients());
actionScore +=
scoreContextForCoefficients(
subjectAttributes.getNumericAttributes(),
banditCoefficients.getSubjectNumericCoefficients());
actionScore +=
scoreContextForCoefficients(
subjectAttributes.getCategoricalAttributes(),
banditCoefficients.getSubjectCategoricalCoefficients());

return actionScore;
}));
}

private static double scoreContextForCoefficients(
Attributes attributes, Map<String, ? extends BanditAttributeCoefficients> coefficients) {

double totalScore = 0.0;

for (BanditAttributeCoefficients attributeCoefficients : coefficients.values()) {
EppoValue contextValue = attributes.get(attributeCoefficients.getAttributeKey());
// The coefficient implementation knows how to score
double attributeScore = attributeCoefficients.scoreForAttributeValue(contextValue);
totalScore += attributeScore;
}

return totalScore;
}

private static Map<String, Double> weighActions(
Map<String, Double> actionScores, double gamma, double actionProbabilityFloor) {
Double highestScore = null;
String highestScoredAction = null;
for (Map.Entry<String, Double> actionScore : actionScores.entrySet()) {
if (highestScore == null
|| actionScore.getValue() > highestScore
|| actionScore
.getValue()
.equals(highestScore) // note: we break ties for scores by action name
&& actionScore.getKey().compareTo(highestScoredAction) < 0) {
highestScore = actionScore.getValue();
highestScoredAction = actionScore.getKey();
}
}

// Weigh all the actions using their score
Map<String, Double> actionWeights = new HashMap<>();
double totalNonHighestWeight = 0.0;
for (Map.Entry<String, Double> actionScore : actionScores.entrySet()) {
if (actionScore.getKey().equals(highestScoredAction)) {
// The highest scored action is weighed at the end
continue;
}

// Compute weight (probability)
double unboundedProbability =
1 / (actionScores.size() + (gamma * (highestScore - actionScore.getValue())));
double minimumProbability = actionProbabilityFloor / actionScores.size();
double boundedProbability = Math.max(unboundedProbability, minimumProbability);
totalNonHighestWeight += boundedProbability;

actionWeights.put(actionScore.getKey(), boundedProbability);
}

// Weigh the highest scoring action (defensively preventing a negative probability)
double weightForHighestScore = Math.max(1 - totalNonHighestWeight, 0);
actionWeights.put(highestScoredAction, weightForHighestScore);
return actionWeights;
}

private static String selectAction(
String flagKey, String subjectKey, Map<String, Double> actionWeights) {
// Deterministically "shuffle" the actions
// This way as action weights shift, a bunch of users who were on the edge of one action won't
// all be shifted to the same new action at the same time.
List<String> shuffledActionKeys =
actionWeights.keySet().stream()
.sorted(
Comparator.comparingInt(
(String actionKey) ->
ShardUtils.getShard(
flagKey + "-" + subjectKey + "-" + actionKey,
BANDIT_ASSIGNMENT_SHARDS))
.thenComparing(actionKey -> actionKey))
.collect(Collectors.toList());

// Select action from the shuffled actions, based on weight
double assignedShard =
ShardUtils.getShard(flagKey + "-" + subjectKey, BANDIT_ASSIGNMENT_SHARDS);
double assignmentWeightThreshold = assignedShard / (double) BANDIT_ASSIGNMENT_SHARDS;
double cumulativeWeight = 0;
String assignedAction = null;
for (String actionKey : shuffledActionKeys) {
cumulativeWeight += actionWeights.get(actionKey);
if (cumulativeWeight > assignmentWeightThreshold) {
assignedAction = actionKey;
break;
}
}
return assignedAction;
}
}
Loading