Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 68 additions & 31 deletions src/main/java/com/uid2/admin/secret/SaltRotation.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import com.uid2.shared.secret.IKeyGenerator;
import com.uid2.shared.store.salt.RotatingSaltProvider;

import java.time.*;
import com.uid2.shared.store.salt.RotatingSaltProvider.SaltSnapshot;

import java.time.Duration;
import java.time.Instant;
import java.time.LocalDate;
import java.time.ZoneOffset;
import java.time.temporal.ChronoUnit;
import java.util.*;
import java.util.stream.IntStream;
Expand All @@ -30,62 +35,94 @@ public Result rotateSalts(RotatingSaltProvider.SaltSnapshot lastSnapshot,
return Result.noSnapshot("cannot create a new salt snapshot with effective timestamp equal or prior to that of an existing snapshot");
}

List<Integer> saltIndexesToRotate = pickSaltIndexesToRotate(lastSnapshot, nextEffective, minAges, fraction);
if (saltIndexesToRotate.isEmpty()) {
return Result.noSnapshot("all salts are below min rotation age");
}

var updatedSalts = updateSalts(lastSnapshot.getAllRotatingSalts(), saltIndexesToRotate, nextEffective.toEpochMilli());

SaltSnapshot nextSnapshot = new SaltSnapshot(
nextEffective,
nextExpires,
updatedSalts,
lastSnapshot.getFirstLevelSalt());
return Result.fromSnapshot(nextSnapshot);
}

private SaltEntry[] updateSalts(SaltEntry[] oldSalts, List<Integer> saltIndexesToRotate, long nextEffective) throws Exception {
var updatedSalts = new SaltEntry[oldSalts.length];

for (int i = 0; i < oldSalts.length; i++) {
var shouldRotate = saltIndexesToRotate.contains(i);
updatedSalts[i] = updateSalt(oldSalts[i], shouldRotate, nextEffective);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just skip the function call if !shouldRotate and set updatedSalts[i] = oldSalts[i]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, but the whole point of this refactoring is to prepare for updating refreshFrom. We're planning to always calculate refreshFrom so we'd need to cycle through every salt.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we'll also need to update the previous from salt in here too

}
return updatedSalts;
}

private SaltEntry updateSalt(SaltEntry oldSalt, boolean shouldRotate, long nextEffective) throws Exception {
var currentSalt = shouldRotate ? this.keyGenerator.generateRandomKeyString(32) : oldSalt.currentSalt();
var lastUpdated = shouldRotate ? nextEffective : oldSalt.lastUpdated();

return new SaltEntry(
oldSalt.id(),
oldSalt.hashedId(),
lastUpdated,
currentSalt,
null,
null,
null,
null
);
}

private List<Integer> pickSaltIndexesToRotate(
SaltSnapshot lastSnapshot,
Instant nextEffective,
Duration[] minAges,
double fraction) {
final Instant[] thresholds = Arrays.stream(minAges)
.map(a -> nextEffective.minusSeconds(a.getSeconds()))
.map(age -> nextEffective.minusSeconds(age.getSeconds()))
.sorted()
.toArray(Instant[]::new);
final int maxSalts = (int)Math.ceil(lastSnapshot.getAllRotatingSalts().length * fraction);
final List<Integer> entryIndexes = new ArrayList<>();
final int maxSalts = (int) Math.ceil(lastSnapshot.getAllRotatingSalts().length * fraction);
final List<Integer> indexesToRotate = new ArrayList<>();

Instant minLastUpdated = Instant.ofEpochMilli(0);
for (Instant threshold : thresholds) {
if (entryIndexes.size() >= maxSalts) break;
addIndexesToRotate(entryIndexes, lastSnapshot,
minLastUpdated.toEpochMilli(), threshold.toEpochMilli(),
maxSalts - entryIndexes.size());
if (indexesToRotate.size() >= maxSalts) break;
addIndexesToRotate(
indexesToRotate,
lastSnapshot,
minLastUpdated.toEpochMilli(),
threshold.toEpochMilli(),
maxSalts - indexesToRotate.size()
);
minLastUpdated = threshold;
}

if (entryIndexes.isEmpty()) return Result.noSnapshot("all salts are below min rotation age");

return Result.fromSnapshot(createRotatedSnapshot(lastSnapshot, nextEffective, nextExpires, entryIndexes));
return indexesToRotate;
}

