diff --git a/src/main/java/com/uid2/admin/Main.java b/src/main/java/com/uid2/admin/Main.java index 2544f889..2e320a01 100644 --- a/src/main/java/com/uid2/admin/Main.java +++ b/src/main/java/com/uid2/admin/Main.java @@ -14,6 +14,7 @@ import com.uid2.admin.legacy.RotatingLegacyClientKeyProvider; import com.uid2.admin.managers.KeysetManager; import com.uid2.admin.monitoring.DataStoreMetrics; +import com.uid2.admin.salt.SaltRotation; import com.uid2.admin.secret.*; import com.uid2.admin.store.*; import com.uid2.admin.store.reader.RotatingAdminKeysetStore; diff --git a/src/main/java/com/uid2/admin/salt/SaltRotation.java b/src/main/java/com/uid2/admin/salt/SaltRotation.java new file mode 100644 index 00000000..aa8945d2 --- /dev/null +++ b/src/main/java/com/uid2/admin/salt/SaltRotation.java @@ -0,0 +1,218 @@ +package com.uid2.admin.salt; + +import com.uid2.admin.AdminConst; +import com.uid2.shared.model.SaltEntry; +import com.uid2.shared.secret.IKeyGenerator; + +import com.uid2.shared.store.salt.RotatingSaltProvider.SaltSnapshot; +import io.vertx.core.json.JsonObject; +import lombok.Getter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.*; +import java.time.temporal.ChronoUnit; +import java.util.*; +import java.util.stream.Collectors; + +public class SaltRotation { + private final static long THIRTY_DAYS_IN_MS = Duration.ofDays(30).toMillis(); + + private final IKeyGenerator keyGenerator; + private final boolean isRefreshFromEnabled; + private static final Logger LOGGER = LoggerFactory.getLogger(SaltRotation.class); + + public SaltRotation(JsonObject config, IKeyGenerator keyGenerator) { + this.keyGenerator = keyGenerator; + this.isRefreshFromEnabled = config.getBoolean(AdminConst.ENABLE_SALT_ROTATION_REFRESH_FROM, false); + } + + public Result rotateSalts( + SaltSnapshot lastSnapshot, + Duration[] minAges, + double fraction, + TargetDate targetDate + ) throws Exception { + var preRotationSalts = lastSnapshot.getAllRotatingSalts(); + var nextEffective = targetDate.asInstant(); + var nextExpires = nextEffective.plus(7, ChronoUnit.DAYS); + if (nextEffective.equals(lastSnapshot.getEffective()) || nextEffective.isBefore(lastSnapshot.getEffective())) { + return Result.noSnapshot("cannot create a new salt snapshot with effective timestamp equal or prior to that of an existing snapshot"); + } + + // Salts that can be rotated based on their refreshFrom being at target date + var refreshableSalts = findRefreshableSalts(preRotationSalts, targetDate); + + var saltsToRotate = pickSaltsToRotate( + refreshableSalts, + targetDate, + minAges, + getNumSaltsToRotate(preRotationSalts, fraction) + ); + + if (saltsToRotate.isEmpty()) { + return Result.noSnapshot("all refreshable salts are below min rotation age"); + } + + var postRotationSalts = rotateSalts(preRotationSalts, saltsToRotate, targetDate); + + logSaltAges("refreshable-salts", targetDate, refreshableSalts); + logSaltAges("rotated-salts", targetDate, saltsToRotate); + logSaltAges("total-salts", targetDate, Arrays.asList(postRotationSalts)); + + var nextSnapshot = new SaltSnapshot( + nextEffective, + nextExpires, + postRotationSalts, + lastSnapshot.getFirstLevelSalt()); + return Result.fromSnapshot(nextSnapshot); + } + + private static int getNumSaltsToRotate(SaltEntry[] preRotationSalts, double fraction) { + return (int) Math.ceil(preRotationSalts.length * fraction); + } + + private Set findRefreshableSalts(SaltEntry[] preRotationSalts, TargetDate targetDate) { + return Arrays.stream(preRotationSalts).filter(s -> isRefreshable(targetDate, s)).collect(Collectors.toSet()); + } + + private boolean isRefreshable(TargetDate targetDate, SaltEntry salt) { + if (this.isRefreshFromEnabled) { + return salt.refreshFrom().equals(targetDate.asEpochMs()); + } + + return true; + } + + private SaltEntry[] rotateSalts(SaltEntry[] oldSalts, List saltsToRotate, TargetDate targetDate) throws Exception { + var saltIdsToRotate = saltsToRotate.stream().map(SaltEntry::id).collect(Collectors.toSet()); + + var updatedSalts = new SaltEntry[oldSalts.length]; + for (int i = 0; i < oldSalts.length; i++) { + var shouldRotate = saltIdsToRotate.contains(oldSalts[i].id()); + updatedSalts[i] = updateSalt(oldSalts[i], targetDate, shouldRotate); + } + return updatedSalts; + } + + private SaltEntry updateSalt(SaltEntry oldSalt, TargetDate targetDate, boolean shouldRotate) throws Exception { + var currentSalt = shouldRotate ? this.keyGenerator.generateRandomKeyString(32) : oldSalt.currentSalt(); + var lastUpdated = shouldRotate ? targetDate.asEpochMs() : oldSalt.lastUpdated(); + var refreshFrom = calculateRefreshFrom(oldSalt, targetDate); + var previousSalt = calculatePreviousSalt(oldSalt, shouldRotate, targetDate); + + return new SaltEntry( + oldSalt.id(), + oldSalt.hashedId(), + lastUpdated, + currentSalt, + refreshFrom, + previousSalt, + null, + null + ); + } + + private long calculateRefreshFrom(SaltEntry salt, TargetDate targetDate) { + long multiplier = targetDate.saltAgeInDays(salt) / 30 + 1; + return salt.lastUpdated() + (multiplier * THIRTY_DAYS_IN_MS); + } + + private String calculatePreviousSalt(SaltEntry salt, boolean shouldRotate, TargetDate targetDate) { + if (shouldRotate) { + return salt.currentSalt(); + } + if (targetDate.saltAgeInDays(salt) < 90) { + return salt.previousSalt(); + } + return null; + } + + private List pickSaltsToRotate( + Set refreshableSalts, + TargetDate targetDate, + Duration[] minAges, + int numSaltsToRotate + ) { + var thresholds = Arrays.stream(minAges) + .map(minAge -> targetDate.asInstant().minusSeconds(minAge.getSeconds())) + .sorted() + .toArray(Instant[]::new); + var indexesToRotate = new ArrayList(); + + var minLastUpdated = Instant.ofEpochMilli(0); + for (var maxLastUpdated : thresholds) { + if (indexesToRotate.size() >= numSaltsToRotate) break; + + var maxIndexes = numSaltsToRotate - indexesToRotate.size(); + var saltsToRotate = pickSaltsToRotateInTimeWindow( + refreshableSalts, + maxIndexes, + minLastUpdated.toEpochMilli(), + maxLastUpdated.toEpochMilli() + ); + indexesToRotate.addAll(saltsToRotate); + minLastUpdated = maxLastUpdated; + } + return indexesToRotate; + } + + private List pickSaltsToRotateInTimeWindow( + Set refreshableSalts, + int maxIndexes, + long minLastUpdated, + long maxLastUpdated + ) { + ArrayList candidateSalts = refreshableSalts.stream() + .filter(salt -> minLastUpdated <= salt.lastUpdated() && salt.lastUpdated() < maxLastUpdated) + .collect(Collectors.toCollection(ArrayList::new)); + + if (candidateSalts.size() <= maxIndexes) { + return candidateSalts; + } + + Collections.shuffle(candidateSalts); + + return candidateSalts.stream().limit(maxIndexes).collect(Collectors.toList()); + } + + private void logSaltAges(String saltCountType, TargetDate targetDate, Collection salts) { + var ages = new HashMap(); // salt age to count + for (var salt : salts) { + long ageInDays = targetDate.saltAgeInDays(salt); + ages.put(ageInDays, ages.getOrDefault(ageInDays, 0L) + 1); + } + + for (var entry : ages.entrySet()) { + LOGGER.info("salt-count-type={} target-date={} age={} salt-count={}", + saltCountType, + targetDate, + entry.getKey(), + entry.getValue() + ); + } + } + + @Getter + public static class Result { + private final SaltSnapshot snapshot; // can be null if new snapshot is not needed + private final String reason; // why you are not getting a new snapshot + + private Result(SaltSnapshot snapshot, String reason) { + this.snapshot = snapshot; + this.reason = reason; + } + + public boolean hasSnapshot() { + return snapshot != null; + } + + public static Result fromSnapshot(SaltSnapshot snapshot) { + return new Result(snapshot, null); + } + + public static Result noSnapshot(String reason) { + return new Result(null, reason); + } + } +} diff --git a/src/main/java/com/uid2/admin/salt/TargetDate.java b/src/main/java/com/uid2/admin/salt/TargetDate.java new file mode 100644 index 00000000..8a503c5c --- /dev/null +++ b/src/main/java/com/uid2/admin/salt/TargetDate.java @@ -0,0 +1,69 @@ +package com.uid2.admin.salt; + +import com.uid2.shared.model.SaltEntry; + +import java.time.*; +import java.time.format.DateTimeFormatter; +import java.util.Objects; + +public class TargetDate { + private final static long DAY_IN_MS = Duration.ofDays(1).toMillis(); + + private final LocalDate date; + private final long epochMs; + private final Instant instant; + private final String formatted; + + public TargetDate(LocalDate date) { + this.instant = date.atStartOfDay().toInstant(ZoneOffset.UTC); + this.date = date; + this.epochMs = instant.toEpochMilli(); + this.formatted = date.format(DateTimeFormatter.ofPattern("yyyy-MM-dd")); + } + + public static TargetDate now() { + return new TargetDate(LocalDate.now(Clock.systemUTC())); + } + + public static TargetDate of(int year, int month, int day) { + return new TargetDate(LocalDate.of(year, month, day)); + } + + public long asEpochMs() { + return epochMs; + } + + public Instant asInstant() { + return instant; + } + + // relative to this date + public long saltAgeInDays(SaltEntry salt) { + return (this.asEpochMs() - salt.lastUpdated()) / DAY_IN_MS; + } + + public TargetDate plusDays(int days) { + return new TargetDate(date.plusDays(days)); + } + + public TargetDate minusDays(int days) { + return new TargetDate(date.minusDays(days)); + } + + @Override + public String toString() { + return formatted; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + TargetDate that = (TargetDate) o; + return epochMs == that.epochMs; + } + + @Override + public int hashCode() { + return Objects.hashCode(epochMs); + } +} diff --git a/src/main/java/com/uid2/admin/secret/SaltRotation.java b/src/main/java/com/uid2/admin/secret/SaltRotation.java deleted file mode 100644 index 657bf23e..00000000 --- a/src/main/java/com/uid2/admin/secret/SaltRotation.java +++ /dev/null @@ -1,181 +0,0 @@ -package com.uid2.admin.secret; - -import com.uid2.admin.AdminConst; -import com.uid2.shared.model.SaltEntry; -import com.uid2.shared.secret.IKeyGenerator; -import com.uid2.shared.store.salt.RotatingSaltProvider; - -import com.uid2.shared.store.salt.RotatingSaltProvider.SaltSnapshot; -import io.vertx.core.json.JsonObject; - -import java.time.Duration; -import java.time.Instant; -import java.time.LocalDate; -import java.time.ZoneOffset; -import java.time.temporal.ChronoUnit; -import java.util.*; -import java.util.stream.IntStream; - -import static java.util.stream.Collectors.toList; - -public class SaltRotation { - private final static long THIRTY_DAYS_IN_MS = Duration.ofDays(30).toMillis(); - private final static long DAY_IN_MS = Duration.ofDays(1).toMillis(); - - private final IKeyGenerator keyGenerator; - private final boolean isRefreshFromEnabled; - - public SaltRotation(JsonObject config, IKeyGenerator keyGenerator) { - this.keyGenerator = keyGenerator; - this.isRefreshFromEnabled = config.getBoolean(AdminConst.ENABLE_SALT_ROTATION_REFRESH_FROM, false); - } - - public Result rotateSalts(RotatingSaltProvider.SaltSnapshot lastSnapshot, - Duration[] minAges, - double fraction, - LocalDate targetDate) throws Exception { - final Instant nextEffective = targetDate.atStartOfDay().toInstant(ZoneOffset.UTC); - final Instant nextExpires = nextEffective.plus(7, ChronoUnit.DAYS); - if (nextEffective.equals(lastSnapshot.getEffective()) || nextEffective.isBefore(lastSnapshot.getEffective())) { - return Result.noSnapshot("cannot create a new salt snapshot with effective timestamp equal or prior to that of an existing snapshot"); - } - - List saltIndexesToRotate = pickSaltIndexesToRotate(lastSnapshot, nextEffective, minAges, fraction); - if (saltIndexesToRotate.isEmpty()) { - return Result.noSnapshot("all salts are below min rotation age"); - } - - var updatedSalts = updateSalts(lastSnapshot.getAllRotatingSalts(), saltIndexesToRotate, nextEffective.toEpochMilli()); - - SaltSnapshot nextSnapshot = new SaltSnapshot( - nextEffective, - nextExpires, - updatedSalts, - lastSnapshot.getFirstLevelSalt()); - return Result.fromSnapshot(nextSnapshot); - } - - private SaltEntry[] updateSalts(SaltEntry[] oldSalts, List saltIndexesToRotate, long nextEffective) throws Exception { - var updatedSalts = new SaltEntry[oldSalts.length]; - - for (int i = 0; i < oldSalts.length; i++) { - var shouldRotate = saltIndexesToRotate.contains(i); - updatedSalts[i] = updateSalt(oldSalts[i], shouldRotate, nextEffective); - } - return updatedSalts; - } - - private SaltEntry updateSalt(SaltEntry oldSalt, boolean shouldRotate, long nextEffective) throws Exception { - var currentSalt = shouldRotate ? this.keyGenerator.generateRandomKeyString(32) : oldSalt.currentSalt(); - var lastUpdated = shouldRotate ? nextEffective : oldSalt.lastUpdated(); - var refreshFrom = calculateRefreshFrom(oldSalt.lastUpdated(), nextEffective); - var previousSalt = calculatePreviousSalt(oldSalt, shouldRotate, nextEffective); - - return new SaltEntry( - oldSalt.id(), - oldSalt.hashedId(), - lastUpdated, - currentSalt, - refreshFrom, - previousSalt, - null, - null - ); - } - - private long calculateRefreshFrom(long lastUpdated, long nextEffective) { - long age = nextEffective - lastUpdated; - long multiplier = age / THIRTY_DAYS_IN_MS + 1; - return lastUpdated + (multiplier * THIRTY_DAYS_IN_MS); - } - - private String calculatePreviousSalt(SaltEntry salt, boolean shouldRotate, long nextEffective) throws Exception { - if (shouldRotate) { - return salt.currentSalt(); - } - long age = nextEffective - salt.lastUpdated(); - if ( age / DAY_IN_MS < 90) { - return salt.previousSalt(); - } - return null; - } - - private List pickSaltIndexesToRotate( - SaltSnapshot lastSnapshot, - Instant nextEffective, - Duration[] minAges, - double fraction) { - final Instant[] thresholds = Arrays.stream(minAges) - .map(age -> nextEffective.minusSeconds(age.getSeconds())) - .sorted() - .toArray(Instant[]::new); - final int maxSalts = (int) Math.ceil(lastSnapshot.getAllRotatingSalts().length * fraction); - final SaltEntry[] rotatableSalts = getRotatableSalts(lastSnapshot, nextEffective.toEpochMilli()); - final List indexesToRotate = new ArrayList<>(); - - Instant minLastUpdated = Instant.ofEpochMilli(0); - for (Instant threshold : thresholds) { - if (indexesToRotate.size() >= maxSalts) break; - addIndexesToRotate( - indexesToRotate, - rotatableSalts, - minLastUpdated.toEpochMilli(), - threshold.toEpochMilli(), - maxSalts - indexesToRotate.size() - ); - minLastUpdated = threshold; - } - return indexesToRotate; - } - - private SaltEntry[] getRotatableSalts(SaltSnapshot lastSnapshot, long nextEffective) { - SaltEntry[] salts = lastSnapshot.getAllRotatingSalts(); - if (isRefreshFromEnabled) { - return Arrays.stream(salts).filter(s -> s.refreshFrom() == nextEffective).toArray(SaltEntry[]::new); - } - return salts; - } - - - private void addIndexesToRotate(List entryIndexes, - SaltEntry[] entries, - long minLastUpdated, - long maxLastUpdated, - int maxIndexes) { - final List candidateIndexes = IntStream.range(0, entries.length) - .filter(i -> isBetween(entries[i].lastUpdated(), minLastUpdated, maxLastUpdated)) - .boxed() - .collect(toList()); - if (candidateIndexes.size() <= maxIndexes) { - entryIndexes.addAll(candidateIndexes); - return; - } - Collections.shuffle(candidateIndexes); - candidateIndexes.stream().limit(maxIndexes).forEachOrdered(entryIndexes::add); - } - - private static boolean isBetween(long t, long minInclusive, long maxExclusive) { - return minInclusive <= t && t < maxExclusive; - } - - public static class Result { - private final RotatingSaltProvider.SaltSnapshot snapshot; // can be null if new snapshot is not needed - private final String reason; // why you are not getting a new snapshot - - private Result(RotatingSaltProvider.SaltSnapshot snapshot, String reason) { - this.snapshot = snapshot; - this.reason = reason; - } - - public boolean hasSnapshot() { return snapshot != null; } - public RotatingSaltProvider.SaltSnapshot getSnapshot() { return snapshot; } - public String getReason() { return reason; } - - public static Result fromSnapshot(RotatingSaltProvider.SaltSnapshot snapshot) { - return new Result(snapshot, null); - } - public static Result noSnapshot(String reason) { - return new Result(null, reason); - } - } -} diff --git a/src/main/java/com/uid2/admin/vertx/service/SaltService.java b/src/main/java/com/uid2/admin/vertx/service/SaltService.java index 2f1dc14f..2841dbf0 100644 --- a/src/main/java/com/uid2/admin/vertx/service/SaltService.java +++ b/src/main/java/com/uid2/admin/vertx/service/SaltService.java @@ -1,7 +1,8 @@ package com.uid2.admin.vertx.service; import com.uid2.admin.auth.AdminAuthMiddleware; -import com.uid2.admin.secret.SaltRotation; +import com.uid2.admin.salt.SaltRotation; +import com.uid2.admin.salt.TargetDate; import com.uid2.admin.store.writer.SaltStoreWriter; import com.uid2.admin.vertx.RequestUtil; import com.uid2.admin.vertx.ResponseUtil; @@ -47,7 +48,7 @@ public SaltService(AdminAuthMiddleware auth, @Override public void setupRoutes(Router router) { router.get("/api/salt/snapshots").handler( - auth.handle(this::handleSaltSnapshots, Role.MAINTAINER)); + auth.handle(this::handleSaltSnapshots, Role.MAINTAINER)); router.post("/api/salt/rotate").blockingHandler(auth.handle((ctx) -> { synchronized (writeLock) { @@ -74,11 +75,16 @@ private void handleSaltSnapshots(RoutingContext rc) { private void handleSaltRotate(RoutingContext rc) { try { final Optional fraction = RequestUtil.getDouble(rc, "fraction"); - if (!fraction.isPresent()) return; + if (fraction.isEmpty()) return; final Duration[] minAges = RequestUtil.getDurations(rc, "min_ages_in_seconds"); if (minAges == null) return; - final LocalDate targetDate = RequestUtil.getDate(rc, "target_date", DateTimeFormatter.ISO_LOCAL_DATE) - .orElse(LocalDate.now(Clock.systemUTC()).plusDays(1)); + + + final TargetDate targetDate = + RequestUtil.getDate(rc, "target_date", DateTimeFormatter.ISO_LOCAL_DATE) + .map(TargetDate::new) + .orElse(TargetDate.now().plusDays(1)) + ; // force refresh this.saltProvider.loadContent(); @@ -87,10 +93,9 @@ private void handleSaltRotate(RoutingContext rc) { storageManager.archiveSaltLocations(); final List snapshots = this.saltProvider.getSnapshots(); - final RotatingSaltProvider.SaltSnapshot lastSnapshot = snapshots.get(snapshots.size() - 1); + final RotatingSaltProvider.SaltSnapshot lastSnapshot = snapshots.getLast(); - final SaltRotation.Result result = saltRotation.rotateSalts( - lastSnapshot, minAges, fraction.get(), targetDate); + final SaltRotation.Result result = saltRotation.rotateSalts(lastSnapshot, minAges, fraction.get(), targetDate); if (!result.hasSnapshot()) { ResponseUtil.error(rc, 200, result.getReason()); return; diff --git a/src/test/java/com/uid2/admin/salt/SaltRotationTest.java b/src/test/java/com/uid2/admin/salt/SaltRotationTest.java new file mode 100644 index 00000000..c96a647a --- /dev/null +++ b/src/test/java/com/uid2/admin/salt/SaltRotationTest.java @@ -0,0 +1,397 @@ +package com.uid2.admin.salt; + +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import com.uid2.admin.AdminConst; +import com.uid2.admin.salt.helper.SaltBuilder; +import com.uid2.admin.salt.helper.SaltSnapshotBuilder; +import com.uid2.shared.model.SaltEntry; +import com.uid2.shared.secret.IKeyGenerator; +import io.vertx.core.json.JsonObject; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.time.*; +import java.util.*; +import java.util.stream.Collectors; + +import static com.uid2.admin.salt.helper.TargetDateUtil.*; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import ch.qos.logback.classic.Logger; +import org.slf4j.LoggerFactory; + +public class SaltRotationTest { + @Mock + private IKeyGenerator keyGenerator; + private SaltRotation saltRotation; + + private ListAppender appender; + private AutoCloseable mocks; + + @BeforeEach + void setup() { + mocks = MockitoAnnotations.openMocks(this); + + appender = new ListAppender<>(); + appender.start(); + ((Logger) LoggerFactory.getLogger(SaltRotation.class)).addAppender(appender); + + JsonObject config = new JsonObject(); + saltRotation = new SaltRotation(config, keyGenerator); + } + + @AfterEach + void tearDown() throws Exception { + appender.stop(); + mocks.close(); + } + + @Test + void rotateSaltsLastSnapshotIsUpToDate() throws Exception { + final Duration[] minAges = { + Duration.ofDays(1), + Duration.ofDays(2), + }; + var lastSnapshot = SaltSnapshotBuilder.start() + .entries(10, targetDate()) + .effective(targetDate()) + .expires(daysLater(7)) + .build(); + + var result1 = saltRotation.rotateSalts(lastSnapshot, minAges, 0.2, targetDate()); + assertFalse(result1.hasSnapshot()); + var result2 = saltRotation.rotateSalts(lastSnapshot, minAges, 0.2, targetDate().minusDays(1)); + assertFalse(result2.hasSnapshot()); + } + + @Test + void rotateSaltsAllSaltsUpToDate() throws Exception { + final Duration[] minAges = { + Duration.ofDays(1), + Duration.ofDays(2), + }; + + var lastSnapshot = SaltSnapshotBuilder.start() + .entries(10, targetDate()) + .build(); + + var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.2, targetDate()); + assertFalse(result.hasSnapshot()); + verify(keyGenerator, times(0)).generateRandomKeyString(anyInt()); + } + + @Test + void rotateSaltsAllSaltsOld() throws Exception { + final Duration[] minAges = { + Duration.ofDays(1), + Duration.ofDays(2), + }; + + var lastSnapshot = SaltSnapshotBuilder.start() + .entries(10, daysEarlier(10)) + .build(); + + var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.2, targetDate()); + assertTrue(result.hasSnapshot()); + assertEquals(2, countEntriesWithLastUpdated(result.getSnapshot().getAllRotatingSalts(), result.getSnapshot().getEffective())); + assertEquals(8, countEntriesWithLastUpdated(result.getSnapshot().getAllRotatingSalts(), daysEarlier(10))); + assertEquals(targetDate().asInstant(), result.getSnapshot().getEffective()); + assertEquals(daysLater(7).asInstant(), result.getSnapshot().getExpires()); + verify(keyGenerator, times(2)).generateRandomKeyString(anyInt()); + } + + @Test + void rotateSaltsRotateSaltsFromOldestBucketOnly() throws Exception { + final Duration[] minAges = { + Duration.ofDays(5), + Duration.ofDays(4), + }; + + var lastSnapshot = SaltSnapshotBuilder.start() + .entries(3, daysEarlier(6)) + .entries(5, daysEarlier(5)) + .entries(2, daysEarlier(4)) + .build(); + + var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.2, targetDate()); + assertTrue(result.hasSnapshot()); + var salts = result.getSnapshot().getAllRotatingSalts(); + assertEquals(2, countEntriesWithLastUpdated(salts, result.getSnapshot().getEffective())); + assertEquals(1, countEntriesWithLastUpdated(salts, daysEarlier(6))); + assertEquals(5, countEntriesWithLastUpdated(salts, daysEarlier(5))); + assertEquals(2, countEntriesWithLastUpdated(salts, daysEarlier(4))); + assertEquals(targetDate().asInstant(), result.getSnapshot().getEffective()); + assertEquals(daysLater(7).asInstant(), result.getSnapshot().getExpires()); + verify(keyGenerator, times(2)).generateRandomKeyString(anyInt()); + } + + @Test + void rotateSaltsRotateSaltsFromNewerBucketOnly() throws Exception { + final Duration[] minAges = { + Duration.ofDays(5), + Duration.ofDays(3), + }; + + var lastSnapshot = SaltSnapshotBuilder.start() + .entries(3, daysEarlier(4)) + .entries(7, daysEarlier(3)) + .build(); + + var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.2, targetDate()); + assertTrue(result.hasSnapshot()); + var salts = result.getSnapshot().getAllRotatingSalts(); + assertEquals(2, countEntriesWithLastUpdated(salts, result.getSnapshot().getEffective())); + assertEquals(1, countEntriesWithLastUpdated(salts, daysEarlier(4))); + assertEquals(7, countEntriesWithLastUpdated(salts, daysEarlier(3))); + assertEquals(targetDate().asInstant(), result.getSnapshot().getEffective()); + assertEquals(daysLater(7).asInstant(), result.getSnapshot().getExpires()); + verify(keyGenerator, times(2)).generateRandomKeyString(anyInt()); + } + + @Test + void rotateSaltsRotateSaltsFromMultipleBuckets() throws Exception { + final Duration[] minAges = { + Duration.ofDays(5), + Duration.ofDays(4), + }; + + var lastSnapshot = SaltSnapshotBuilder.start() + .entries(3, daysEarlier(6)) + .entries(5, daysEarlier(5)) + .entries(2, daysEarlier(4)) + .build(); + + var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.45, targetDate()); + assertTrue(result.hasSnapshot()); + var salts = result.getSnapshot().getAllRotatingSalts(); + assertEquals(5, countEntriesWithLastUpdated(salts, result.getSnapshot().getEffective())); + assertEquals(0, countEntriesWithLastUpdated(salts, daysEarlier(6))); + assertEquals(3, countEntriesWithLastUpdated(salts, daysEarlier(5))); + assertEquals(2, countEntriesWithLastUpdated(salts, daysEarlier(4))); + assertEquals(targetDate().asInstant(), result.getSnapshot().getEffective()); + assertEquals(daysLater(7).asInstant(), result.getSnapshot().getExpires()); + verify(keyGenerator, times(5)).generateRandomKeyString(anyInt()); + } + + @Test + void rotateSaltsRotateSaltsInsufficientOutdatedSalts() throws Exception { + final Duration[] minAges = { + Duration.ofDays(5), + Duration.ofDays(3), + }; + + var lastSnapshot = SaltSnapshotBuilder.start() + .entries(1, daysEarlier(5)) + .entries(2, daysEarlier(4)) + .entries(7, daysEarlier(2)) + .build(); + + var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.45, targetDate()); + assertTrue(result.hasSnapshot()); + var salts = result.getSnapshot().getAllRotatingSalts(); + assertEquals(3, countEntriesWithLastUpdated(salts, result.getSnapshot().getEffective())); + assertEquals(0, countEntriesWithLastUpdated(salts, daysEarlier(5))); + assertEquals(0, countEntriesWithLastUpdated(salts, daysEarlier(4))); + assertEquals(7, countEntriesWithLastUpdated(salts, daysEarlier(2))); + assertEquals(targetDate().asInstant(), result.getSnapshot().getEffective()); + assertEquals(daysLater(7).asInstant(), result.getSnapshot().getExpires()); + verify(keyGenerator, times(3)).generateRandomKeyString(anyInt()); + } + + @ParameterizedTest + @CsvSource({ + "5, 30", // Soon after rotation, use 30 days post rotation + "40, 60", // >30 days after rotation use the next increment of 30 days + "60, 90", // Exactly at multiple of 30 days post rotation, use next increment of 30 days + }) + void testRefreshFromCalculation(int lastRotationDaysAgo, int refreshFromDaysFromRotation) throws Exception { + var lastRotation = daysEarlier(lastRotationDaysAgo); + SaltBuilder saltBuilder = SaltBuilder.start().lastUpdated(lastRotation); + var lastSnapshot = SaltSnapshotBuilder.start() + .entries(saltBuilder) + .build(); + + var result = saltRotation.rotateSalts(lastSnapshot, new Duration[]{Duration.ofDays(1)}, 0.45, targetDate()); + var actual = result.getSnapshot().getAllRotatingSalts()[0]; + + var expected = lastRotation.plusDays(refreshFromDaysFromRotation).asEpochMs(); + + assertThat(actual.refreshFrom()).isEqualTo(expected); + } + + @Test + void rotateSaltsPopulatePreviousSaltsOnRotation() throws Exception { + final Duration[] minAges = { + Duration.ofDays(90), + Duration.ofDays(60), + Duration.ofDays(30) + }; + + var lessThan90Days = daysEarlier(60); + var exactly90Days = daysEarlier(90); + var over90Days = daysEarlier(120); + var lastSnapshot = SaltSnapshotBuilder.start() + .entries( + SaltBuilder.start().lastUpdated(lessThan90Days).currentSalt("salt1"), + SaltBuilder.start().lastUpdated(exactly90Days).currentSalt("salt2"), + SaltBuilder.start().lastUpdated(over90Days).currentSalt("salt3") + ) + .build(); + + var result = saltRotation.rotateSalts(lastSnapshot, minAges, 1, targetDate()); + assertTrue(result.hasSnapshot()); + + var salts = result.getSnapshot().getAllRotatingSalts(); + assertEquals("salt1", salts[0].previousSalt()); + assertEquals("salt2", salts[1].previousSalt()); + assertEquals("salt3", salts[2].previousSalt()); + } + + @Test + void rotateSaltsPreservePreviousSaltsLessThan90DaysOld() throws Exception { + final Duration[] minAges = { + Duration.ofDays(60), + }; + + var notValidForRotation1 = daysEarlier(40); + var notValidForRotation2 = daysEarlier(50); + var validForRotation = daysEarlier(70); + var lastSnapshot = SaltSnapshotBuilder.start() + .entries( + SaltBuilder.start().lastUpdated(notValidForRotation1).currentSalt("salt1").previousSalt("previousSalt1"), + SaltBuilder.start().lastUpdated(notValidForRotation2).currentSalt("salt2") + ) + .entries(1, validForRotation) + .build(); + + var result = saltRotation.rotateSalts(lastSnapshot, minAges, 1, targetDate()); + assertTrue(result.hasSnapshot()); + + var salts = result.getSnapshot().getAllRotatingSalts(); + assertEquals("previousSalt1", salts[0].previousSalt()); + assertNull(salts[1].previousSalt()); + } + + @Test + void rotateSaltsRemovePreviousSaltsOver90DaysOld() throws Exception { + final Duration[] minAges = { + Duration.ofDays(100), + }; + + var exactly90Days = daysEarlier(90); + var over90Days = daysEarlier(100); + var validForRotation = daysEarlier(120); + var lastSnapshot = SaltSnapshotBuilder.start() + .entries( + SaltBuilder.start().lastUpdated(exactly90Days).previousSalt("90DaysOld"), + SaltBuilder.start().lastUpdated(over90Days).previousSalt("over90DaysOld") + ) + .entries(1, validForRotation) + .build(); + + var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.5, targetDate()); + assertTrue(result.hasSnapshot()); + + var salts = result.getSnapshot().getAllRotatingSalts(); + assertNull(salts[0].previousSalt()); + assertNull(salts[1].previousSalt()); + } + + + @Test + void rotateSaltsRotateWhenRefreshFromIsTargetDate() throws Exception { + JsonObject config = new JsonObject(); + config.put(AdminConst.ENABLE_SALT_ROTATION_REFRESH_FROM, Boolean.TRUE); + saltRotation = new SaltRotation(config, keyGenerator); + + final Duration[] minAges = { + Duration.ofDays(90), + Duration.ofDays(60), + }; + + var validForRotation1 = daysEarlier(120); + var validForRotation2 = daysEarlier(70); + var notValidForRotation = daysEarlier(30); + var refreshNow = targetDate(); + var refreshLater = daysLater(20); + + var lastSnapshot = SaltSnapshotBuilder.start() + .entries( + SaltBuilder.start().lastUpdated(validForRotation1).refreshFrom(refreshNow), + SaltBuilder.start().lastUpdated(notValidForRotation).refreshFrom(refreshNow), + SaltBuilder.start().lastUpdated(validForRotation2).refreshFrom(refreshLater) + ) + .build(); + + var result = saltRotation.rotateSalts(lastSnapshot, minAges, 1, targetDate()); + assertTrue(result.hasSnapshot()); + + var salts = result.getSnapshot().getAllRotatingSalts(); + + assertEquals(targetDate().asEpochMs(), salts[0].lastUpdated()); + assertEquals(daysLater(30).asEpochMs(), salts[0].refreshFrom()); + + assertEquals(notValidForRotation.asEpochMs(), salts[1].lastUpdated()); + assertEquals(daysLater(30).asEpochMs(), salts[1].refreshFrom()); + + assertEquals(validForRotation2.asEpochMs(), salts[2].lastUpdated()); + assertEquals(refreshLater.asEpochMs(), salts[2].refreshFrom()); + } + + @Test + void logsSaltAgesOnRotation() throws Exception { + JsonObject config = new JsonObject(); + config.put(AdminConst.ENABLE_SALT_ROTATION_REFRESH_FROM, Boolean.TRUE); + saltRotation = new SaltRotation(config, keyGenerator); + + var lastSnapshot = SaltSnapshotBuilder.start() + .entries( + // 5 salts total, 3 refreshable, 2 rotated given 40% fraction + SaltBuilder.start().lastUpdated(daysEarlier(65)).refreshFrom(targetDate()), // Refreshable, old enough, rotated + SaltBuilder.start().lastUpdated(daysEarlier(5)).refreshFrom(targetDate()), // Refreshable, too new + SaltBuilder.start().lastUpdated(daysEarlier(50)).refreshFrom(daysLater(1)), // Not refreshable, old enough + SaltBuilder.start().lastUpdated(daysEarlier(65)).refreshFrom(targetDate()), // Refreshable, old enough, rotated + SaltBuilder.start().lastUpdated(daysEarlier(10)).refreshFrom(daysLater(10)) // Not refreshable, too new + ) + .build(); + + var expected = Set.of( + // Post-rotation ages, we want to look at current state + "[INFO] salt-count-type=total-salts target-date=2025-01-01 age=0 salt-count=2", // The two rotated salts, used to be 65 and 50 days old + "[INFO] salt-count-type=total-salts target-date=2025-01-01 age=5 salt-count=1", + "[INFO] salt-count-type=total-salts target-date=2025-01-01 age=10 salt-count=1", + "[INFO] salt-count-type=total-salts target-date=2025-01-01 age=50 salt-count=1", + + // Pre-rotation ages, we want to see at which ages salts become refreshable, post rotation some will be 0 + "[INFO] salt-count-type=refreshable-salts target-date=2025-01-01 age=5 salt-count=1", + "[INFO] salt-count-type=refreshable-salts target-date=2025-01-01 age=65 salt-count=2", + + // Pre-rotation ages, post rotation they will all have age 0 + "[INFO] salt-count-type=rotated-salts target-date=2025-01-01 age=65 salt-count=2" + ); + + var minAges = new Duration[]{Duration.ofDays(30), Duration.ofDays(60)}; + saltRotation.rotateSalts(lastSnapshot, minAges, 0.4, targetDate()); + + var actual = appender.list.stream().map(Object::toString).collect(Collectors.toSet()); + assertThat(actual).isEqualTo(expected); + } + + private int countEntriesWithLastUpdated(SaltEntry[] entries, TargetDate lastUpdated) { + return countEntriesWithLastUpdated(entries, lastUpdated.asInstant()); + } + + private int countEntriesWithLastUpdated(SaltEntry[] entries, Instant lastUpdated) { + return (int) Arrays.stream(entries).filter(e -> e.lastUpdated() == lastUpdated.toEpochMilli()).count(); + } + +} diff --git a/src/test/java/com/uid2/admin/vertx/SaltServiceTest.java b/src/test/java/com/uid2/admin/salt/SaltServiceTest.java similarity index 56% rename from src/test/java/com/uid2/admin/vertx/SaltServiceTest.java rename to src/test/java/com/uid2/admin/salt/SaltServiceTest.java index c3952fca..7d4755f0 100644 --- a/src/test/java/com/uid2/admin/vertx/SaltServiceTest.java +++ b/src/test/java/com/uid2/admin/salt/SaltServiceTest.java @@ -1,11 +1,10 @@ -package com.uid2.admin.vertx; +package com.uid2.admin.salt; -import com.uid2.admin.secret.SaltRotation; +import com.uid2.admin.salt.helper.SaltSnapshotBuilder; import com.uid2.admin.vertx.service.IService; import com.uid2.admin.vertx.service.SaltService; import com.uid2.admin.vertx.test.ServiceTestBase; import com.uid2.shared.auth.Role; -import com.uid2.shared.model.SaltEntry; import com.uid2.shared.store.salt.RotatingSaltProvider; import io.vertx.core.Vertx; import io.vertx.core.json.JsonObject; @@ -14,16 +13,15 @@ import org.mockito.Mock; import java.time.Instant; -import java.time.LocalDate; -import java.time.ZoneOffset; -import java.time.temporal.ChronoUnit; import java.util.Arrays; +import static com.uid2.admin.salt.helper.TargetDateUtil.*; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.*; public class SaltServiceTest extends ServiceTestBase { + private final TargetDate utcTomorrow = TargetDate.now().plusDays(1); @Mock RotatingSaltProvider saltProvider; @Mock SaltRotation saltRotation; @@ -32,29 +30,6 @@ protected IService createService() { return new SaltService(auth, writeLock, saltStoreWriter, saltProvider, saltRotation); } - private void checkSnapshotsResponse(RotatingSaltProvider.SaltSnapshot[] expectedSnapshots, Object[] actualSnapshots) { - assertEquals(expectedSnapshots.length, actualSnapshots.length); - for (int i = 0; i < expectedSnapshots.length; ++i) { - RotatingSaltProvider.SaltSnapshot expectedSnapshot = expectedSnapshots[i]; - JsonObject actualSnapshot = (JsonObject) actualSnapshots[i]; - assertEquals(expectedSnapshot.getEffective(), Instant.ofEpochMilli(actualSnapshot.getLong("effective"))); - assertEquals(expectedSnapshot.getExpires(), Instant.ofEpochMilli(actualSnapshot.getLong("expires"))); - assertEquals(expectedSnapshot.getAllRotatingSalts().length, actualSnapshot.getInteger("salts_count")); - } - } - - private void setSnapshots(RotatingSaltProvider.SaltSnapshot... snapshots) { - when(saltProvider.getSnapshots()).thenReturn(Arrays.asList(snapshots)); - } - - private RotatingSaltProvider.SaltSnapshot makeSnapshot(Instant effective, Instant expires, int nsalts) { - SaltEntry[] entries = new SaltEntry[nsalts]; - for (int i = 0; i < entries.length; ++i) { - entries[i] = new SaltEntry(i, "hashed_id", effective.toEpochMilli(), "salt", null, null, null, null); - } - return new RotatingSaltProvider.SaltSnapshot(effective, expires, entries, "test_first_level_salt"); - } - @Test void listSaltSnapshotsNoSnapshots(Vertx vertx, VertxTestContext testContext) { fakeAuth(Role.MAINTAINER); @@ -70,10 +45,10 @@ void listSaltSnapshotsNoSnapshots(Vertx vertx, VertxTestContext testContext) { void listSaltSnapshotsWithSnapshots(Vertx vertx, VertxTestContext testContext) { fakeAuth(Role.MAINTAINER); - final RotatingSaltProvider.SaltSnapshot[] snapshots = { - makeSnapshot(Instant.ofEpochMilli(10001), Instant.ofEpochMilli(20001), 10), - makeSnapshot(Instant.ofEpochMilli(10002), Instant.ofEpochMilli(20002), 10), - makeSnapshot(Instant.ofEpochMilli(10003), Instant.ofEpochMilli(20003), 10), + final SaltSnapshotBuilder[] snapshots = { + SaltSnapshotBuilder.start().effective(daysLater(1)).expires(daysLater(4)).entries(10, daysLater(1)), + SaltSnapshotBuilder.start().effective(daysLater(2)).expires(daysLater(5)).entries(10, daysLater(2)), + SaltSnapshotBuilder.start().effective(daysLater(3)).expires(daysLater(6)).entries(10, daysLater(3)), }; setSnapshots(snapshots); @@ -88,17 +63,19 @@ void listSaltSnapshotsWithSnapshots(Vertx vertx, VertxTestContext testContext) { void rotateSalts(Vertx vertx, VertxTestContext testContext) throws Exception { fakeAuth(Role.SUPER_USER); - final RotatingSaltProvider.SaltSnapshot[] snapshots = { - makeSnapshot(Instant.ofEpochMilli(10001), Instant.ofEpochMilli(20001), 10), - makeSnapshot(Instant.ofEpochMilli(10002), Instant.ofEpochMilli(20002), 10), - makeSnapshot(Instant.ofEpochMilli(10003), Instant.ofEpochMilli(20003), 10), + final SaltSnapshotBuilder[] snapshots = { + SaltSnapshotBuilder.start().effective(daysLater(1)).expires(daysLater(4)).entries(10, daysLater(1)), + SaltSnapshotBuilder.start().effective(daysLater(2)).expires(daysLater(5)).entries(10, daysLater(2)), + SaltSnapshotBuilder.start().effective(daysLater(3)).expires(daysLater(6)).entries(10, daysLater(3)), }; setSnapshots(snapshots); - final RotatingSaltProvider.SaltSnapshot[] addedSnapshots = { - makeSnapshot(Instant.ofEpochMilli(10004), Instant.ofEpochMilli(20004), 10), + final SaltSnapshotBuilder[] addedSnapshots = { + SaltSnapshotBuilder.start().effective(daysLater(7)).expires(daysLater(8)).entries(10, daysLater(7)), }; - when(saltRotation.rotateSalts(any(), any(), eq(0.2), eq(LocalDate.now().plusDays(1)))).thenReturn(SaltRotation.Result.fromSnapshot(addedSnapshots[0])); + + var result = SaltRotation.Result.fromSnapshot(addedSnapshots[0].build()); + when(saltRotation.rotateSalts(any(), any(), eq(0.2), eq(utcTomorrow))).thenReturn(result); post(vertx, testContext, "api/salt/rotate?min_ages_in_seconds=50,60,70&fraction=0.2", "", response -> { assertEquals(200, response.statusCode()); @@ -113,14 +90,15 @@ void rotateSalts(Vertx vertx, VertxTestContext testContext) throws Exception { void rotateSaltsNoNewSnapshot(Vertx vertx, VertxTestContext testContext) throws Exception { fakeAuth(Role.SUPER_USER); - final RotatingSaltProvider.SaltSnapshot[] snapshots = { - makeSnapshot(Instant.ofEpochMilli(10001), Instant.ofEpochMilli(20001), 10), - makeSnapshot(Instant.ofEpochMilli(10002), Instant.ofEpochMilli(20002), 10), - makeSnapshot(Instant.ofEpochMilli(10003), Instant.ofEpochMilli(20003), 10), + final SaltSnapshotBuilder[] snapshots = { + SaltSnapshotBuilder.start().effective(daysLater(1)).expires(daysLater(4)).entries(10, daysLater(1)), + SaltSnapshotBuilder.start().effective(daysLater(2)).expires(daysLater(5)).entries(10, daysLater(2)), + SaltSnapshotBuilder.start().effective(daysLater(3)).expires(daysLater(6)).entries(10, daysLater(3)), }; setSnapshots(snapshots); - when(saltRotation.rotateSalts(any(), any(), eq(0.2), eq(LocalDate.now().plusDays(1)))).thenReturn(SaltRotation.Result.noSnapshot("test")); + var result = SaltRotation.Result.noSnapshot("test"); + when(saltRotation.rotateSalts(any(), any(), eq(0.2), eq(utcTomorrow))).thenReturn(result); post(vertx, testContext, "api/salt/rotate?min_ages_in_seconds=50,60,70&fraction=0.2", "", response -> { assertEquals(200, response.statusCode()); @@ -135,24 +113,38 @@ void rotateSaltsNoNewSnapshot(Vertx vertx, VertxTestContext testContext) throws @Test void rotateSaltsWitnSpecificTargetDate(Vertx vertx, VertxTestContext testContext) throws Exception { fakeAuth(Role.SUPER_USER); - LocalDate targetDate = LocalDate.of(2025, 5, 8); - Instant targetDateAsInstant = targetDate.atStartOfDay().toInstant(ZoneOffset.UTC); - final RotatingSaltProvider.SaltSnapshot[] snapshots = { - makeSnapshot(targetDateAsInstant.minus(5, ChronoUnit.DAYS), targetDateAsInstant.minus(4, ChronoUnit.DAYS), 10), - makeSnapshot(targetDateAsInstant.minus(4, ChronoUnit.DAYS), targetDateAsInstant.minus(3, ChronoUnit.DAYS), 10), - makeSnapshot(targetDateAsInstant.minus(3, ChronoUnit.DAYS), targetDateAsInstant.minus(2, ChronoUnit.DAYS), 10), + final SaltSnapshotBuilder[] snapshots = { + SaltSnapshotBuilder.start().effective(daysEarlier(5)).expires(daysEarlier(4)).entries(10, daysEarlier(5)), + SaltSnapshotBuilder.start().effective(daysEarlier(4)).expires(daysEarlier(3)).entries(10, daysEarlier(4)), + SaltSnapshotBuilder.start().effective(daysEarlier(3)).expires(daysEarlier(2)).entries(10, daysEarlier(3)), }; setSnapshots(snapshots); - final RotatingSaltProvider.SaltSnapshot[] addedSnapshots = { - makeSnapshot(targetDateAsInstant, targetDateAsInstant.plus(1, ChronoUnit.DAYS), 10), + final SaltSnapshotBuilder[] addedSnapshots = { + SaltSnapshotBuilder.start().effective(targetDate()).expires(daysEarlier(1)).entries(10, targetDate()), }; - when(saltRotation.rotateSalts(any(), any(), eq(0.2), eq(targetDate))).thenReturn(SaltRotation.Result.fromSnapshot(addedSnapshots[0])); + var result = SaltRotation.Result.fromSnapshot(addedSnapshots[0].build()); + when(saltRotation.rotateSalts(any(), any(), eq(0.2), eq(targetDate()))).thenReturn(result); - post(vertx, testContext, "api/salt/rotate?min_ages_in_seconds=50,60,70&fraction=0.2&target_date=2025-05-08", "", response -> { + post(vertx, testContext, "api/salt/rotate?min_ages_in_seconds=50,60,70&fraction=0.2&target_date=2025-01-01", "", response -> { assertEquals(200, response.statusCode()); testContext.completeNow(); }); } + + private void checkSnapshotsResponse(SaltSnapshotBuilder[] expectedSnapshots, Object[] actualSnapshots) { + assertEquals(expectedSnapshots.length, actualSnapshots.length); + for (int i = 0; i < expectedSnapshots.length; ++i) { + RotatingSaltProvider.SaltSnapshot expectedSnapshot = expectedSnapshots[i].build(); + JsonObject actualSnapshot = (JsonObject) actualSnapshots[i]; + assertEquals(expectedSnapshot.getEffective(), Instant.ofEpochMilli(actualSnapshot.getLong("effective"))); + assertEquals(expectedSnapshot.getExpires(), Instant.ofEpochMilli(actualSnapshot.getLong("expires"))); + assertEquals(expectedSnapshot.getAllRotatingSalts().length, actualSnapshot.getInteger("salts_count")); + } + } + + private void setSnapshots(SaltSnapshotBuilder... snapshots) { + when(saltProvider.getSnapshots()).thenReturn(Arrays.stream(snapshots).map(SaltSnapshotBuilder::build).toList()); + } } diff --git a/src/test/java/com/uid2/admin/salt/helper/SaltBuilder.java b/src/test/java/com/uid2/admin/salt/helper/SaltBuilder.java new file mode 100644 index 00000000..ff42ed52 --- /dev/null +++ b/src/test/java/com/uid2/admin/salt/helper/SaltBuilder.java @@ -0,0 +1,62 @@ +package com.uid2.admin.salt.helper; + +import com.uid2.admin.salt.TargetDate; +import com.uid2.shared.model.SaltEntry; + +import java.time.Instant; +import java.util.concurrent.atomic.AtomicInteger; + +public class SaltBuilder { + private static final AtomicInteger LAST_AUTO_ID = new AtomicInteger(0); + + private int id = LAST_AUTO_ID.incrementAndGet(); + private Instant lastUpdated = Instant.now(); + private Instant refreshFrom = Instant.now(); + private String currentSalt = null; + private String previousSalt = null; + + private SaltBuilder() { + } + + public static SaltBuilder start() { + return new SaltBuilder(); + } + + public SaltBuilder id(int id) { + this.id = id; + return this; + } + + public SaltBuilder lastUpdated(TargetDate lastUpdated) { + this.lastUpdated = lastUpdated.asInstant(); + return this; + } + + public SaltBuilder refreshFrom(TargetDate refreshFrom) { + this.refreshFrom = refreshFrom.asInstant(); + return this; + } + + public SaltBuilder currentSalt(String currentSalt) { + this.currentSalt = currentSalt; + return this; + } + + public SaltBuilder previousSalt(String previousSalt) { + this.previousSalt = previousSalt; + return this; + } + + public SaltEntry build() { + return new SaltEntry( + id, + Integer.toString(id), + lastUpdated.toEpochMilli(), + currentSalt == null ? "salt " + id : currentSalt, + refreshFrom.toEpochMilli(), + previousSalt, + null, + null + ); + } +} diff --git a/src/test/java/com/uid2/admin/salt/helper/SaltSnapshotBuilder.java b/src/test/java/com/uid2/admin/salt/helper/SaltSnapshotBuilder.java new file mode 100644 index 00000000..78f76dae --- /dev/null +++ b/src/test/java/com/uid2/admin/salt/helper/SaltSnapshotBuilder.java @@ -0,0 +1,60 @@ +package com.uid2.admin.salt.helper; + +import com.uid2.admin.salt.SaltRotation; +import com.uid2.admin.salt.TargetDate; +import com.uid2.shared.model.SaltEntry; +import com.uid2.shared.store.salt.RotatingSaltProvider; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static com.uid2.admin.salt.helper.TargetDateUtil.daysEarlier; +import static com.uid2.admin.salt.helper.TargetDateUtil.daysLater; + +public class SaltSnapshotBuilder { + private final List entries = new ArrayList<>(); + + private TargetDate effective = daysEarlier(1); + private TargetDate expires = daysLater(6); + + private SaltSnapshotBuilder() { + } + + public static SaltSnapshotBuilder start() { + return new SaltSnapshotBuilder(); + } + + public SaltSnapshotBuilder entries(int count, TargetDate lastUpdated) { + for (int i = 0; i < count; ++i) { + entries.add(SaltBuilder.start().lastUpdated(lastUpdated).build()); + } + return this; + } + + public SaltSnapshotBuilder entries(SaltBuilder... salts) { + SaltEntry[] builtSalts = Arrays.stream(salts).map(SaltBuilder::build).toArray(SaltEntry[]::new); + Collections.addAll(this.entries, builtSalts); + return this; + } + + public SaltSnapshotBuilder effective(TargetDate effective) { + this.effective = effective; + return this; + } + + public SaltSnapshotBuilder expires(TargetDate expires) { + this.expires = expires; + return this; + } + + public RotatingSaltProvider.SaltSnapshot build() { + return new RotatingSaltProvider.SaltSnapshot( + effective.asInstant(), + expires.asInstant(), + entries.toArray(SaltEntry[]::new), + "test_first_level_salt" + ); + } +} diff --git a/src/test/java/com/uid2/admin/salt/helper/TargetDateUtil.java b/src/test/java/com/uid2/admin/salt/helper/TargetDateUtil.java new file mode 100644 index 00000000..fea21eff --- /dev/null +++ b/src/test/java/com/uid2/admin/salt/helper/TargetDateUtil.java @@ -0,0 +1,19 @@ +package com.uid2.admin.salt.helper; + +import com.uid2.admin.salt.TargetDate; + +public class TargetDateUtil { + private static final TargetDate TARGET_DATE = TargetDate.of(2025, 1, 1); + + public static TargetDate daysEarlier(int days) { + return TARGET_DATE.minusDays(days); + } + + public static TargetDate daysLater(int days) { + return TARGET_DATE.plusDays(days); + } + + public static TargetDate targetDate() { + return TARGET_DATE; + } +} diff --git a/src/test/java/com/uid2/admin/secret/SaltRotationTest.java b/src/test/java/com/uid2/admin/secret/SaltRotationTest.java deleted file mode 100644 index 302cf188..00000000 --- a/src/test/java/com/uid2/admin/secret/SaltRotationTest.java +++ /dev/null @@ -1,365 +0,0 @@ -package com.uid2.admin.secret; - -import com.uid2.admin.AdminConst; -import com.uid2.shared.model.SaltEntry; -import com.uid2.shared.secret.IKeyGenerator; -import com.uid2.shared.store.salt.RotatingSaltProvider; -import io.vertx.core.json.JsonObject; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.CsvSource; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -import java.time.*; -import java.util.*; - -import static java.time.temporal.ChronoUnit.*; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.*; -import static org.mockito.Mockito.*; - -public class SaltRotationTest { - @Mock private IKeyGenerator keyGenerator; - private SaltRotation saltRotation; - - private final LocalDate targetDate = LocalDate.of(2025, 1, 1); - private final Instant targetDateAsInstant = targetDate.atStartOfDay().toInstant(ZoneOffset.UTC); - - private Instant daysEarlier(int days) { - return targetDateAsInstant.minus(days, DAYS); - } - - private Instant daysLater(int days) { - return targetDateAsInstant.plus(days, DAYS); - } - - @BeforeEach - void setup() { - MockitoAnnotations.openMocks(this); - - JsonObject config = new JsonObject(); - - saltRotation = new SaltRotation(config, keyGenerator); - } - - private static class SnapshotBuilder { - private final List entries = new ArrayList<>(); - - private SnapshotBuilder() {} - - public static SnapshotBuilder start() { return new SnapshotBuilder(); } - - public SnapshotBuilder withEntries(int count, Instant lastUpdated) { - for (int i = 0; i < count; ++i) { - entries.add(new SaltEntry(entries.size(), "h", lastUpdated.toEpochMilli(), "salt" + entries.size(), null, null, null, null)); - } - return this; - } - - public SnapshotBuilder withEntries(SaltEntry... salts) { - Collections.addAll(this.entries, salts); - return this; - } - - public RotatingSaltProvider.SaltSnapshot build(Instant effective, Instant expires) { - return new RotatingSaltProvider.SaltSnapshot( - effective, expires, entries.toArray(SaltEntry[]::new), "test_first_level_salt"); - } - } - - private int countEntriesWithLastUpdated(SaltEntry[] entries, Instant lastUpdated) { - return (int)Arrays.stream(entries).filter(e -> e.lastUpdated() == lastUpdated.toEpochMilli()).count(); - } - - @Test - void rotateSaltsLastSnapshotIsUpToDate() throws Exception { - final Duration[] minAges = { - Duration.ofDays(1), - Duration.ofDays(2), - }; - var lastSnapshot = SnapshotBuilder.start() - .withEntries(10, targetDateAsInstant) - .build(targetDateAsInstant, daysLater(7)); - - var result1 = saltRotation.rotateSalts(lastSnapshot, minAges, 0.2, targetDate); - assertFalse(result1.hasSnapshot()); - var result2 = saltRotation.rotateSalts(lastSnapshot, minAges, 0.2, targetDate.minusDays(1)); - assertFalse(result2.hasSnapshot()); - } - - @Test - void rotateSaltsAllSaltsUpToDate() throws Exception { - final Duration[] minAges = { - Duration.ofDays(1), - Duration.ofDays(2), - }; - - var lastSnapshot = SnapshotBuilder.start() - .withEntries(10, targetDateAsInstant) - .build(daysEarlier(1), daysLater(6)); - - var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.2, targetDate); - assertFalse(result.hasSnapshot()); - verify(keyGenerator, times(0)).generateRandomKeyString(anyInt()); - } - - @Test - void rotateSaltsAllSaltsOld() throws Exception { - final Duration[] minAges = { - Duration.ofDays(1), - Duration.ofDays(2), - }; - - var lastSnapshot = SnapshotBuilder.start() - .withEntries(10, daysEarlier(10)) - .build(daysEarlier(1), daysLater(6)); - - var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.2, targetDate); - assertTrue(result.hasSnapshot()); - assertEquals(2, countEntriesWithLastUpdated(result.getSnapshot().getAllRotatingSalts(), result.getSnapshot().getEffective())); - assertEquals(8, countEntriesWithLastUpdated(result.getSnapshot().getAllRotatingSalts(), daysEarlier(10))); - assertEquals(targetDateAsInstant, result.getSnapshot().getEffective()); - assertEquals(daysLater(7), result.getSnapshot().getExpires()); - verify(keyGenerator, times(2)).generateRandomKeyString(anyInt()); - } - - @Test - void rotateSaltsRotateSaltsFromOldestBucketOnly() throws Exception { - final Duration[] minAges = { - Duration.ofDays(5), - Duration.ofDays(4), - }; - - var lastSnapshot = SnapshotBuilder.start() - .withEntries(3, daysEarlier(6)) - .withEntries(5, daysEarlier(5)) - .withEntries(2, daysEarlier(4)) - .build(daysEarlier(1), daysLater(6)); - - var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.2, targetDate); - assertTrue(result.hasSnapshot()); - var salts = result.getSnapshot().getAllRotatingSalts(); - assertEquals(2, countEntriesWithLastUpdated(salts, result.getSnapshot().getEffective())); - assertEquals(1, countEntriesWithLastUpdated(salts, daysEarlier(6))); - assertEquals(5, countEntriesWithLastUpdated(salts, daysEarlier(5))); - assertEquals(2, countEntriesWithLastUpdated(salts, daysEarlier(4))); - assertEquals(targetDateAsInstant, result.getSnapshot().getEffective()); - assertEquals(daysLater(7), result.getSnapshot().getExpires()); - verify(keyGenerator, times(2)).generateRandomKeyString(anyInt()); - } - - @Test - void rotateSaltsRotateSaltsFromNewerBucketOnly() throws Exception { - final Duration[] minAges = { - Duration.ofDays(5), - Duration.ofDays(3), - }; - - var lastSnapshot = SnapshotBuilder.start() - .withEntries(3, daysEarlier(4)) - .withEntries(7, daysEarlier(3)) - .build(daysEarlier(1), daysLater(6)); - - var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.2, targetDate); - assertTrue(result.hasSnapshot()); - var salts = result.getSnapshot().getAllRotatingSalts(); - assertEquals(2, countEntriesWithLastUpdated(salts, result.getSnapshot().getEffective())); - assertEquals(1, countEntriesWithLastUpdated(salts, daysEarlier(4))); - assertEquals(7, countEntriesWithLastUpdated(salts, daysEarlier(3))); - assertEquals(targetDateAsInstant, result.getSnapshot().getEffective()); - assertEquals(daysLater(7), result.getSnapshot().getExpires()); - verify(keyGenerator, times(2)).generateRandomKeyString(anyInt()); - } - - @Test - void rotateSaltsRotateSaltsFromMultipleBuckets() throws Exception { - final Duration[] minAges = { - Duration.ofDays(5), - Duration.ofDays(4), - }; - - var lastSnapshot = SnapshotBuilder.start() - .withEntries(3, daysEarlier(6)) - .withEntries(5, daysEarlier(5)) - .withEntries(2, daysEarlier(4)) - .build(daysEarlier(1), daysLater(6)); - - var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.45, targetDate); - assertTrue(result.hasSnapshot()); - var salts = result.getSnapshot().getAllRotatingSalts(); - assertEquals(5, countEntriesWithLastUpdated(salts, result.getSnapshot().getEffective())); - assertEquals(0, countEntriesWithLastUpdated(salts, daysEarlier(6))); - assertEquals(3, countEntriesWithLastUpdated(salts, daysEarlier(5))); - assertEquals(2, countEntriesWithLastUpdated(salts, daysEarlier(4))); - assertEquals(targetDateAsInstant, result.getSnapshot().getEffective()); - assertEquals(daysLater(7), result.getSnapshot().getExpires()); - verify(keyGenerator, times(5)).generateRandomKeyString(anyInt()); - } - - @Test - void rotateSaltsRotateSaltsInsufficientOutdatedSalts() throws Exception { - final Duration[] minAges = { - Duration.ofDays(5), - Duration.ofDays(3), - }; - - var lastSnapshot = SnapshotBuilder.start() - .withEntries(1, daysEarlier(5)) - .withEntries(2, daysEarlier(4)) - .withEntries(7, daysEarlier(2)) - .build(daysEarlier(1), daysLater(6)); - - var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.45, targetDate); - assertTrue(result.hasSnapshot()); - var salts = result.getSnapshot().getAllRotatingSalts(); - assertEquals(3, countEntriesWithLastUpdated(salts, result.getSnapshot().getEffective())); - assertEquals(0, countEntriesWithLastUpdated(salts, daysEarlier(5))); - assertEquals(0, countEntriesWithLastUpdated(salts, daysEarlier(4))); - assertEquals(7, countEntriesWithLastUpdated(salts, daysEarlier(2))); - assertEquals(targetDateAsInstant, result.getSnapshot().getEffective()); - assertEquals(daysLater(7), result.getSnapshot().getExpires()); - verify(keyGenerator, times(3)).generateRandomKeyString(anyInt()); - } - - @ParameterizedTest - @CsvSource({ - "5, 30", // Soon after rotation, use 30 days post rotation - "40, 60", // >30 days after rotation use the next increment of 30 days - "60, 90", // Exactly at multiple of 30 days post rotation, use next increment of 30 days - }) - void testRefreshFromCalculation(int lastRotationDaysAgo, int refreshFromDaysFromRotation) throws Exception { - var lastRotation = daysEarlier(lastRotationDaysAgo); - var lastSnapshot = SnapshotBuilder.start() - .withEntries(new SaltEntry(1, "1", lastRotation.toEpochMilli(), "salt1", 100L, null, null, null)) - .build(daysEarlier(1), daysLater(6)); - - var result = saltRotation.rotateSalts(lastSnapshot, new Duration[]{ Duration.ofDays(1) }, 0.45, targetDate); - var actual = result.getSnapshot().getAllRotatingSalts()[0]; - - var expected = lastRotation.plus(refreshFromDaysFromRotation, DAYS).toEpochMilli(); - - assertThat(actual.refreshFrom()).isEqualTo(expected); - } - - @Test - void rotateSaltsPopulatePreviousSaltsOnRotation() throws Exception { - final Duration[] minAges = { - Duration.ofDays(90), - Duration.ofDays(60), - Duration.ofDays(30) - }; - - var lessThan90Days = daysEarlier(60).toEpochMilli(); - var exactly90Days = daysEarlier(90).toEpochMilli(); - var over90Days = daysEarlier(120).toEpochMilli(); - var lastSnapshot = SnapshotBuilder.start() - .withEntries( - new SaltEntry(1, "1", lessThan90Days, "salt1", null, null, null, null), - new SaltEntry(3, "2", exactly90Days, "salt2", null, null, null, null), - new SaltEntry(5, "3", over90Days, "salt3", null, null, null, null) - ) - .build(daysEarlier(1), daysLater(6)); - - var result = saltRotation.rotateSalts(lastSnapshot, minAges, 1, targetDate); - assertTrue(result.hasSnapshot()); - - var salts = result.getSnapshot().getAllRotatingSalts(); - assertEquals("salt1", salts[0].previousSalt()); - assertEquals("salt2", salts[1].previousSalt()); - assertEquals("salt3", salts[2].previousSalt()); - } - - @Test - void rotateSaltsPreservePreviousSaltsLessThan90DaysOld() throws Exception { - final Duration[] minAges = { - Duration.ofDays(60), - }; - - var notValidForRotation1 = daysEarlier(40).toEpochMilli(); - var notValidForRotation2 = daysEarlier(50).toEpochMilli(); - var validForRotation = daysEarlier(70); - var lastSnapshot = SnapshotBuilder.start() - .withEntries( - new SaltEntry(1, "1", notValidForRotation1, "salt1", null, "previousSalt1", null, null), - new SaltEntry(2, "2", notValidForRotation2, "salt2", null, null, null, null) - ) - .withEntries(1, validForRotation) - .build(daysEarlier(1), daysLater(6)); - - var result = saltRotation.rotateSalts(lastSnapshot, minAges, 1, targetDate); - assertTrue(result.hasSnapshot()); - - var salts = result.getSnapshot().getAllRotatingSalts(); - assertEquals("previousSalt1", salts[0].previousSalt()); - assertNull(salts[1].previousSalt()); - } - - @Test - void rotateSaltsRemovePreviousSaltsOver90DaysOld() throws Exception { - final Duration[] minAges = { - Duration.ofDays(100), - }; - - var exactly90Days = daysEarlier(90).toEpochMilli(); - var over90Days = daysEarlier(100).toEpochMilli(); - var validForRotation = daysEarlier(120); - var lastSnapshot = SnapshotBuilder.start() - .withEntries( - new SaltEntry(1, "1", exactly90Days, "salt1", null, "90DaysOld", null, null), - new SaltEntry(2, "2", over90Days, "salt2", null, "over90DaysOld", null, null) - ) - .withEntries(1, validForRotation) - .build(daysEarlier(1), daysLater(6)); - - var result = saltRotation.rotateSalts(lastSnapshot, minAges, 0.5, targetDate); - assertTrue(result.hasSnapshot()); - - var salts = result.getSnapshot().getAllRotatingSalts(); - assertNull(salts[0].previousSalt()); - assertNull(salts[1].previousSalt()); - } - - - @Test - void rotateSaltsRotateWhenRefreshFromIsTargetDate() throws Exception { - JsonObject config = new JsonObject(); - config.put(AdminConst.ENABLE_SALT_ROTATION_REFRESH_FROM, Boolean.TRUE); - saltRotation = new SaltRotation(config, keyGenerator); - - final Duration[] minAges = { - Duration.ofDays(90), - Duration.ofDays(60), - }; - - var validForRotation1 = daysEarlier(120).toEpochMilli(); - var validForRotation2 = daysEarlier(70).toEpochMilli(); - var notValidForRotation = daysEarlier(30).toEpochMilli(); - var refreshNow = targetDateAsInstant.toEpochMilli(); - var refreshLater = daysLater(20).toEpochMilli(); - - var lastSnapshot = SnapshotBuilder.start() - .withEntries( - new SaltEntry(1, "1", validForRotation1, "salt", refreshNow, null, null, null), - new SaltEntry(2, "2", notValidForRotation, "salt", refreshNow, null, null, null), - new SaltEntry(3, "3", validForRotation2, "salt", refreshLater, null, null, null) - ) - .build(daysEarlier(1), daysLater(6)); - - var result = saltRotation.rotateSalts(lastSnapshot, minAges, 1, targetDate); - assertTrue(result.hasSnapshot()); - - var salts = result.getSnapshot().getAllRotatingSalts(); - - assertEquals(targetDateAsInstant.toEpochMilli(), salts[0].lastUpdated()); - assertEquals(daysLater(30).toEpochMilli(), salts[0].refreshFrom()); - - assertEquals(notValidForRotation, salts[1].lastUpdated()); - assertEquals(daysLater(30).toEpochMilli(), salts[1].refreshFrom()); - - assertEquals(validForRotation2, salts[2].lastUpdated()); - assertEquals(refreshLater, salts[2].refreshFrom()); - } -}