Skip to content

Commit e7ecbda

Browse files
authored
Optimize assignment getting (#30)
* Eppo Client with shared UFC tests passing (#23) * tests passing for rule evaluator, flag evaluator, and eppo value * work in progress * shared UFC tests passing * don't check in test data * changes from self-review of PR * apply spotless linter * working on tests * better test logging * use make test for tests * Add bandit support (#25) * bandit test harness ready * add bandit result * set up for dropping in bandit evaluation * bandit deserialization wired up * loading bandit parameters * bandit stuff happening * shared bandit tests passing * bandit logger classes * bandit log test passing * more tests for logger * bandit tests for graceful mode * apply spotless formatting autofix * changes from self-review of PR so far * more changes from self-review of PR * more changes from self-review * make test less fragile * bump version; don't sign local maven * bandit logging errors should be non-fatal * use normalized probability floor * update result before even attempting to log bandit * spotless 🙄 * do the rename (#26) * work in progress * remove singleton for base client * linter * expose bandit test harnesses * expose test uilities * changes from self-review of PR * make base client constructor protected * add simple logger for tests * basic profiling test * improvement not using bigint * faster getShard() * linter * more helpful failure message * increase CPU time allowance to account for slower machines * bump version
1 parent d5552e4 commit e7ecbda

File tree

7 files changed

+177
-36
lines changed

7 files changed

+177
-36
lines changed

build.gradle

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ plugins {
66
}
77

88
group = 'cloud.eppo'
9-
version = '3.0.0-SNAPSHOT'
9+
version = '3.0.1-SNAPSHOT'
1010
ext.isReleaseVersion = !version.endsWith("SNAPSHOT")
1111

1212
dependencies {
@@ -17,6 +17,7 @@ dependencies {
1717
// For UFC DTOs
1818
implementation 'commons-codec:commons-codec:1.17.0'
1919
implementation 'org.slf4j:slf4j-api:2.0.13'
20+
testImplementation 'org.slf4j:slf4j-simple:2.0.16'
2021
testImplementation platform('org.junit:junit-bom:5.10.3')
2122
testImplementation 'org.junit.jupiter:junit-jupiter'
2223
testImplementation 'org.skyscreamer:jsonassert:1.5.3'

src/main/java/cloud/eppo/BanditEvaluator.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package cloud.eppo;
22

3+
import static cloud.eppo.Utils.getShard;
4+
35
import cloud.eppo.ufc.dto.*;
46
import java.util.*;
57
import java.util.stream.Collectors;
@@ -144,15 +146,14 @@ private static String selectAction(
144146
.sorted(
145147
Comparator.comparingInt(
146148
(String actionKey) ->
147-
ShardUtils.getShard(
149+
getShard(
148150
flagKey + "-" + subjectKey + "-" + actionKey,
149151
BANDIT_ASSIGNMENT_SHARDS))
150152
.thenComparing(actionKey -> actionKey))
151153
.collect(Collectors.toList());
152154

153155
// Select action from the shuffled actions, based on weight
154-
double assignedShard =
155-
ShardUtils.getShard(flagKey + "-" + subjectKey, BANDIT_ASSIGNMENT_SHARDS);
156+
double assignedShard = getShard(flagKey + "-" + subjectKey, BANDIT_ASSIGNMENT_SHARDS);
156157
double assignmentWeightThreshold = assignedShard / (double) BANDIT_ASSIGNMENT_SHARDS;
157158
double cumulativeWeight = 0;
158159
String assignedAction = null;

src/main/java/cloud/eppo/FlagEvaluator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package cloud.eppo;
22

3-
import static cloud.eppo.ShardUtils.getShard;
43
import static cloud.eppo.Utils.base64Decode;
4+
import static cloud.eppo.Utils.getShard;
55

66
import cloud.eppo.model.ShardRange;
77
import cloud.eppo.ufc.dto.Allocation;

src/main/java/cloud/eppo/ShardUtils.java

Lines changed: 0 additions & 18 deletions
This file was deleted.

src/main/java/cloud/eppo/Utils.java

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package cloud.eppo;
22

33
import com.fasterxml.jackson.databind.JsonNode;
4-
import java.math.BigInteger;
54
import java.nio.charset.StandardCharsets;
65
import java.security.MessageDigest;
76
import java.security.NoSuchAlgorithmException;
@@ -16,28 +15,60 @@
1615
public final class Utils {
1716
private static final SimpleDateFormat UTC_ISO_DATE_FORMAT = buildUtcIsoDateFormat();
1817
private static final Logger log = LoggerFactory.getLogger(Utils.class);
18+
private static final MessageDigest md = buildMd5MessageDigest();
19+
20+
private static MessageDigest buildMd5MessageDigest() {
21+
try {
22+
return MessageDigest.getInstance("MD5");
23+
} catch (NoSuchAlgorithmException e) {
24+
throw new RuntimeException("Error computing md5 hash", e);
25+
}
26+
}
1927

2028
public static void throwIfEmptyOrNull(String input, String errorMessage) {
2129
if (input == null || input.isEmpty()) {
2230
throw new IllegalArgumentException(errorMessage);
2331
}
2432
}
2533

34+
/**
35+
* Return the String representation of the zero-padded hexadecimal hash of a string input This is
36+
* useful for comparing against other string hashes, such as obfuscated flag names.
37+
*/
2638
public static String getMD5Hex(String input) {
27-
MessageDigest md;
28-
try {
29-
md = MessageDigest.getInstance("MD5");
30-
} catch (NoSuchAlgorithmException e) {
31-
throw new RuntimeException("Error computing md5 hash", e);
39+
// md5 the input
40+
md.reset();
41+
byte[] md5Bytes = md.digest(input.getBytes());
42+
// Pre-allocate a StringBuilder with a capacity of 32 characters
43+
StringBuilder hexString = new StringBuilder(32);
44+
45+
for (byte b : md5Bytes) {
46+
// Append the two hex digits corresponding to the byte
47+
hexString.append(Character.forDigit((b >> 4) & 0xF, 16));
48+
hexString.append(Character.forDigit(b & 0xF, 16));
3249
}
33-
byte[] messageDigest = md.digest(input.getBytes());
34-
BigInteger no = new BigInteger(1, messageDigest);
35-
StringBuilder hashText = new StringBuilder(no.toString(16));
36-
while (hashText.length() < 32) {
37-
hashText.insert(0, "0");
50+
51+
return hexString.toString();
52+
}
53+
54+
/**
55+
* Return a deterministic pseudo-random integer based on the input that falls between 0
56+
* (inclusive) and a max value (exclusive) This is useful for randomly bucketing subjects or
57+
* shuffling bandit actions
58+
*/
59+
public static int getShard(String input, int maxShardValue) {
60+
// md5 the input
61+
md.reset();
62+
byte[] md5Bytes = md.digest(input.getBytes());
63+
64+
// Extract the first 4 bytes (8 digits) and convert to a long
65+
long value = 0;
66+
for (int i = 0; i < 4; i++) {
67+
value = (value << 8) | (md5Bytes[i] & 0xFF);
3868
}
3969

40-
return hashText.toString();
70+
// Modulo into the shard space
71+
return (int) (value % maxShardValue);
4172
}
4273

4374
public static Date parseUtcISODateNode(JsonNode isoDateStringElement) {
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package cloud.eppo;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertTrue;
5+
6+
import cloud.eppo.logging.Assignment;
7+
import cloud.eppo.logging.AssignmentLogger;
8+
import cloud.eppo.ufc.dto.Attributes;
9+
import java.lang.management.ManagementFactory;
10+
import java.lang.management.ThreadMXBean;
11+
import java.util.HashMap;
12+
import java.util.Map;
13+
import java.util.concurrent.atomic.AtomicInteger;
14+
import org.junit.jupiter.api.BeforeAll;
15+
import org.junit.jupiter.api.Test;
16+
import org.slf4j.Logger;
17+
import org.slf4j.LoggerFactory;
18+
19+
public class ProfileBaseEppoClientTest {
20+
private static final Logger log = LoggerFactory.getLogger(ProfileBaseEppoClientTest.class);
21+
22+
private static final String DUMMY_FLAG_API_KEY = "dummy-flags-api-key"; // Will load flags-v1
23+
private static final String TEST_HOST =
24+
"https://us-central1-eppo-qa.cloudfunctions.net/serveGitHubRacTestFile";
25+
26+
private static BaseEppoClient eppoClient;
27+
private static final AssignmentLogger noOpAssignmentLogger =
28+
new AssignmentLogger() {
29+
@Override
30+
public void logAssignment(Assignment assignment) {
31+
/* no-op */
32+
}
33+
};
34+
35+
@BeforeAll
36+
public static void initClient() {
37+
eppoClient =
38+
new BaseEppoClient(
39+
DUMMY_FLAG_API_KEY,
40+
"java",
41+
"3.0.0",
42+
TEST_HOST,
43+
noOpAssignmentLogger,
44+
null,
45+
false,
46+
false);
47+
48+
eppoClient.loadConfiguration();
49+
50+
log.info("Test client initialized");
51+
}
52+
53+
@Test
54+
public void testGetStringAssignmentPerformance() {
55+
Map<String, AtomicInteger> variationCounts = new HashMap<>();
56+
Attributes subjectAttributes = new Attributes();
57+
subjectAttributes.put("country", "FR");
58+
59+
ThreadMXBean threadBean = ManagementFactory.getThreadMXBean();
60+
long startTime = threadBean.getCurrentThreadCpuTime();
61+
62+
int numIterations = 10000;
63+
64+
for (int i = 0; i < numIterations; i++) {
65+
String subjectKey = "subject" + i;
66+
String assignedVariation =
67+
eppoClient.getStringAssignment(
68+
"new-user-onboarding", subjectKey, subjectAttributes, "default");
69+
AtomicInteger existingCount =
70+
variationCounts.computeIfAbsent(assignedVariation, k -> new AtomicInteger(0));
71+
existingCount.incrementAndGet();
72+
}
73+
long endTime = threadBean.getCurrentThreadCpuTime();
74+
long elapsedTime = endTime - startTime;
75+
76+
log.info("Assignment counts: {}", variationCounts);
77+
log.info("CPU Time: {}", elapsedTime);
78+
79+
// Assert assignments shook out as expected based the shard ranges
80+
assertEquals(4, variationCounts.keySet().size());
81+
// Expect ~40% default
82+
assertEquals(0.40, variationCounts.get("default").doubleValue() / numIterations, 0.02);
83+
// Expect ~30% control (50% of 60%)
84+
assertEquals(0.30, variationCounts.get("control").doubleValue() / numIterations, 0.02);
85+
// Expect ~18% red (30% of 60%)
86+
assertEquals(0.18, variationCounts.get("red").doubleValue() / numIterations, 0.02);
87+
// Expect ~12% yellow (20% of 60%)
88+
assertEquals(0.12, variationCounts.get("yellow").doubleValue() / numIterations, 0.02);
89+
90+
// Seeing ~48,000,000 - ~54,000,000 for 10k iterations on a M2 Macbook Pro; let's fail if more
91+
// than 150,000,000; giving a generous allowance for slower systems (like GitHub) but will still
92+
// catch if things slow down considerably
93+
long maxAllowedTime = 15000 * numIterations;
94+
assertTrue(
95+
elapsedTime < maxAllowedTime,
96+
"Cpu time of " + elapsedTime + " is more than the " + maxAllowedTime + " allowed");
97+
}
98+
}

src/test/java/cloud/eppo/UtilsTest.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package cloud.eppo;
22

3-
import static cloud.eppo.Utils.parseUtcISODateNode;
3+
import static cloud.eppo.Utils.*;
44
import static org.junit.jupiter.api.Assertions.assertEquals;
55
import static org.junit.jupiter.api.Assertions.assertNull;
66

@@ -11,6 +11,34 @@
1111
import org.junit.jupiter.api.Test;
1212

1313
public class UtilsTest {
14+
@Test
15+
public void testGetMd5Hash() {
16+
// empty string
17+
assertEquals("d41d8cd98f00b204e9800998ecf8427e", getMD5Hex(""));
18+
// leading zero
19+
assertEquals("0212de0d90804f17d2b7bab512cd2f0f", getMD5Hex("input-59"));
20+
// zero first byte
21+
assertEquals("00dd33988da4202fc1990a4dfa7ee18b", getMD5Hex("input-411"));
22+
// zero middle byte
23+
assertEquals("448614887a99f16179b400cfccceb72d", getMD5Hex("input-62"));
24+
// zero last byte
25+
assertEquals("429fb7196ccb2978443a0de8da180e00", getMD5Hex("input-34"));
26+
}
27+
28+
@Test
29+
public void testGetShard() {
30+
// Shard is the first 8 digits read as a number and modulo into the space
31+
int computedShard = (int) (Long.parseLong(getMD5Hex("shard me").substring(0, 8), 16) % 10000);
32+
int shardFromGetShard = getShard("shard me", 10000);
33+
assertEquals(computedShard, shardFromGetShard);
34+
35+
// Total shards is respected
36+
assertEquals(538, getShard("shard me", 10000));
37+
assertEquals(538, getShard("shard me", 1000));
38+
assertEquals(38, getShard("shard me", 100));
39+
assertEquals(8, getShard("shard me", 10));
40+
}
41+
1442
@Test
1543
public void testParseUtcISODateNode() throws JsonProcessingException {
1644
ObjectMapper mapper = new ObjectMapper();

0 commit comments

Comments
 (0)