Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/java/com/uid2/admin/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public void run() {
WriteLock writeLock = new WriteLock();
KeyHasher keyHasher = new KeyHasher();
IKeypairGenerator keypairGenerator = new SecureKeypairGenerator();
ISaltRotation saltRotation = new SaltRotation(config, keyGenerator);
ISaltRotation saltRotation = new SaltRotation(keyGenerator);
EncryptionKeyService encryptionKeyService = new EncryptionKeyService(
config, auth, writeLock, encryptionKeyStoreWriter, keysetKeyStoreWriter, keyProvider, keysetKeysProvider, adminKeysetProvider, adminKeysetStoreWriter, keyGenerator, clock);
KeysetManager keysetManager = new KeysetManager(
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/uid2/admin/secret/ISaltRotation.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import com.uid2.shared.store.salt.RotatingSaltProvider;

import java.time.Duration;
import java.time.LocalDate;

public interface ISaltRotation {
Result rotateSalts(RotatingSaltProvider.SaltSnapshot lastSnapshot,
Duration[] minAges,
double fraction) throws Exception;
double fraction,
LocalDate nextEffective) throws Exception;

class Result {
private RotatingSaltProvider.SaltSnapshot snapshot; // can be null if new snapshot is not needed
Expand Down
132 changes: 75 additions & 57 deletions src/main/java/com/uid2/admin/secret/SaltRotation.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,108 +3,126 @@
import com.uid2.shared.model.SaltEntry;
import com.uid2.shared.secret.IKeyGenerator;
import com.uid2.shared.store.salt.RotatingSaltProvider;
import io.vertx.core.json.JsonObject;

import com.uid2.shared.store.salt.RotatingSaltProvider.SaltSnapshot;

import java.time.Duration;
import java.time.Instant;
import java.time.LocalDate;
import java.time.ZoneOffset;
import java.time.temporal.ChronoUnit;
import java.util.*;
import java.util.stream.IntStream;

import static java.util.stream.Collectors.toList;

public class SaltRotation implements ISaltRotation {
private static final String SNAPSHOT_ACTIVATES_IN_SECONDS = "salt_snapshot_activates_in_seconds";
private static final String SNAPSHOT_EXPIRES_AFTER_SECONDS = "salt_snapshot_expires_after_seconds";

private final IKeyGenerator keyGenerator;
private final Duration snapshotActivatesIn;
private final Duration snapshotExpiresAfter;

public static Duration getSnapshotActivatesIn(JsonObject config) {
return Duration.ofSeconds(config.getInteger(SNAPSHOT_ACTIVATES_IN_SECONDS));
}
public static Duration getSnapshotExpiresAfter(JsonObject config) {
return Duration.ofSeconds(config.getInteger(SNAPSHOT_EXPIRES_AFTER_SECONDS));
}

public SaltRotation(JsonObject config, IKeyGenerator keyGenerator) {
public SaltRotation(IKeyGenerator keyGenerator) {
this.keyGenerator = keyGenerator;
}

snapshotActivatesIn = getSnapshotActivatesIn(config);
snapshotExpiresAfter = getSnapshotExpiresAfter(config);
@Override
public Result rotateSalts(RotatingSaltProvider.SaltSnapshot lastSnapshot,
Duration[] minAges,
double fraction,
LocalDate targetDate) throws Exception {

final Instant nextEffective = targetDate.atStartOfDay().toInstant(ZoneOffset.UTC);
final Instant nextExpires = nextEffective.plus(7, ChronoUnit.DAYS);
if (nextEffective.equals(lastSnapshot.getEffective()) || nextEffective.isBefore(lastSnapshot.getEffective())) {
return Result.noSnapshot("cannot create a new salt snapshot with effective timestamp equal or prior to that of an existing snapshot");
}

if (snapshotActivatesIn.compareTo(snapshotExpiresAfter) >= 0) {
throw new IllegalStateException(SNAPSHOT_EXPIRES_AFTER_SECONDS + " must be greater than " + SNAPSHOT_ACTIVATES_IN_SECONDS);
List<Integer> saltIndexesToRotate = pickSaltIndexesToRotate(lastSnapshot, nextEffective, minAges, fraction);
if (saltIndexesToRotate.isEmpty()) {
return Result.noSnapshot("all salts are below min rotation age");
}

var updatedSalts = updateSalts(lastSnapshot.getAllRotatingSalts(), saltIndexesToRotate, nextEffective.toEpochMilli());

SaltSnapshot nextSnapshot = new SaltSnapshot(
nextEffective,
nextExpires,
updatedSalts,
lastSnapshot.getFirstLevelSalt());
return Result.fromSnapshot(nextSnapshot);
}

@Override
public Result rotateSalts(RotatingSaltProvider.SaltSnapshot lastSnapshot,
Duration[] minAges,
double fraction) throws Exception {
final Instant now = Instant.now();
final Instant nextEffective = now.plusSeconds(snapshotActivatesIn.getSeconds());
final Instant nextExpires = nextEffective.plusSeconds(snapshotExpiresAfter.getSeconds());
if (!nextEffective.isAfter(lastSnapshot.getEffective())) {
return Result.noSnapshot("cannot create a new salt snapshot with effective timestamp prior to that of an existing snapshot");
private SaltEntry[] updateSalts(SaltEntry[] oldSalts, List<Integer> saltIndexesToRotate, long nextEffective) throws Exception {
var updatedSalts = new SaltEntry[oldSalts.length];

for (int i = 0; i < oldSalts.length; i++) {
var shouldRotate = saltIndexesToRotate.contains(i);
updatedSalts[i] = updateSalt(oldSalts[i], shouldRotate, nextEffective);
}
return updatedSalts;
}

private SaltEntry updateSalt(SaltEntry oldSalt, boolean shouldRotate, long nextEffective) throws Exception {
var currentSalt = shouldRotate ? this.keyGenerator.generateRandomKeyString(32) : oldSalt.currentSalt();
var lastUpdated = shouldRotate ? nextEffective : oldSalt.lastUpdated();

return new SaltEntry(
oldSalt.id(),
oldSalt.hashedId(),
lastUpdated,
currentSalt,
null,
null,
null,
null
);
}

private List<Integer> pickSaltIndexesToRotate(
SaltSnapshot lastSnapshot,
Instant nextEffective,
Duration[] minAges,
double fraction) {
final Instant[] thresholds = Arrays.stream(minAges)
.map(a -> now.minusSeconds(a.getSeconds()))
.map(age -> nextEffective.minusSeconds(age.getSeconds()))
.sorted()
.toArray(Instant[]::new);
final int maxSalts = (int)Math.ceil(lastSnapshot.getAllRotatingSalts().length * fraction);
final List<Integer> entryIndexes = new ArrayList<>();
final int maxSalts = (int) Math.ceil(lastSnapshot.getAllRotatingSalts().length * fraction);
final List<Integer> indexesToRotate = new ArrayList<>();

Instant minLastUpdated = Instant.ofEpochMilli(0);
for (Instant threshold : thresholds) {
if (entryIndexes.size() >= maxSalts) break;
addIndexesToRotate(entryIndexes, lastSnapshot,
minLastUpdated.toEpochMilli(), threshold.toEpochMilli(),
maxSalts - entryIndexes.size());
if (indexesToRotate.size() >= maxSalts) break;
addIndexesToRotate(
indexesToRotate,
lastSnapshot,
minLastUpdated.toEpochMilli(),
threshold.toEpochMilli(),
maxSalts - indexesToRotate.size()
);
minLastUpdated = threshold;
}

if (entryIndexes.isEmpty()) return Result.noSnapshot("all salts are below min rotation age");

return Result.fromSnapshot(createRotatedSnapshot(lastSnapshot, nextEffective, nextExpires, entryIndexes));
return indexesToRotate;
}

private void addIndexesToRotate(List<Integer> entryIndexes,
RotatingSaltProvider.SaltSnapshot lastSnapshot,
SaltSnapshot lastSnapshot,
long minLastUpdated,
long maxLastUpdated,
int maxIndexes) {
final SaltEntry[] entries = lastSnapshot.getAllRotatingSalts();
final List<Integer> candidateIndexes = IntStream.range(0, entries.length)
.filter(i -> isBetween(entries[i].lastUpdated(), minLastUpdated, maxLastUpdated))
.boxed().collect(toList());
.boxed()
.collect(toList());
if (candidateIndexes.size() <= maxIndexes) {
entryIndexes.addAll(candidateIndexes);
return;
}
Collections.shuffle(candidateIndexes);
candidateIndexes.stream().limit(maxIndexes).forEachOrdered(i -> entryIndexes.add(i));
candidateIndexes.stream().limit(maxIndexes).forEachOrdered(entryIndexes::add);
}

private static boolean isBetween(long t, long minInclusive, long maxExclusive) {
return minInclusive <= t && t < maxExclusive;
}

private RotatingSaltProvider.SaltSnapshot createRotatedSnapshot(RotatingSaltProvider.SaltSnapshot lastSnapshot,
Instant nextEffective,
Instant nextExpires,
List<Integer> entryIndexes) throws Exception {
final long lastUpdated = nextEffective.toEpochMilli();
final RotatingSaltProvider.SaltSnapshot nextSnapshot = new RotatingSaltProvider.SaltSnapshot(
nextEffective, nextExpires,
Arrays.copyOf(lastSnapshot.getAllRotatingSalts(), lastSnapshot.getAllRotatingSalts().length),
lastSnapshot.getFirstLevelSalt());
for (Integer i : entryIndexes) {
final SaltEntry oldSalt = nextSnapshot.getAllRotatingSalts()[i];
final String secret = this.keyGenerator.generateRandomKeyString(32);
nextSnapshot.getAllRotatingSalts()[i] = new SaltEntry(oldSalt.id(), oldSalt.hashedId(), lastUpdated, secret, null, null, null, null);
}
return nextSnapshot;
}
}
16 changes: 15 additions & 1 deletion src/main/java/com/uid2/admin/vertx/RequestUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import com.uid2.shared.store.ISiteStore;
import io.vertx.ext.web.RoutingContext;

import java.time.Duration;
import java.time.*;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -229,4 +230,17 @@ public static Optional<Double> getDouble(RoutingContext rc, String paramName) {
return Optional.empty();
}
}

public static Optional<LocalDate> getDate(RoutingContext rc, String paramName, DateTimeFormatter formatter) {
final List<String> values = rc.queryParam(paramName);
if (values.isEmpty()) {
return Optional.empty();
}
try {
return Optional.of(LocalDate.parse(values.get(0), formatter));
} catch (Exception ex) {
ResponseUtil.error(rc, 400, "failed to parse " + paramName + ": " + ex.getMessage());
return Optional.empty();
}
}
}
8 changes: 6 additions & 2 deletions src/main/java/com/uid2/admin/vertx/service/SaltService.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;

import java.time.Duration;
import java.time.*;
import java.time.format.DateTimeFormatter;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -76,6 +77,8 @@ private void handleSaltRotate(RoutingContext rc) {
if (!fraction.isPresent()) return;
final Duration[] minAges = RequestUtil.getDurations(rc, "min_ages_in_seconds");
if (minAges == null) return;
final LocalDate targetDate = RequestUtil.getDate(rc, "target_date", DateTimeFormatter.ISO_LOCAL_DATE)
.orElse(LocalDate.now(Clock.systemUTC()).plusDays(1));

// force refresh
this.saltProvider.loadContent();
Expand All @@ -85,8 +88,9 @@ private void handleSaltRotate(RoutingContext rc) {

final List<RotatingSaltProvider.SaltSnapshot> snapshots = this.saltProvider.getSnapshots();
final RotatingSaltProvider.SaltSnapshot lastSnapshot = snapshots.get(snapshots.size() - 1);

final ISaltRotation.Result result = saltRotation.rotateSalts(
lastSnapshot, minAges, fraction.get());
lastSnapshot, minAges, fraction.get(), targetDate);
if (!result.hasSnapshot()) {
ResponseUtil.error(rc, 200, result.getReason());
return;
Expand Down
2 changes: 1 addition & 1 deletion src/main/resources/localstack/s3/core/salts/metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"salts" : [
{
"effective" : 1670796729291,
"expires" : 1745907348982,
"expires" : 1766125493000,
"location" : "salts/salts.txt.1670796729291",
"size" : 2
},{
Expand Down
Loading