Skip to content

Commit 7dbead3

Browse files
committed
addressing feedback
1 parent 6229778 commit 7dbead3

File tree

4 files changed

+181
-172
lines changed

4 files changed

+181
-172
lines changed

src/main/java/com/uid2/admin/secret/SaltRotation.java

Lines changed: 112 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import java.time.format.DateTimeFormatter;
1818
import java.time.temporal.ChronoUnit;
1919
import java.util.*;
20+
import java.util.stream.Collectors;
2021

2122
public class SaltRotation {
2223
private final static long THIRTY_DAYS_IN_MS = Duration.ofDays(30).toMillis();
23-
private final static long DAY_IN_MS = Duration.ofDays(1).toMillis();
2424

2525
private final IKeyGenerator keyGenerator;
2626
private final boolean isRefreshFromEnabled;
@@ -35,38 +35,34 @@ public Result rotateSalts(
3535
SaltSnapshot lastSnapshot,
3636
Duration[] minAges,
3737
double fraction,
38-
LocalDate targetDate
38+
TargetDate targetDate
3939
) throws Exception {
4040
var preRotationSalts = lastSnapshot.getAllRotatingSalts();
41-
var nextEffective = targetDate.atStartOfDay().toInstant(ZoneOffset.UTC);
41+
var nextEffective = targetDate.asInstant();
4242
var nextExpires = nextEffective.plus(7, ChronoUnit.DAYS);
4343
if (nextEffective.equals(lastSnapshot.getEffective()) || nextEffective.isBefore(lastSnapshot.getEffective())) {
4444
return Result.noSnapshot("cannot create a new salt snapshot with effective timestamp equal or prior to that of an existing snapshot");
4545
}
4646

47-
var rotatableSaltIndexes = findRotatableSaltIndexes(preRotationSalts, nextEffective.toEpochMilli());
48-
var saltIndexesToRotate = pickSaltIndexesToRotate(
49-
nextEffective,
47+
// Salts that can be rotated based on their refreshFrom being at target date
48+
var rotatableSalts = findRotatableSalts(preRotationSalts, targetDate);
49+
50+
var saltsToRotate = pickSaltsToRotate(
51+
rotatableSalts,
52+
targetDate,
5053
minAges,
51-
fraction,
52-
preRotationSalts,
53-
rotatableSaltIndexes
54+
getNumSaltsToRotate(preRotationSalts, fraction)
5455
);
5556

56-
if (saltIndexesToRotate.isEmpty()) {
57+
if (saltsToRotate.isEmpty()) {
5758
return Result.noSnapshot("all rotatable salts are below min rotation age");
5859
}
5960

60-
var postRotationSalts = updateSalts(preRotationSalts, saltIndexesToRotate, nextEffective.toEpochMilli());
61+
var postRotationSalts = rotateSalts(preRotationSalts, saltsToRotate, targetDate);
6162

62-
logSaltCounts(
63-
targetDate,
64-
nextEffective,
65-
preRotationSalts,
66-
postRotationSalts,
67-
rotatableSaltIndexes,
68-
saltIndexesToRotate
69-
);
63+
logSaltAgeCounts("rotatable-salts", targetDate, rotatableSalts);
64+
logSaltAgeCounts("rotated-salts", targetDate, saltsToRotate);
65+
logSaltAgeCounts("total-salts", targetDate, Arrays.asList(postRotationSalts));
7066

7167
var nextSnapshot = new SaltSnapshot(
7268
nextEffective,
@@ -76,31 +72,38 @@ public Result rotateSalts(
7672
return Result.fromSnapshot(nextSnapshot);
7773
}
7874

79-
private List<Integer> findRotatableSaltIndexes(SaltEntry[] preRotationSalts, long nextEffective) {
80-
var rotatableSalts = new ArrayList<Integer>();
81-
for (int i = 0; i < preRotationSalts.length; i++) {
82-
if (isRotatable(nextEffective, preRotationSalts[i])) {
83-
rotatableSalts.add(i);
84-
}
75+
private static int getNumSaltsToRotate(SaltEntry[] preRotationSalts, double fraction) {
76+
return (int) Math.ceil(preRotationSalts.length * fraction);
77+
}
78+
79+
private Set<SaltEntry> findRotatableSalts(SaltEntry[] preRotationSalts, TargetDate targetDate) {
80+
return Arrays.stream(preRotationSalts).filter(s -> isRotatable(targetDate, s)).collect(Collectors.toSet());
81+
}
82+
83+
private boolean isRotatable(TargetDate targetDate, SaltEntry salt) {
84+
if (this.isRefreshFromEnabled) {
85+
return salt.refreshFrom().equals(targetDate.asEpochMs());
8586
}
86-
return rotatableSalts;
87+
88+
return true;
8789
}
8890

89-
private SaltEntry[] updateSalts(SaltEntry[] oldSalts, List<Integer> saltIndexesToRotate, long nextEffective) throws Exception {
90-
var updatedSalts = new SaltEntry[oldSalts.length];
91+
private SaltEntry[] rotateSalts(SaltEntry[] oldSalts, List<SaltEntry> saltsToRotate, TargetDate targetDate) throws Exception {
92+
var saltIdsToRotate = saltsToRotate.stream().map(SaltEntry::id).collect(Collectors.toSet());
9193

94+
var updatedSalts = new SaltEntry[oldSalts.length];
9295
for (int i = 0; i < oldSalts.length; i++) {
93-
var shouldRotate = saltIndexesToRotate.contains(i);
94-
updatedSalts[i] = updateSalt(oldSalts[i], shouldRotate, nextEffective);
96+
var shouldRotate = saltIdsToRotate.contains(oldSalts[i].id());
97+
updatedSalts[i] = updateSalt(oldSalts[i], targetDate, shouldRotate);
9598
}
9699
return updatedSalts;
97100
}
98101

99-
private SaltEntry updateSalt(SaltEntry oldSalt, boolean shouldRotate, long nextEffective) throws Exception {
102+
private SaltEntry updateSalt(SaltEntry oldSalt, TargetDate targetDate, boolean shouldRotate) throws Exception {
100103
var currentSalt = shouldRotate ? this.keyGenerator.generateRandomKeyString(32) : oldSalt.currentSalt();
101-
var lastUpdated = shouldRotate ? nextEffective : oldSalt.lastUpdated();
102-
var refreshFrom = calculateRefreshFrom(oldSalt.lastUpdated(), nextEffective);
103-
var previousSalt = calculatePreviousSalt(oldSalt, shouldRotate, nextEffective);
104+
var lastUpdated = shouldRotate ? targetDate.asEpochMs() : oldSalt.lastUpdated();
105+
var refreshFrom = calculateRefreshFrom(oldSalt, targetDate);
106+
var previousSalt = calculatePreviousSalt(oldSalt, shouldRotate, targetDate);
104107

105108
return new SaltEntry(
106109
oldSalt.id(),
@@ -114,132 +117,133 @@ private SaltEntry updateSalt(SaltEntry oldSalt, boolean shouldRotate, long nextE
114117
);
115118
}
116119

117-
private long calculateRefreshFrom(long lastUpdated, long nextEffective) {
118-
long age = nextEffective - lastUpdated;
119-
long multiplier = age / THIRTY_DAYS_IN_MS + 1;
120-
return lastUpdated + (multiplier * THIRTY_DAYS_IN_MS);
120+
private long calculateRefreshFrom(SaltEntry salt, TargetDate targetDate) {
121+
long multiplier = targetDate.ageOfSaltInMs(salt) / THIRTY_DAYS_IN_MS + 1;
122+
return salt.lastUpdated() + (multiplier * THIRTY_DAYS_IN_MS);
121123
}
122124

123-
private String calculatePreviousSalt(SaltEntry salt, boolean shouldRotate, long nextEffective) throws Exception {
125+
private String calculatePreviousSalt(SaltEntry salt, boolean shouldRotate, TargetDate targetDate) {
124126
if (shouldRotate) {
125127
return salt.currentSalt();
126128
}
127-
long age = nextEffective - salt.lastUpdated();
128-
if (age / DAY_IN_MS < 90) {
129+
if (targetDate.ageOfSaltInDays(salt) < 90) {
129130
return salt.previousSalt();
130131
}
131132
return null;
132133
}
133134

134-
private List<Integer> pickSaltIndexesToRotate(
135-
Instant nextEffective,
135+
private List<SaltEntry> pickSaltsToRotate(
136+
Set<SaltEntry> rotatableSalts,
137+
TargetDate targetDate,
136138
Duration[] minAges,
137-
double fraction,
138-
SaltEntry[] preRotationSalts,
139-
List<Integer> rotatableSaltIndexes
139+
int numSaltsToRotate
140140
) {
141141
var thresholds = Arrays.stream(minAges)
142-
.map(age -> nextEffective.minusSeconds(age.getSeconds()))
142+
.map(minAge -> targetDate.asInstant().minusSeconds(minAge.getSeconds()))
143143
.sorted()
144144
.toArray(Instant[]::new);
145-
var maxSalts = (int) Math.ceil(preRotationSalts.length * fraction);
146-
var indexesToRotate = new ArrayList<Integer>();
145+
var indexesToRotate = new ArrayList<SaltEntry>();
147146

148147
var minLastUpdated = Instant.ofEpochMilli(0);
149148
for (var maxLastUpdated : thresholds) {
150-
if (indexesToRotate.size() >= maxSalts) break;
149+
if (indexesToRotate.size() >= numSaltsToRotate) break;
151150

152-
var maxIndexes = maxSalts - indexesToRotate.size();
153-
var saltsToRotate = selectIndexesToRotate(
154-
preRotationSalts,
155-
minLastUpdated.toEpochMilli(),
156-
maxLastUpdated.toEpochMilli(),
151+
var maxIndexes = numSaltsToRotate - indexesToRotate.size();
152+
var saltsToRotate = pickSaltsToRotateInTimeWindow(
153+
rotatableSalts,
157154
maxIndexes,
158-
rotatableSaltIndexes
155+
minLastUpdated.toEpochMilli(),
156+
maxLastUpdated.toEpochMilli()
159157
);
160158
indexesToRotate.addAll(saltsToRotate);
161159
minLastUpdated = maxLastUpdated;
162160
}
163161
return indexesToRotate;
164162
}
165163

166-
private List<Integer> selectIndexesToRotate(
167-
SaltEntry[] salts,
168-
long minLastUpdated,
169-
long maxLastUpdated,
164+
private List<SaltEntry> pickSaltsToRotateInTimeWindow(
165+
Set<SaltEntry> rotatableSalts,
170166
int maxIndexes,
171-
List<Integer> rotatableSaltIndexes
172-
) {
173-
var candidateIndexes = indexesForRotation(salts, minLastUpdated, maxLastUpdated, rotatableSaltIndexes);
174-
175-
if (candidateIndexes.size() <= maxIndexes) {
176-
return candidateIndexes;
177-
}
178-
Collections.shuffle(candidateIndexes);
179-
return candidateIndexes.subList(0, Math.min(maxIndexes, candidateIndexes.size()));
180-
}
181-
182-
private List<Integer> indexesForRotation(
183-
SaltEntry[] salts,
184167
long minLastUpdated,
185-
long maxLastUpdated,
186-
List<Integer> rotatableSaltIndexes
168+
long maxLastUpdated
187169
) {
188-
var candidateIndexes = new ArrayList<Integer>();
189-
for (int i = 0; i < salts.length; i++) {
190-
var salt = salts[i];
170+
var candidateSalts = new ArrayList<SaltEntry>();
171+
for (SaltEntry salt : rotatableSalts) {
191172
var lastUpdated = salt.lastUpdated();
192173
var isInTimeWindow = minLastUpdated <= lastUpdated && lastUpdated < maxLastUpdated;
193-
var isRotatable = rotatableSaltIndexes.contains(i);
194174

195-
if (isInTimeWindow && isRotatable) {
196-
candidateIndexes.add(i);
175+
if (isInTimeWindow) {
176+
candidateSalts.add(salt);
197177
}
198178
}
199-
return candidateIndexes;
200-
}
201179

202-
private boolean isRotatable(long nextEffective, SaltEntry salt) {
203-
if (this.isRefreshFromEnabled) {
204-
if (salt.refreshFrom() == null) { // TODO: remove once refreshFrom is no longer optional
205-
return true;
206-
}
207-
return salt.refreshFrom() == nextEffective;
180+
if (candidateSalts.size() <= maxIndexes) {
181+
return candidateSalts;
208182
}
209183

210-
return true;
184+
Collections.shuffle(candidateSalts);
185+
return candidateSalts.subList(0, Math.min(maxIndexes, candidateSalts.size()));
211186
}
212187

213-
private void logSaltAgeCounts(String logEvent, LocalDate targetDate, Instant nextEffective, SaltEntry[] salts) {
214-
var formattedDate = DateTimeFormatter.ofPattern("yyyy-MM-dd").format(targetDate);
215-
188+
private void logSaltAgeCounts(String saltCountType, TargetDate targetDate, Collection<SaltEntry> salts) {
216189
var ages = new HashMap<Long, Long>(); // salt age to count
217190
for (var salt : salts) {
218-
long age = (nextEffective.toEpochMilli() - salt.lastUpdated()) / DAY_IN_MS;
219-
ages.put(age, ages.getOrDefault(age, 0L) + 1);
191+
long ageInDays = targetDate.ageOfSaltInDays(salt);
192+
ages.put(ageInDays, ages.getOrDefault(ageInDays, 0L) + 1);
220193
}
221194

222195
for (var entry : ages.entrySet()) {
223-
LOGGER.info("{} target-date={} age={} salts={}", logEvent, formattedDate, entry.getKey(), entry.getValue());
196+
LOGGER.info("salt-count-type={} target-date={} age={} salt-count={}", saltCountType, targetDate, entry.getKey(), entry.getValue());
224197
}
225198
}
226199

227-
private static SaltEntry[] onlySaltsAtIndexes(SaltEntry[] salts, List<Integer> saltIndexes) {
228-
SaltEntry[] selected = new SaltEntry[saltIndexes.size()];
229-
for (int i = 0; i < saltIndexes.size(); i++) {
230-
selected[i] = salts[saltIndexes.get(i)];
200+
public static class TargetDate {
201+
private final static long DAY_IN_MS = Duration.ofDays(1).toMillis();
202+
203+
private final LocalDate date;
204+
private final long epochMs;
205+
private final Instant instant;
206+
private final String formatted;
207+
208+
public TargetDate(LocalDate date) {
209+
this.instant = date.atStartOfDay().toInstant(ZoneOffset.UTC);
210+
this.date = date;
211+
this.epochMs = instant.toEpochMilli();
212+
this.formatted = date.format(DateTimeFormatter.ofPattern("yyyy-MM-dd"));
231213
}
232-
return selected;
233-
}
234214

235-
private void logSaltCounts(LocalDate targetDate, Instant nextEffective, SaltEntry[] preRotationSalts, SaltEntry[] postRotationSalts, List<Integer> rotatableSaltIndexes, List<Integer> rotatedSaltIndexes) {
236-
var rotatableSalts = onlySaltsAtIndexes(preRotationSalts, rotatableSaltIndexes);
237-
logSaltAgeCounts("rotatable-salts", targetDate, nextEffective, rotatableSalts);
215+
public static TargetDate of(int year, int month, int day) {
216+
return new TargetDate(LocalDate.of(year, month, day));
217+
}
238218

239-
var rotatedSalts = onlySaltsAtIndexes(preRotationSalts, rotatedSaltIndexes);
240-
logSaltAgeCounts("rotated-salts", targetDate, nextEffective, rotatedSalts);
219+
public long asEpochMs() {
220+
return epochMs;
221+
}
241222

242-
logSaltAgeCounts("total-salts", targetDate, nextEffective, postRotationSalts);
223+
public Instant asInstant() {
224+
return instant;
225+
}
226+
227+
public long ageOfSaltInMs(SaltEntry salt) {
228+
return this.asEpochMs() - salt.lastUpdated();
229+
}
230+
231+
public long ageOfSaltInDays(SaltEntry salt) {
232+
return ageOfSaltInMs(salt) / DAY_IN_MS;
233+
}
234+
235+
public TargetDate plusDays(int days) {
236+
return new TargetDate(date.plusDays(days));
237+
}
238+
239+
public TargetDate minusDays(int days) {
240+
return new TargetDate(date.minusDays(days));
241+
}
242+
243+
@Override
244+
public String toString() {
245+
return formatted;
246+
}
243247
}
244248

245249
@Getter

src/main/java/com/uid2/admin/vertx/service/SaltService.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public SaltService(AdminAuthMiddleware auth,
4747
@Override
4848
public void setupRoutes(Router router) {
4949
router.get("/api/salt/snapshots").handler(
50-
auth.handle(this::handleSaltSnapshots, Role.MAINTAINER));
50+
auth.handle(this::handleSaltSnapshots, Role.MAINTAINER));
5151

5252
router.post("/api/salt/rotate").blockingHandler(auth.handle((ctx) -> {
5353
synchronized (writeLock) {
@@ -77,8 +77,10 @@ private void handleSaltRotate(RoutingContext rc) {
7777
if (!fraction.isPresent()) return;
7878
final Duration[] minAges = RequestUtil.getDurations(rc, "min_ages_in_seconds");
7979
if (minAges == null) return;
80-
final LocalDate targetDate = RequestUtil.getDate(rc, "target_date", DateTimeFormatter.ISO_LOCAL_DATE)
81-
.orElse(LocalDate.now(Clock.systemUTC()).plusDays(1));
80+
final SaltRotation.TargetDate targetDate = new SaltRotation.TargetDate(
81+
RequestUtil.getDate(rc, "target_date", DateTimeFormatter.ISO_LOCAL_DATE)
82+
.orElse(LocalDate.now(Clock.systemUTC()).plusDays(1))
83+
);
8284

8385
// force refresh
8486
this.saltProvider.loadContent();
@@ -87,7 +89,7 @@ private void handleSaltRotate(RoutingContext rc) {
8789
storageManager.archiveSaltLocations();
8890

8991
final List<RotatingSaltProvider.SaltSnapshot> snapshots = this.saltProvider.getSnapshots();
90-
final RotatingSaltProvider.SaltSnapshot lastSnapshot = snapshots.get(snapshots.size() - 1);
92+
final RotatingSaltProvider.SaltSnapshot lastSnapshot = snapshots.getLast();
9193

9294
final SaltRotation.Result result = saltRotation.rotateSalts(
9395
lastSnapshot, minAges, fraction.get(), targetDate);

0 commit comments

Comments
 (0)