private void addIndexesToRotate(List<Integer> entryIndexes,
RotatingSaltProvider.SaltSnapshot lastSnapshot,
SaltSnapshot lastSnapshot,
long minLastUpdated,
long maxLastUpdated,
int maxIndexes) {
final SaltEntry[] entries = lastSnapshot.getAllRotatingSalts();
final List<Integer> candidateIndexes = IntStream.range(0, entries.length)
.filter(i -> isBetween(entries[i].lastUpdated(), minLastUpdated, maxLastUpdated))
.boxed().collect(toList());
.boxed()
.collect(toList());
if (candidateIndexes.size() <= maxIndexes) {
entryIndexes.addAll(candidateIndexes);
return;
}
Collections.shuffle(candidateIndexes);
candidateIndexes.stream().limit(maxIndexes).forEachOrdered(i -> entryIndexes.add(i));
candidateIndexes.stream().limit(maxIndexes).forEachOrdered(entryIndexes::add);
}

private static boolean isBetween(long t, long minInclusive, long maxExclusive) {
return minInclusive <= t && t < maxExclusive;
}

private RotatingSaltProvider.SaltSnapshot createRotatedSnapshot(RotatingSaltProvider.SaltSnapshot lastSnapshot,
Instant nextEffective,
Instant nextExpires,
List<Integer> entryIndexes) throws Exception {
final long lastUpdated = nextEffective.toEpochMilli();
final RotatingSaltProvider.SaltSnapshot nextSnapshot = new RotatingSaltProvider.SaltSnapshot(
nextEffective, nextExpires,
Arrays.copyOf(lastSnapshot.getAllRotatingSalts(), lastSnapshot.getAllRotatingSalts().length),
lastSnapshot.getFirstLevelSalt());
for (Integer i : entryIndexes) {
final SaltEntry oldSalt = nextSnapshot.getAllRotatingSalts()[i];
final String secret = this.keyGenerator.generateRandomKeyString(32);
nextSnapshot.getAllRotatingSalts()[i] = new SaltEntry(oldSalt.id(), oldSalt.hashedId(), lastUpdated, secret, null, null, null, null);
}
return nextSnapshot;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import java.time.*;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
Expand Down
8 changes: 3 additions & 5 deletions src/test/java/com/uid2/admin/secret/SaltRotationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,15 @@
import static org.mockito.Mockito.*;

public class SaltRotationTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can put this class-level annotation:
@ExtendWith(MockitoExtension.java)

Then you won't need to do MockitoAnnotations.openMocks(this); in setup()

If the test fails because the stubs are too lenient, you can additionally add this class-level annotation:
@MockitoSettings(strictness = Strictness.LENIENT)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require importing some new libraries. I suggest we keep this for a separate PR.

private AutoCloseable mocks;

@Mock private IKeyGenerator keyGenerator;
private SaltRotation saltRotation;

private final LocalDate targetDate = LocalDate.of(2025, 1, 1);
private final Instant targetDateAsInstant = targetDate.atStartOfDay().toInstant(ZoneOffset.UTC);

@BeforeEach
void setup() throws Exception {
mocks = MockitoAnnotations.openMocks(this);
void setup() {
MockitoAnnotations.openMocks(this);

saltRotation = new SaltRotation(keyGenerator);
}
Expand All @@ -51,7 +49,7 @@ public SnapshotBuilder withEntries(int count, Instant lastUpdated) {

public RotatingSaltProvider.SaltSnapshot build(Instant effective, Instant expires) {
return new RotatingSaltProvider.SaltSnapshot(
effective, expires, entries.stream().toArray(SaltEntry[]::new), "test_first_level_salt");
effective, expires, entries.toArray(SaltEntry[]::new), "test_first_level_salt");
}
}

Expand Down