diff --git a/src/main/java/com/uid2/admin/secret/SaltRotation.java b/src/main/java/com/uid2/admin/secret/SaltRotation.java index eca371c2..35c3b96e 100644 --- a/src/main/java/com/uid2/admin/secret/SaltRotation.java +++ b/src/main/java/com/uid2/admin/secret/SaltRotation.java @@ -4,7 +4,12 @@ import com.uid2.shared.secret.IKeyGenerator; import com.uid2.shared.store.salt.RotatingSaltProvider; -import java.time.*; +import com.uid2.shared.store.salt.RotatingSaltProvider.SaltSnapshot; + +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.ZoneOffset; import java.time.temporal.ChronoUnit; import java.util.*; import java.util.stream.IntStream; @@ -30,62 +35,94 @@ public Result rotateSalts(RotatingSaltProvider.SaltSnapshot lastSnapshot, return Result.noSnapshot("cannot create a new salt snapshot with effective timestamp equal or prior to that of an existing snapshot"); } + List saltIndexesToRotate = pickSaltIndexesToRotate(lastSnapshot, nextEffective, minAges, fraction); + if (saltIndexesToRotate.isEmpty()) { + return Result.noSnapshot("all salts are below min rotation age"); + } + + var updatedSalts = updateSalts(lastSnapshot.getAllRotatingSalts(), saltIndexesToRotate, nextEffective.toEpochMilli()); + + SaltSnapshot nextSnapshot = new SaltSnapshot( + nextEffective, + nextExpires, + updatedSalts, + lastSnapshot.getFirstLevelSalt()); + return Result.fromSnapshot(nextSnapshot); + } + + private SaltEntry[] updateSalts(SaltEntry[] oldSalts, List saltIndexesToRotate, long nextEffective) throws Exception { + var updatedSalts = new SaltEntry[oldSalts.length]; + + for (int i = 0; i < oldSalts.length; i++) { + var shouldRotate = saltIndexesToRotate.contains(i); + updatedSalts[i] = updateSalt(oldSalts[i], shouldRotate, nextEffective); + } + return updatedSalts; + } + + private SaltEntry updateSalt(SaltEntry oldSalt, boolean shouldRotate, long nextEffective) throws Exception { + var currentSalt = shouldRotate ? this.keyGenerator.generateRandomKeyString(32) : oldSalt.currentSalt(); + var lastUpdated = shouldRotate ? nextEffective : oldSalt.lastUpdated(); + + return new SaltEntry( + oldSalt.id(), + oldSalt.hashedId(), + lastUpdated, + currentSalt, + null, + null, + null, + null + ); + } + + private List pickSaltIndexesToRotate( + SaltSnapshot lastSnapshot, + Instant nextEffective, + Duration[] minAges, + double fraction) { final Instant[] thresholds = Arrays.stream(minAges) - .map(a -> nextEffective.minusSeconds(a.getSeconds())) + .map(age -> nextEffective.minusSeconds(age.getSeconds())) .sorted() .toArray(Instant[]::new); - final int maxSalts = (int)Math.ceil(lastSnapshot.getAllRotatingSalts().length * fraction); - final List entryIndexes = new ArrayList<>(); + final int maxSalts = (int) Math.ceil(lastSnapshot.getAllRotatingSalts().length * fraction); + final List indexesToRotate = new ArrayList<>(); Instant minLastUpdated = Instant.ofEpochMilli(0); for (Instant threshold : thresholds) { - if (entryIndexes.size() >= maxSalts) break; - addIndexesToRotate(entryIndexes, lastSnapshot, - minLastUpdated.toEpochMilli(), threshold.toEpochMilli(), - maxSalts - entryIndexes.size()); + if (indexesToRotate.size() >= maxSalts) break; + addIndexesToRotate( + indexesToRotate, + lastSnapshot, + minLastUpdated.toEpochMilli(), + threshold.toEpochMilli(), + maxSalts - indexesToRotate.size() + ); minLastUpdated = threshold; } - - if (entryIndexes.isEmpty()) return Result.noSnapshot("all salts are below min rotation age"); - - return Result.fromSnapshot(createRotatedSnapshot(lastSnapshot, nextEffective, nextExpires, entryIndexes)); + return indexesToRotate; } private void addIndexesToRotate(List entryIndexes, - RotatingSaltProvider.SaltSnapshot lastSnapshot, + SaltSnapshot lastSnapshot, long minLastUpdated, long maxLastUpdated, int maxIndexes) { final SaltEntry[] entries = lastSnapshot.getAllRotatingSalts(); final List candidateIndexes = IntStream.range(0, entries.length) .filter(i -> isBetween(entries[i].lastUpdated(), minLastUpdated, maxLastUpdated)) - .boxed().collect(toList()); + .boxed() + .collect(toList()); if (candidateIndexes.size() <= maxIndexes) { entryIndexes.addAll(candidateIndexes); return; } Collections.shuffle(candidateIndexes); - candidateIndexes.stream().limit(maxIndexes).forEachOrdered(i -> entryIndexes.add(i)); + candidateIndexes.stream().limit(maxIndexes).forEachOrdered(entryIndexes::add); } private static boolean isBetween(long t, long minInclusive, long maxExclusive) { return minInclusive <= t && t < maxExclusive; } - private RotatingSaltProvider.SaltSnapshot createRotatedSnapshot(RotatingSaltProvider.SaltSnapshot lastSnapshot, - Instant nextEffective, - Instant nextExpires, - List entryIndexes) throws Exception { - final long lastUpdated = nextEffective.toEpochMilli(); - final RotatingSaltProvider.SaltSnapshot nextSnapshot = new RotatingSaltProvider.SaltSnapshot( - nextEffective, nextExpires, - Arrays.copyOf(lastSnapshot.getAllRotatingSalts(), lastSnapshot.getAllRotatingSalts().length), - lastSnapshot.getFirstLevelSalt()); - for (Integer i : entryIndexes) { - final SaltEntry oldSalt = nextSnapshot.getAllRotatingSalts()[i]; - final String secret = this.keyGenerator.generateRandomKeyString(32); - nextSnapshot.getAllRotatingSalts()[i] = new SaltEntry(oldSalt.id(), oldSalt.hashedId(), lastUpdated, secret, null, null, null, null); - } - return nextSnapshot; - } } diff --git a/src/main/java/com/uid2/admin/vertx/service/SaltService.java b/src/main/java/com/uid2/admin/vertx/service/SaltService.java index 139b9994..086bbc26 100644 --- a/src/main/java/com/uid2/admin/vertx/service/SaltService.java +++ b/src/main/java/com/uid2/admin/vertx/service/SaltService.java @@ -19,7 +19,6 @@ import java.time.*; import java.time.format.DateTimeFormatter; -import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.List; import java.util.Optional; diff --git a/src/test/java/com/uid2/admin/secret/SaltRotationTest.java b/src/test/java/com/uid2/admin/secret/SaltRotationTest.java index 9c74e984..2bcf4977 100644 --- a/src/test/java/com/uid2/admin/secret/SaltRotationTest.java +++ b/src/test/java/com/uid2/admin/secret/SaltRotationTest.java @@ -19,8 +19,6 @@ import static org.mockito.Mockito.*; public class SaltRotationTest { - private AutoCloseable mocks; - @Mock private IKeyGenerator keyGenerator; private SaltRotation saltRotation; @@ -28,8 +26,8 @@ public class SaltRotationTest { private final Instant targetDateAsInstant = targetDate.atStartOfDay().toInstant(ZoneOffset.UTC); @BeforeEach - void setup() throws Exception { - mocks = MockitoAnnotations.openMocks(this); + void setup() { + MockitoAnnotations.openMocks(this); saltRotation = new SaltRotation(keyGenerator); } @@ -51,7 +49,7 @@ public SnapshotBuilder withEntries(int count, Instant lastUpdated) { public RotatingSaltProvider.SaltSnapshot build(Instant effective, Instant expires) { return new RotatingSaltProvider.SaltSnapshot( - effective, expires, entries.stream().toArray(SaltEntry[]::new), "test_first_level_salt"); + effective, expires, entries.toArray(SaltEntry[]::new), "test_first_level_salt"); } }