|
3 | 3 | import com.uid2.admin.AdminConst; |
4 | 4 | import com.uid2.shared.model.SaltEntry; |
5 | 5 | import com.uid2.shared.secret.IKeyGenerator; |
6 | | -import com.uid2.shared.store.salt.RotatingSaltProvider; |
7 | 6 |
|
8 | 7 | import com.uid2.shared.store.salt.RotatingSaltProvider.SaltSnapshot; |
9 | 8 | import io.vertx.core.json.JsonObject; |
| 9 | +import lombok.Getter; |
| 10 | +import org.slf4j.Logger; |
| 11 | +import org.slf4j.LoggerFactory; |
10 | 12 |
|
11 | 13 | import java.time.Duration; |
12 | 14 | import java.time.Instant; |
13 | 15 | import java.time.LocalDate; |
14 | 16 | import java.time.ZoneOffset; |
15 | 17 | import java.time.temporal.ChronoUnit; |
16 | 18 | import java.util.*; |
17 | | -import java.util.stream.IntStream; |
18 | | - |
19 | | -import static java.util.stream.Collectors.toList; |
20 | 19 |
|
21 | 20 | public class SaltRotation { |
22 | 21 | private final static long THIRTY_DAYS_IN_MS = Duration.ofDays(30).toMillis(); |
23 | 22 | private final static long DAY_IN_MS = Duration.ofDays(1).toMillis(); |
24 | 23 |
|
25 | 24 | private final IKeyGenerator keyGenerator; |
26 | 25 | private final boolean isRefreshFromEnabled; |
| 26 | + private static final Logger LOGGER = LoggerFactory.getLogger(SaltRotation.class); |
27 | 27 |
|
28 | 28 | public SaltRotation(JsonObject config, IKeyGenerator keyGenerator) { |
29 | 29 | this.keyGenerator = keyGenerator; |
30 | 30 | this.isRefreshFromEnabled = config.getBoolean(AdminConst.ENABLE_SALT_ROTATION_REFRESH_FROM, false); |
31 | 31 | } |
32 | 32 |
|
33 | | - public Result rotateSalts(RotatingSaltProvider.SaltSnapshot lastSnapshot, |
34 | | - Duration[] minAges, |
35 | | - double fraction, |
36 | | - LocalDate targetDate) throws Exception { |
37 | | - final Instant nextEffective = targetDate.atStartOfDay().toInstant(ZoneOffset.UTC); |
38 | | - final Instant nextExpires = nextEffective.plus(7, ChronoUnit.DAYS); |
| 33 | + public Result rotateSalts( |
| 34 | + SaltSnapshot lastSnapshot, |
| 35 | + Duration[] minAges, |
| 36 | + double fraction, |
| 37 | + LocalDate targetDate |
| 38 | + ) throws Exception { |
| 39 | + var preRotationSalts = lastSnapshot.getAllRotatingSalts(); |
| 40 | + var nextEffective = targetDate.atStartOfDay().toInstant(ZoneOffset.UTC); |
| 41 | + var nextExpires = nextEffective.plus(7, ChronoUnit.DAYS); |
39 | 42 | if (nextEffective.equals(lastSnapshot.getEffective()) || nextEffective.isBefore(lastSnapshot.getEffective())) { |
40 | 43 | return Result.noSnapshot("cannot create a new salt snapshot with effective timestamp equal or prior to that of an existing snapshot"); |
41 | 44 | } |
42 | 45 |
|
43 | | - List<Integer> saltIndexesToRotate = pickSaltIndexesToRotate(lastSnapshot, nextEffective, minAges, fraction); |
| 46 | + var rotatableSaltIndexes = findRotatableSaltIndexes(preRotationSalts, nextEffective.toEpochMilli()); |
| 47 | + var saltIndexesToRotate = pickSaltIndexesToRotate( |
| 48 | + nextEffective, |
| 49 | + minAges, |
| 50 | + fraction, |
| 51 | + preRotationSalts, |
| 52 | + rotatableSaltIndexes |
| 53 | + ); |
| 54 | + |
44 | 55 | if (saltIndexesToRotate.isEmpty()) { |
45 | | - return Result.noSnapshot("all salts are below min rotation age"); |
| 56 | + return Result.noSnapshot("all rotatable salts are below min rotation age"); |
46 | 57 | } |
47 | 58 |
|
48 | | - var updatedSalts = updateSalts(lastSnapshot.getAllRotatingSalts(), saltIndexesToRotate, nextEffective.toEpochMilli()); |
| 59 | + var postRotationSalts = updateSalts(preRotationSalts, saltIndexesToRotate, nextEffective.toEpochMilli()); |
| 60 | + |
| 61 | + new SaltAgeCounter("rotatable-salts", nextEffective).countIndexes(preRotationSalts, rotatableSaltIndexes); |
| 62 | + new SaltAgeCounter("rotated-salts", nextEffective).countIndexes(preRotationSalts, saltIndexesToRotate); |
| 63 | + new SaltAgeCounter("total-salts", nextEffective).count(postRotationSalts); |
49 | 64 |
|
50 | | - SaltSnapshot nextSnapshot = new SaltSnapshot( |
| 65 | + var nextSnapshot = new SaltSnapshot( |
51 | 66 | nextEffective, |
52 | 67 | nextExpires, |
53 | | - updatedSalts, |
| 68 | + postRotationSalts, |
54 | 69 | lastSnapshot.getFirstLevelSalt()); |
55 | 70 | return Result.fromSnapshot(nextSnapshot); |
56 | 71 | } |
57 | 72 |
|
| 73 | + private List<Integer> findRotatableSaltIndexes(SaltEntry[] preRotationSalts, long nextEffective) { |
| 74 | + var rotatableSalts = new ArrayList<Integer>(); |
| 75 | + for (int i = 0; i < preRotationSalts.length; i++) { |
| 76 | + if (isRotatable(nextEffective, preRotationSalts[i])) { |
| 77 | + rotatableSalts.add(i); |
| 78 | + } |
| 79 | + } |
| 80 | + return rotatableSalts; |
| 81 | + } |
| 82 | + |
58 | 83 | private SaltEntry[] updateSalts(SaltEntry[] oldSalts, List<Integer> saltIndexesToRotate, long nextEffective) throws Exception { |
59 | 84 | var updatedSalts = new SaltEntry[oldSalts.length]; |
60 | 85 |
|
@@ -94,88 +119,155 @@ private String calculatePreviousSalt(SaltEntry salt, boolean shouldRotate, long |
94 | 119 | return salt.currentSalt(); |
95 | 120 | } |
96 | 121 | long age = nextEffective - salt.lastUpdated(); |
97 | | - if ( age / DAY_IN_MS < 90) { |
| 122 | + if (age / DAY_IN_MS < 90) { |
98 | 123 | return salt.previousSalt(); |
99 | 124 | } |
100 | 125 | return null; |
101 | 126 | } |
102 | 127 |
|
103 | 128 | private List<Integer> pickSaltIndexesToRotate( |
104 | | - SaltSnapshot lastSnapshot, |
105 | 129 | Instant nextEffective, |
106 | 130 | Duration[] minAges, |
107 | | - double fraction) { |
108 | | - final Instant[] thresholds = Arrays.stream(minAges) |
| 131 | + double fraction, |
| 132 | + SaltEntry[] preRotationSalts, |
| 133 | + List<Integer> rotatableSaltIndexes |
| 134 | + ) { |
| 135 | + var thresholds = Arrays.stream(minAges) |
109 | 136 | .map(age -> nextEffective.minusSeconds(age.getSeconds())) |
110 | 137 | .sorted() |
111 | 138 | .toArray(Instant[]::new); |
112 | | - final int maxSalts = (int) Math.ceil(lastSnapshot.getAllRotatingSalts().length * fraction); |
113 | | - final SaltEntry[] rotatableSalts = getRotatableSalts(lastSnapshot, nextEffective.toEpochMilli()); |
114 | | - final List<Integer> indexesToRotate = new ArrayList<>(); |
| 139 | + var maxSalts = (int) Math.ceil(preRotationSalts.length * fraction); |
| 140 | + var indexesToRotate = new ArrayList<Integer>(); |
115 | 141 |
|
116 | | - Instant minLastUpdated = Instant.ofEpochMilli(0); |
117 | | - for (Instant threshold : thresholds) { |
| 142 | + var minLastUpdated = Instant.ofEpochMilli(0); |
| 143 | + for (var maxLastUpdated : thresholds) { |
118 | 144 | if (indexesToRotate.size() >= maxSalts) break; |
119 | | - addIndexesToRotate( |
120 | | - indexesToRotate, |
121 | | - rotatableSalts, |
| 145 | + |
| 146 | + var maxIndexes = maxSalts - indexesToRotate.size(); |
| 147 | + var saltsToRotate = selectIndexesToRotate( |
| 148 | + preRotationSalts, |
122 | 149 | minLastUpdated.toEpochMilli(), |
123 | | - threshold.toEpochMilli(), |
124 | | - maxSalts - indexesToRotate.size() |
| 150 | + maxLastUpdated.toEpochMilli(), |
| 151 | + maxIndexes, |
| 152 | + rotatableSaltIndexes |
125 | 153 | ); |
126 | | - minLastUpdated = threshold; |
| 154 | + indexesToRotate.addAll(saltsToRotate); |
| 155 | + minLastUpdated = maxLastUpdated; |
127 | 156 | } |
128 | 157 | return indexesToRotate; |
129 | 158 | } |
130 | 159 |
|
131 | | - private SaltEntry[] getRotatableSalts(SaltSnapshot lastSnapshot, long nextEffective) { |
132 | | - SaltEntry[] salts = lastSnapshot.getAllRotatingSalts(); |
133 | | - if (isRefreshFromEnabled) { |
134 | | - return Arrays.stream(salts).filter(s -> s.refreshFrom() == nextEffective).toArray(SaltEntry[]::new); |
| 160 | + private List<Integer> selectIndexesToRotate( |
| 161 | + SaltEntry[] salts, |
| 162 | + long minLastUpdated, |
| 163 | + long maxLastUpdated, |
| 164 | + int maxIndexes, |
| 165 | + List<Integer> rotatableSaltIndexes |
| 166 | + ) { |
| 167 | + var candidateIndexes = indexesForRotation(salts, minLastUpdated, maxLastUpdated, rotatableSaltIndexes); |
| 168 | + |
| 169 | + if (candidateIndexes.size() <= maxIndexes) { |
| 170 | + return candidateIndexes; |
135 | 171 | } |
136 | | - return salts; |
| 172 | + Collections.shuffle(candidateIndexes); |
| 173 | + return candidateIndexes.subList(0, Math.min(maxIndexes, candidateIndexes.size())); |
137 | 174 | } |
138 | 175 |
|
| 176 | + private List<Integer> indexesForRotation( |
| 177 | + SaltEntry[] salts, |
| 178 | + long minLastUpdated, |
| 179 | + long maxLastUpdated, |
| 180 | + List<Integer> rotatableSaltIndexes |
| 181 | + ) { |
| 182 | + var candidateIndexes = new ArrayList<Integer>(); |
| 183 | + for (int i = 0; i < salts.length; i++) { |
| 184 | + var salt = salts[i]; |
| 185 | + var lastUpdated = salt.lastUpdated(); |
| 186 | + var isInTimeWindow = minLastUpdated <= lastUpdated && lastUpdated < maxLastUpdated; |
| 187 | + var isRotatable = rotatableSaltIndexes.contains(i); |
139 | 188 |
|
140 | | - private void addIndexesToRotate(List<Integer> entryIndexes, |
141 | | - SaltEntry[] entries, |
142 | | - long minLastUpdated, |
143 | | - long maxLastUpdated, |
144 | | - int maxIndexes) { |
145 | | - final List<Integer> candidateIndexes = IntStream.range(0, entries.length) |
146 | | - .filter(i -> isBetween(entries[i].lastUpdated(), minLastUpdated, maxLastUpdated)) |
147 | | - .boxed() |
148 | | - .collect(toList()); |
149 | | - if (candidateIndexes.size() <= maxIndexes) { |
150 | | - entryIndexes.addAll(candidateIndexes); |
151 | | - return; |
| 189 | + if (isInTimeWindow && isRotatable) { |
| 190 | + candidateIndexes.add(i); |
| 191 | + } |
152 | 192 | } |
153 | | - Collections.shuffle(candidateIndexes); |
154 | | - candidateIndexes.stream().limit(maxIndexes).forEachOrdered(entryIndexes::add); |
| 193 | + return candidateIndexes; |
155 | 194 | } |
156 | 195 |
|
157 | | - private static boolean isBetween(long t, long minInclusive, long maxExclusive) { |
158 | | - return minInclusive <= t && t < maxExclusive; |
| 196 | + private boolean isRotatable(long nextEffective, SaltEntry salt) { |
| 197 | + if (this.isRefreshFromEnabled) { |
| 198 | + if (salt.refreshFrom() == null) { // TODO: remove once refreshFrom is no longer optional |
| 199 | + return true; |
| 200 | + } |
| 201 | + return salt.refreshFrom() == nextEffective; |
| 202 | + } |
| 203 | + |
| 204 | + return true; |
159 | 205 | } |
160 | 206 |
|
| 207 | + @Getter |
161 | 208 | public static class Result { |
162 | | - private final RotatingSaltProvider.SaltSnapshot snapshot; // can be null if new snapshot is not needed |
| 209 | + private final SaltSnapshot snapshot; // can be null if new snapshot is not needed |
163 | 210 | private final String reason; // why you are not getting a new snapshot |
164 | 211 |
|
165 | | - private Result(RotatingSaltProvider.SaltSnapshot snapshot, String reason) { |
| 212 | + private Result(SaltSnapshot snapshot, String reason) { |
166 | 213 | this.snapshot = snapshot; |
167 | 214 | this.reason = reason; |
168 | 215 | } |
169 | 216 |
|
170 | | - public boolean hasSnapshot() { return snapshot != null; } |
171 | | - public RotatingSaltProvider.SaltSnapshot getSnapshot() { return snapshot; } |
172 | | - public String getReason() { return reason; } |
| 217 | + public boolean hasSnapshot() { |
| 218 | + return snapshot != null; |
| 219 | + } |
173 | 220 |
|
174 | | - public static Result fromSnapshot(RotatingSaltProvider.SaltSnapshot snapshot) { |
| 221 | + public static Result fromSnapshot(SaltSnapshot snapshot) { |
175 | 222 | return new Result(snapshot, null); |
176 | 223 | } |
| 224 | + |
177 | 225 | public static Result noSnapshot(String reason) { |
178 | 226 | return new Result(null, reason); |
179 | 227 | } |
180 | 228 | } |
| 229 | + |
| 230 | + private static class SaltAgeCounter { |
| 231 | + private final String logEvent; |
| 232 | + private final long nextEffective; |
| 233 | + private final HashMap<Long, Long> ages = new HashMap<>(); // salt age to count |
| 234 | + |
| 235 | + public SaltAgeCounter(String logEvent, Instant nextEffective) { |
| 236 | + this.logEvent = logEvent; |
| 237 | + this.nextEffective = nextEffective.toEpochMilli(); |
| 238 | + } |
| 239 | + |
| 240 | + public void count(SaltEntry[] salts) { |
| 241 | + try { |
| 242 | + for (var salt : salts) { |
| 243 | + count(salt); |
| 244 | + } |
| 245 | + logCounts(); |
| 246 | + } catch (Exception e) { |
| 247 | + LOGGER.error("Error counting salts for {}", logEvent, e); |
| 248 | + } |
| 249 | + } |
| 250 | + |
| 251 | + public void countIndexes(SaltEntry[] salts, List<Integer> saltIndexes) { |
| 252 | + try { |
| 253 | + for (var index : saltIndexes) { |
| 254 | + count(salts[index]); |
| 255 | + } |
| 256 | + logCounts(); |
| 257 | + } catch (Exception e) { |
| 258 | + LOGGER.error("Error counting salts for {}", logEvent, e); |
| 259 | + } |
| 260 | + } |
| 261 | + |
| 262 | + private void count(SaltEntry salt) { |
| 263 | + long age = (nextEffective - salt.lastUpdated()) / DAY_IN_MS; |
| 264 | + ages.put(age, ages.getOrDefault(age, 0L) + 1); |
| 265 | + } |
| 266 | + |
| 267 | + private void logCounts() { |
| 268 | + for (var entry : ages.entrySet()) { |
| 269 | + LOGGER.info("{} age={} salts={}", logEvent, entry.getKey(), entry.getValue()); |
| 270 | + } |
| 271 | + } |
| 272 | + } |
181 | 273 | } |
0 commit comments