Skip to content

Commit 8ee1512

Browse files
authored
Merge pull request #475 from IABTechLab/aul-UID2-5349-logging-salt-stats
Logging salt age stats on rotation
2 parents 6bfb4c1 + 8dbf326 commit 8ee1512

File tree

11 files changed

+886
-609
lines changed

11 files changed

+886
-609
lines changed

src/main/java/com/uid2/admin/Main.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import com.uid2.admin.legacy.RotatingLegacyClientKeyProvider;
1515
import com.uid2.admin.managers.KeysetManager;
1616
import com.uid2.admin.monitoring.DataStoreMetrics;
17+
import com.uid2.admin.salt.SaltRotation;
1718
import com.uid2.admin.secret.*;
1819
import com.uid2.admin.store.*;
1920
import com.uid2.admin.store.reader.RotatingAdminKeysetStore;
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
package com.uid2.admin.salt;
2+
3+
import com.uid2.admin.AdminConst;
4+
import com.uid2.shared.model.SaltEntry;
5+
import com.uid2.shared.secret.IKeyGenerator;
6+
7+
import com.uid2.shared.store.salt.RotatingSaltProvider.SaltSnapshot;
8+
import io.vertx.core.json.JsonObject;
9+
import lombok.Getter;
10+
import org.slf4j.Logger;
11+
import org.slf4j.LoggerFactory;
12+
13+
import java.time.*;
14+
import java.time.temporal.ChronoUnit;
15+
import java.util.*;
16+
import java.util.stream.Collectors;
17+
18+
public class SaltRotation {
19+
private final static long THIRTY_DAYS_IN_MS = Duration.ofDays(30).toMillis();
20+
21+
private final IKeyGenerator keyGenerator;
22+
private final boolean isRefreshFromEnabled;
23+
private static final Logger LOGGER = LoggerFactory.getLogger(SaltRotation.class);
24+
25+
public SaltRotation(JsonObject config, IKeyGenerator keyGenerator) {
26+
this.keyGenerator = keyGenerator;
27+
this.isRefreshFromEnabled = config.getBoolean(AdminConst.ENABLE_SALT_ROTATION_REFRESH_FROM, false);
28+
}
29+
30+
public Result rotateSalts(
31+
SaltSnapshot lastSnapshot,
32+
Duration[] minAges,
33+
double fraction,
34+
TargetDate targetDate
35+
) throws Exception {
36+
var preRotationSalts = lastSnapshot.getAllRotatingSalts();
37+
var nextEffective = targetDate.asInstant();
38+
var nextExpires = nextEffective.plus(7, ChronoUnit.DAYS);
39+
if (nextEffective.equals(lastSnapshot.getEffective()) || nextEffective.isBefore(lastSnapshot.getEffective())) {
40+
return Result.noSnapshot("cannot create a new salt snapshot with effective timestamp equal or prior to that of an existing snapshot");
41+
}
42+
43+
// Salts that can be rotated based on their refreshFrom being at target date
44+
var refreshableSalts = findRefreshableSalts(preRotationSalts, targetDate);
45+
46+
var saltsToRotate = pickSaltsToRotate(
47+
refreshableSalts,
48+
targetDate,
49+
minAges,
50+
getNumSaltsToRotate(preRotationSalts, fraction)
51+
);
52+
53+
if (saltsToRotate.isEmpty()) {
54+
return Result.noSnapshot("all refreshable salts are below min rotation age");
55+
}
56+
57+
var postRotationSalts = rotateSalts(preRotationSalts, saltsToRotate, targetDate);
58+
59+
logSaltAges("refreshable-salts", targetDate, refreshableSalts);
60+
logSaltAges("rotated-salts", targetDate, saltsToRotate);
61+
logSaltAges("total-salts", targetDate, Arrays.asList(postRotationSalts));
62+
63+
var nextSnapshot = new SaltSnapshot(
64+
nextEffective,
65+
nextExpires,
66+
postRotationSalts,
67+
lastSnapshot.getFirstLevelSalt());
68+
return Result.fromSnapshot(nextSnapshot);
69+
}
70+
71+
private static int getNumSaltsToRotate(SaltEntry[] preRotationSalts, double fraction) {
72+
return (int) Math.ceil(preRotationSalts.length * fraction);
73+
}
74+
75+
private Set<SaltEntry> findRefreshableSalts(SaltEntry[] preRotationSalts, TargetDate targetDate) {
76+
return Arrays.stream(preRotationSalts).filter(s -> isRefreshable(targetDate, s)).collect(Collectors.toSet());
77+
}
78+
79+
private boolean isRefreshable(TargetDate targetDate, SaltEntry salt) {
80+
if (this.isRefreshFromEnabled) {
81+
return salt.refreshFrom().equals(targetDate.asEpochMs());
82+
}
83+
84+
return true;
85+
}
86+
87+
private SaltEntry[] rotateSalts(SaltEntry[] oldSalts, List<SaltEntry> saltsToRotate, TargetDate targetDate) throws Exception {
88+
var saltIdsToRotate = saltsToRotate.stream().map(SaltEntry::id).collect(Collectors.toSet());
89+
90+
var updatedSalts = new SaltEntry[oldSalts.length];
91+
for (int i = 0; i < oldSalts.length; i++) {
92+
var shouldRotate = saltIdsToRotate.contains(oldSalts[i].id());
93+
updatedSalts[i] = updateSalt(oldSalts[i], targetDate, shouldRotate);
94+
}
95+
return updatedSalts;
96+
}
97+
98+
private SaltEntry updateSalt(SaltEntry oldSalt, TargetDate targetDate, boolean shouldRotate) throws Exception {
99+
var currentSalt = shouldRotate ? this.keyGenerator.generateRandomKeyString(32) : oldSalt.currentSalt();
100+
var lastUpdated = shouldRotate ? targetDate.asEpochMs() : oldSalt.lastUpdated();
101+
var refreshFrom = calculateRefreshFrom(oldSalt, targetDate);
102+
var previousSalt = calculatePreviousSalt(oldSalt, shouldRotate, targetDate);
103+
104+
return new SaltEntry(
105+
oldSalt.id(),
106+
oldSalt.hashedId(),
107+
lastUpdated,
108+
currentSalt,
109+
refreshFrom,
110+
previousSalt,
111+
null,
112+
null
113+
);
114+
}
115+
116+
private long calculateRefreshFrom(SaltEntry salt, TargetDate targetDate) {
117+
long multiplier = targetDate.saltAgeInDays(salt) / 30 + 1;
118+
return salt.lastUpdated() + (multiplier * THIRTY_DAYS_IN_MS);
119+
}
120+
121+
private String calculatePreviousSalt(SaltEntry salt, boolean shouldRotate, TargetDate targetDate) {
122+
if (shouldRotate) {
123+
return salt.currentSalt();
124+
}
125+
if (targetDate.saltAgeInDays(salt) < 90) {
126+
return salt.previousSalt();
127+
}
128+
return null;
129+
}
130+
131+
private List<SaltEntry> pickSaltsToRotate(
132+
Set<SaltEntry> refreshableSalts,
133+
TargetDate targetDate,
134+
Duration[] minAges,
135+
int numSaltsToRotate
136+
) {
137+
var thresholds = Arrays.stream(minAges)
138+
.map(minAge -> targetDate.asInstant().minusSeconds(minAge.getSeconds()))
139+
.sorted()
140+
.toArray(Instant[]::new);
141+
var indexesToRotate = new ArrayList<SaltEntry>();
142+
143+
var minLastUpdated = Instant.ofEpochMilli(0);
144+
for (var maxLastUpdated : thresholds) {
145+
if (indexesToRotate.size() >= numSaltsToRotate) break;
146+
147+
var maxIndexes = numSaltsToRotate - indexesToRotate.size();
148+
var saltsToRotate = pickSaltsToRotateInTimeWindow(
149+
refreshableSalts,
150+
maxIndexes,
151+
minLastUpdated.toEpochMilli(),
152+
maxLastUpdated.toEpochMilli()
153+
);
154+
indexesToRotate.addAll(saltsToRotate);
155+
minLastUpdated = maxLastUpdated;
156+
}
157+
return indexesToRotate;
158+
}
159+
160+
private List<SaltEntry> pickSaltsToRotateInTimeWindow(
161+
Set<SaltEntry> refreshableSalts,
162+
int maxIndexes,
163+
long minLastUpdated,
164+
long maxLastUpdated
165+
) {
166+
ArrayList<SaltEntry> candidateSalts = refreshableSalts.stream()
167+
.filter(salt -> minLastUpdated <= salt.lastUpdated() && salt.lastUpdated() < maxLastUpdated)
168+
.collect(Collectors.toCollection(ArrayList::new));
169+
170+
if (candidateSalts.size() <= maxIndexes) {
171+
return candidateSalts;
172+
}
173+
174+
Collections.shuffle(candidateSalts);
175+
176+
return candidateSalts.stream().limit(maxIndexes).collect(Collectors.toList());
177+
}
178+
179+
private void logSaltAges(String saltCountType, TargetDate targetDate, Collection<SaltEntry> salts) {
180+
var ages = new HashMap<Long, Long>(); // salt age to count
181+
for (var salt : salts) {
182+
long ageInDays = targetDate.saltAgeInDays(salt);
183+
ages.put(ageInDays, ages.getOrDefault(ageInDays, 0L) + 1);
184+
}
185+
186+
for (var entry : ages.entrySet()) {
187+
LOGGER.info("salt-count-type={} target-date={} age={} salt-count={}",
188+
saltCountType,
189+
targetDate,
190+
entry.getKey(),
191+
entry.getValue()
192+
);
193+
}
194+
}
195+
196+
@Getter
197+
public static class Result {
198+
private final SaltSnapshot snapshot; // can be null if new snapshot is not needed
199+
private final String reason; // why you are not getting a new snapshot
200+
201+
private Result(SaltSnapshot snapshot, String reason) {
202+
this.snapshot = snapshot;
203+
this.reason = reason;
204+
}
205+
206+
public boolean hasSnapshot() {
207+
return snapshot != null;
208+
}
209+
210+
public static Result fromSnapshot(SaltSnapshot snapshot) {
211+
return new Result(snapshot, null);
212+
}
213+
214+
public static Result noSnapshot(String reason) {
215+
return new Result(null, reason);
216+
}
217+
}
218+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package com.uid2.admin.salt;
2+
3+
import com.uid2.shared.model.SaltEntry;
4+
5+
import java.time.*;
6+
import java.time.format.DateTimeFormatter;
7+
import java.util.Objects;
8+
9+
public class TargetDate {
10+
private final static long DAY_IN_MS = Duration.ofDays(1).toMillis();
11+
12+
private final LocalDate date;
13+
private final long epochMs;
14+
private final Instant instant;
15+
private final String formatted;
16+
17+
public TargetDate(LocalDate date) {
18+
this.instant = date.atStartOfDay().toInstant(ZoneOffset.UTC);
19+
this.date = date;
20+
this.epochMs = instant.toEpochMilli();
21+
this.formatted = date.format(DateTimeFormatter.ofPattern("yyyy-MM-dd"));
22+
}
23+
24+
public static TargetDate now() {
25+
return new TargetDate(LocalDate.now(Clock.systemUTC()));
26+
}
27+
28+
public static TargetDate of(int year, int month, int day) {
29+
return new TargetDate(LocalDate.of(year, month, day));
30+
}
31+
32+
public long asEpochMs() {
33+
return epochMs;
34+
}
35+
36+
public Instant asInstant() {
37+
return instant;
38+
}
39+
40+
// relative to this date
41+
public long saltAgeInDays(SaltEntry salt) {
42+
return (this.asEpochMs() - salt.lastUpdated()) / DAY_IN_MS;
43+
}
44+
45+
public TargetDate plusDays(int days) {
46+
return new TargetDate(date.plusDays(days));
47+
}
48+
49+
public TargetDate minusDays(int days) {
50+
return new TargetDate(date.minusDays(days));
51+
}
52+
53+
@Override
54+
public String toString() {
55+
return formatted;
56+
}
57+
58+
@Override
59+
public boolean equals(Object o) {
60+
if (o == null || getClass() != o.getClass()) return false;
61+
TargetDate that = (TargetDate) o;
62+
return epochMs == that.epochMs;
63+
}
64+
65+
@Override
66+
public int hashCode() {
67+
return Objects.hashCode(epochMs);
68+
}
69+
}

0 commit comments

Comments
 (0)