Skip to content

Commit d8e1775

Browse files
committed
Logging salt ages on rotation
1 parent eb5d42f commit d8e1775

File tree

2 files changed

+398
-167
lines changed

2 files changed

+398
-167
lines changed

src/main/java/com/uid2/admin/secret/SaltRotation.java

Lines changed: 149 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,58 +3,83 @@
33
import com.uid2.admin.AdminConst;
44
import com.uid2.shared.model.SaltEntry;
55
import com.uid2.shared.secret.IKeyGenerator;
6-
import com.uid2.shared.store.salt.RotatingSaltProvider;
76

87
import com.uid2.shared.store.salt.RotatingSaltProvider.SaltSnapshot;
98
import io.vertx.core.json.JsonObject;
9+
import lombok.Getter;
10+
import org.slf4j.Logger;
11+
import org.slf4j.LoggerFactory;
1012

1113
import java.time.Duration;
1214
import java.time.Instant;
1315
import java.time.LocalDate;
1416
import java.time.ZoneOffset;
1517
import java.time.temporal.ChronoUnit;
1618
import java.util.*;
17-
import java.util.stream.IntStream;
18-
19-
import static java.util.stream.Collectors.toList;
2019

2120
public class SaltRotation {
2221
private final static long THIRTY_DAYS_IN_MS = Duration.ofDays(30).toMillis();
2322
private final static long DAY_IN_MS = Duration.ofDays(1).toMillis();
2423

2524
private final IKeyGenerator keyGenerator;
2625
private final boolean isRefreshFromEnabled;
26+
private static final Logger LOGGER = LoggerFactory.getLogger(SaltRotation.class);
2727

2828
public SaltRotation(JsonObject config, IKeyGenerator keyGenerator) {
2929
this.keyGenerator = keyGenerator;
3030
this.isRefreshFromEnabled = config.getBoolean(AdminConst.ENABLE_SALT_ROTATION_REFRESH_FROM, false);
3131
}
3232

33-
public Result rotateSalts(RotatingSaltProvider.SaltSnapshot lastSnapshot,
34-
Duration[] minAges,
35-
double fraction,
36-
LocalDate targetDate) throws Exception {
37-
final Instant nextEffective = targetDate.atStartOfDay().toInstant(ZoneOffset.UTC);
38-
final Instant nextExpires = nextEffective.plus(7, ChronoUnit.DAYS);
33+
public Result rotateSalts(
34+
SaltSnapshot lastSnapshot,
35+
Duration[] minAges,
36+
double fraction,
37+
LocalDate targetDate
38+
) throws Exception {
39+
var preRotationSalts = lastSnapshot.getAllRotatingSalts();
40+
var nextEffective = targetDate.atStartOfDay().toInstant(ZoneOffset.UTC);
41+
var nextExpires = nextEffective.plus(7, ChronoUnit.DAYS);
3942
if (nextEffective.equals(lastSnapshot.getEffective()) || nextEffective.isBefore(lastSnapshot.getEffective())) {
4043
return Result.noSnapshot("cannot create a new salt snapshot with effective timestamp equal or prior to that of an existing snapshot");
4144
}
4245

43-
List<Integer> saltIndexesToRotate = pickSaltIndexesToRotate(lastSnapshot, nextEffective, minAges, fraction);
46+
var rotatableSaltIndexes = findRotatableSaltIndexes(preRotationSalts, nextEffective.toEpochMilli());
47+
var saltIndexesToRotate = pickSaltIndexesToRotate(
48+
nextEffective,
49+
minAges,
50+
fraction,
51+
preRotationSalts,
52+
rotatableSaltIndexes
53+
);
54+
4455
if (saltIndexesToRotate.isEmpty()) {
45-
return Result.noSnapshot("all salts are below min rotation age");
56+
return Result.noSnapshot("all rotatable salts are below min rotation age");
4657
}
4758

48-
var updatedSalts = updateSalts(lastSnapshot.getAllRotatingSalts(), saltIndexesToRotate, nextEffective.toEpochMilli());
59+
var postRotationSalts = updateSalts(preRotationSalts, saltIndexesToRotate, nextEffective.toEpochMilli());
60+
61+
new SaltAgeCounter("rotatable-salts", nextEffective).countIndexes(preRotationSalts, rotatableSaltIndexes);
62+
new SaltAgeCounter("rotated-salts", nextEffective).countIndexes(preRotationSalts, saltIndexesToRotate);
63+
new SaltAgeCounter("total-salts", nextEffective).count(postRotationSalts);
4964

50-
SaltSnapshot nextSnapshot = new SaltSnapshot(
65+
var nextSnapshot = new SaltSnapshot(
5166
nextEffective,
5267
nextExpires,
53-
updatedSalts,
68+
postRotationSalts,
5469
lastSnapshot.getFirstLevelSalt());
5570
return Result.fromSnapshot(nextSnapshot);
5671
}
5772

73+
private List<Integer> findRotatableSaltIndexes(SaltEntry[] preRotationSalts, long nextEffective) {
74+
var rotatableSalts = new ArrayList<Integer>();
75+
for (int i = 0; i < preRotationSalts.length; i++) {
76+
if (isRotatable(nextEffective, preRotationSalts[i])) {
77+
rotatableSalts.add(i);
78+
}
79+
}
80+
return rotatableSalts;
81+
}
82+
5883
private SaltEntry[] updateSalts(SaltEntry[] oldSalts, List<Integer> saltIndexesToRotate, long nextEffective) throws Exception {
5984
var updatedSalts = new SaltEntry[oldSalts.length];
6085

@@ -94,88 +119,155 @@ private String calculatePreviousSalt(SaltEntry salt, boolean shouldRotate, long
94119
return salt.currentSalt();
95120
}
96121
long age = nextEffective - salt.lastUpdated();
97-
if ( age / DAY_IN_MS < 90) {
122+
if (age / DAY_IN_MS < 90) {
98123
return salt.previousSalt();
99124
}
100125
return null;
101126
}
102127

103128
private List<Integer> pickSaltIndexesToRotate(
104-
SaltSnapshot lastSnapshot,
105129
Instant nextEffective,
106130
Duration[] minAges,
107-
double fraction) {
108-
final Instant[] thresholds = Arrays.stream(minAges)
131+
double fraction,
132+
SaltEntry[] preRotationSalts,
133+
List<Integer> rotatableSaltIndexes
134+
) {
135+
var thresholds = Arrays.stream(minAges)
109136
.map(age -> nextEffective.minusSeconds(age.getSeconds()))
110137
.sorted()
111138
.toArray(Instant[]::new);
112-
final int maxSalts = (int) Math.ceil(lastSnapshot.getAllRotatingSalts().length * fraction);
113-
final SaltEntry[] rotatableSalts = getRotatableSalts(lastSnapshot, nextEffective.toEpochMilli());
114-
final List<Integer> indexesToRotate = new ArrayList<>();
139+
var maxSalts = (int) Math.ceil(preRotationSalts.length * fraction);
140+
var indexesToRotate = new ArrayList<Integer>();
115141

116-
Instant minLastUpdated = Instant.ofEpochMilli(0);
117-
for (Instant threshold : thresholds) {
142+
var minLastUpdated = Instant.ofEpochMilli(0);
143+
for (var maxLastUpdated : thresholds) {
118144
if (indexesToRotate.size() >= maxSalts) break;
119-
addIndexesToRotate(
120-
indexesToRotate,
121-
rotatableSalts,
145+
146+
var maxIndexes = maxSalts - indexesToRotate.size();
147+
var saltsToRotate = selectIndexesToRotate(
148+
preRotationSalts,
122149
minLastUpdated.toEpochMilli(),
123-
threshold.toEpochMilli(),
124-
maxSalts - indexesToRotate.size()
150+
maxLastUpdated.toEpochMilli(),
151+
maxIndexes,
152+
rotatableSaltIndexes
125153
);
126-
minLastUpdated = threshold;
154+
indexesToRotate.addAll(saltsToRotate);
155+
minLastUpdated = maxLastUpdated;
127156
}
128157
return indexesToRotate;
129158
}
130159

131-
private SaltEntry[] getRotatableSalts(SaltSnapshot lastSnapshot, long nextEffective) {
132-
SaltEntry[] salts = lastSnapshot.getAllRotatingSalts();
133-
if (isRefreshFromEnabled) {
134-
return Arrays.stream(salts).filter(s -> s.refreshFrom() == nextEffective).toArray(SaltEntry[]::new);
160+
private List<Integer> selectIndexesToRotate(
161+
SaltEntry[] salts,
162+
long minLastUpdated,
163+
long maxLastUpdated,
164+
int maxIndexes,
165+
List<Integer> rotatableSaltIndexes
166+
) {
167+
var candidateIndexes = indexesForRotation(salts, minLastUpdated, maxLastUpdated, rotatableSaltIndexes);
168+
169+
if (candidateIndexes.size() <= maxIndexes) {
170+
return candidateIndexes;
135171
}
136-
return salts;
172+
Collections.shuffle(candidateIndexes);
173+
return candidateIndexes.subList(0, Math.min(maxIndexes, candidateIndexes.size()));
137174
}
138175

176+
private List<Integer> indexesForRotation(
177+
SaltEntry[] salts,
178+
long minLastUpdated,
179+
long maxLastUpdated,
180+
List<Integer> rotatableSaltIndexes
181+
) {
182+
var candidateIndexes = new ArrayList<Integer>();
183+
for (int i = 0; i < salts.length; i++) {
184+
var salt = salts[i];
185+
var lastUpdated = salt.lastUpdated();
186+
var isInTimeWindow = minLastUpdated <= lastUpdated && lastUpdated < maxLastUpdated;
187+
var isRotatable = rotatableSaltIndexes.contains(i);
139188

140-
private void addIndexesToRotate(List<Integer> entryIndexes,
141-
SaltEntry[] entries,
142-
long minLastUpdated,
143-
long maxLastUpdated,
144-
int maxIndexes) {
145-
final List<Integer> candidateIndexes = IntStream.range(0, entries.length)
146-
.filter(i -> isBetween(entries[i].lastUpdated(), minLastUpdated, maxLastUpdated))
147-
.boxed()
148-
.collect(toList());
149-
if (candidateIndexes.size() <= maxIndexes) {
150-
entryIndexes.addAll(candidateIndexes);
151-
return;
189+
if (isInTimeWindow && isRotatable) {
190+
candidateIndexes.add(i);
191+
}
152192
}
153-
Collections.shuffle(candidateIndexes);
154-
candidateIndexes.stream().limit(maxIndexes).forEachOrdered(entryIndexes::add);
193+
return candidateIndexes;
155194
}
156195

157-
private static boolean isBetween(long t, long minInclusive, long maxExclusive) {
158-
return minInclusive <= t && t < maxExclusive;
196+
private boolean isRotatable(long nextEffective, SaltEntry salt) {
197+
if (this.isRefreshFromEnabled) {
198+
if (salt.refreshFrom() == null) { // TODO: remove once refreshFrom is no longer optional
199+
return true;
200+
}
201+
return salt.refreshFrom() == nextEffective;
202+
}
203+
204+
return true;
159205
}
160206

207+
@Getter
161208
public static class Result {
162-
private final RotatingSaltProvider.SaltSnapshot snapshot; // can be null if new snapshot is not needed
209+
private final SaltSnapshot snapshot; // can be null if new snapshot is not needed
163210
private final String reason; // why you are not getting a new snapshot
164211

165-
private Result(RotatingSaltProvider.SaltSnapshot snapshot, String reason) {
212+
private Result(SaltSnapshot snapshot, String reason) {
166213
this.snapshot = snapshot;
167214
this.reason = reason;
168215
}
169216

170-
public boolean hasSnapshot() { return snapshot != null; }
171-
public RotatingSaltProvider.SaltSnapshot getSnapshot() { return snapshot; }
172-
public String getReason() { return reason; }
217+
public boolean hasSnapshot() {
218+
return snapshot != null;
219+
}
173220

174-
public static Result fromSnapshot(RotatingSaltProvider.SaltSnapshot snapshot) {
221+
public static Result fromSnapshot(SaltSnapshot snapshot) {
175222
return new Result(snapshot, null);
176223
}
224+
177225
public static Result noSnapshot(String reason) {
178226
return new Result(null, reason);
179227
}
180228
}
229+
230+
private static class SaltAgeCounter {
231+
private final String logEvent;
232+
private final long nextEffective;
233+
private final HashMap<Long, Long> ages = new HashMap<>(); // salt age to count
234+
235+
public SaltAgeCounter(String logEvent, Instant nextEffective) {
236+
this.logEvent = logEvent;
237+
this.nextEffective = nextEffective.toEpochMilli();
238+
}
239+
240+
public void count(SaltEntry[] salts) {
241+
try {
242+
for (var salt : salts) {
243+
count(salt);
244+
}
245+
logCounts();
246+
} catch (Exception e) {
247+
LOGGER.error("Error counting salts for {}", logEvent, e);
248+
}
249+
}
250+
251+
public void countIndexes(SaltEntry[] salts, List<Integer> saltIndexes) {
252+
try {
253+
for (var index : saltIndexes) {
254+
count(salts[index]);
255+
}
256+
logCounts();
257+
} catch (Exception e) {
258+
LOGGER.error("Error counting salts for {}", logEvent, e);
259+
}
260+
}
261+
262+
private void count(SaltEntry salt) {
263+
long age = (nextEffective - salt.lastUpdated()) / DAY_IN_MS;
264+
ages.put(age, ages.getOrDefault(age, 0L) + 1);
265+
}
266+
267+
private void logCounts() {
268+
for (var entry : ages.entrySet()) {
269+
LOGGER.info("{} age={} salts={}", logEvent, entry.getKey(), entry.getValue());
270+
}
271+
}
272+
}
181273
}

0 commit comments

Comments
 (0)