1717import java .time .format .DateTimeFormatter ;
1818import java .time .temporal .ChronoUnit ;
1919import java .util .*;
20+ import java .util .stream .Collectors ;
2021
2122public class SaltRotation {
2223 private final static long THIRTY_DAYS_IN_MS = Duration .ofDays (30 ).toMillis ();
23- private final static long DAY_IN_MS = Duration .ofDays (1 ).toMillis ();
2424
2525 private final IKeyGenerator keyGenerator ;
2626 private final boolean isRefreshFromEnabled ;
@@ -35,38 +35,34 @@ public Result rotateSalts(
3535 SaltSnapshot lastSnapshot ,
3636 Duration [] minAges ,
3737 double fraction ,
38- LocalDate targetDate
38+ TargetDate targetDate
3939 ) throws Exception {
4040 var preRotationSalts = lastSnapshot .getAllRotatingSalts ();
41- var nextEffective = targetDate .atStartOfDay (). toInstant ( ZoneOffset . UTC );
41+ var nextEffective = targetDate .asInstant ( );
4242 var nextExpires = nextEffective .plus (7 , ChronoUnit .DAYS );
4343 if (nextEffective .equals (lastSnapshot .getEffective ()) || nextEffective .isBefore (lastSnapshot .getEffective ())) {
4444 return Result .noSnapshot ("cannot create a new salt snapshot with effective timestamp equal or prior to that of an existing snapshot" );
4545 }
4646
47- var rotatableSaltIndexes = findRotatableSaltIndexes (preRotationSalts , nextEffective .toEpochMilli ());
48- var saltIndexesToRotate = pickSaltIndexesToRotate (
49- nextEffective ,
47+ // Salts that can be rotated based on their refreshFrom being at target date
48+ var rotatableSalts = findRotatableSalts (preRotationSalts , targetDate );
49+
50+ var saltsToRotate = pickSaltsToRotate (
51+ rotatableSalts ,
52+ targetDate ,
5053 minAges ,
51- fraction ,
52- preRotationSalts ,
53- rotatableSaltIndexes
54+ getNumSaltsToRotate (preRotationSalts , fraction )
5455 );
5556
56- if (saltIndexesToRotate .isEmpty ()) {
57+ if (saltsToRotate .isEmpty ()) {
5758 return Result .noSnapshot ("all rotatable salts are below min rotation age" );
5859 }
5960
60- var postRotationSalts = updateSalts (preRotationSalts , saltIndexesToRotate , nextEffective . toEpochMilli () );
61+ var postRotationSalts = rotateSalts (preRotationSalts , saltsToRotate , targetDate );
6162
62- logSaltCounts (
63- targetDate ,
64- nextEffective ,
65- preRotationSalts ,
66- postRotationSalts ,
67- rotatableSaltIndexes ,
68- saltIndexesToRotate
69- );
63+ logSaltAgeCounts ("rotatable-salts" , targetDate , rotatableSalts );
64+ logSaltAgeCounts ("rotated-salts" , targetDate , saltsToRotate );
65+ logSaltAgeCounts ("total-salts" , targetDate , Arrays .asList (postRotationSalts ));
7066
7167 var nextSnapshot = new SaltSnapshot (
7268 nextEffective ,
@@ -76,31 +72,38 @@ public Result rotateSalts(
7672 return Result .fromSnapshot (nextSnapshot );
7773 }
7874
79- private List <Integer > findRotatableSaltIndexes (SaltEntry [] preRotationSalts , long nextEffective ) {
80- var rotatableSalts = new ArrayList <Integer >();
81- for (int i = 0 ; i < preRotationSalts .length ; i ++) {
82- if (isRotatable (nextEffective , preRotationSalts [i ])) {
83- rotatableSalts .add (i );
84- }
75+ private static int getNumSaltsToRotate (SaltEntry [] preRotationSalts , double fraction ) {
76+ return (int ) Math .ceil (preRotationSalts .length * fraction );
77+ }
78+
79+ private Set <SaltEntry > findRotatableSalts (SaltEntry [] preRotationSalts , TargetDate targetDate ) {
80+ return Arrays .stream (preRotationSalts ).filter (s -> isRotatable (targetDate , s )).collect (Collectors .toSet ());
81+ }
82+
83+ private boolean isRotatable (TargetDate targetDate , SaltEntry salt ) {
84+ if (this .isRefreshFromEnabled ) {
85+ return salt .refreshFrom ().equals (targetDate .asEpochMs ());
8586 }
86- return rotatableSalts ;
87+
88+ return true ;
8789 }
8890
89- private SaltEntry [] updateSalts (SaltEntry [] oldSalts , List <Integer > saltIndexesToRotate , long nextEffective ) throws Exception {
90- var updatedSalts = new SaltEntry [ oldSalts . length ] ;
91+ private SaltEntry [] rotateSalts (SaltEntry [] oldSalts , List <SaltEntry > saltsToRotate , TargetDate targetDate ) throws Exception {
92+ var saltIdsToRotate = saltsToRotate . stream (). map ( SaltEntry :: id ). collect ( Collectors . toSet ()) ;
9193
94+ var updatedSalts = new SaltEntry [oldSalts .length ];
9295 for (int i = 0 ; i < oldSalts .length ; i ++) {
93- var shouldRotate = saltIndexesToRotate .contains (i );
94- updatedSalts [i ] = updateSalt (oldSalts [i ], shouldRotate , nextEffective );
96+ var shouldRotate = saltIdsToRotate .contains (oldSalts [ i ]. id () );
97+ updatedSalts [i ] = updateSalt (oldSalts [i ], targetDate , shouldRotate );
9598 }
9699 return updatedSalts ;
97100 }
98101
99- private SaltEntry updateSalt (SaltEntry oldSalt , boolean shouldRotate , long nextEffective ) throws Exception {
102+ private SaltEntry updateSalt (SaltEntry oldSalt , TargetDate targetDate , boolean shouldRotate ) throws Exception {
100103 var currentSalt = shouldRotate ? this .keyGenerator .generateRandomKeyString (32 ) : oldSalt .currentSalt ();
101- var lastUpdated = shouldRotate ? nextEffective : oldSalt .lastUpdated ();
102- var refreshFrom = calculateRefreshFrom (oldSalt . lastUpdated (), nextEffective );
103- var previousSalt = calculatePreviousSalt (oldSalt , shouldRotate , nextEffective );
104+ var lastUpdated = shouldRotate ? targetDate . asEpochMs () : oldSalt .lastUpdated ();
105+ var refreshFrom = calculateRefreshFrom (oldSalt , targetDate );
106+ var previousSalt = calculatePreviousSalt (oldSalt , shouldRotate , targetDate );
104107
105108 return new SaltEntry (
106109 oldSalt .id (),
@@ -114,132 +117,133 @@ private SaltEntry updateSalt(SaltEntry oldSalt, boolean shouldRotate, long nextE
114117 );
115118 }
116119
117- private long calculateRefreshFrom (long lastUpdated , long nextEffective ) {
118- long age = nextEffective - lastUpdated ;
119- long multiplier = age / THIRTY_DAYS_IN_MS + 1 ;
120- return lastUpdated + (multiplier * THIRTY_DAYS_IN_MS );
120+ private long calculateRefreshFrom (SaltEntry salt , TargetDate targetDate ) {
121+ long multiplier = targetDate .ageOfSaltInMs (salt ) / THIRTY_DAYS_IN_MS + 1 ;
122+ return salt .lastUpdated () + (multiplier * THIRTY_DAYS_IN_MS );
121123 }
122124
123- private String calculatePreviousSalt (SaltEntry salt , boolean shouldRotate , long nextEffective ) throws Exception {
125+ private String calculatePreviousSalt (SaltEntry salt , boolean shouldRotate , TargetDate targetDate ) {
124126 if (shouldRotate ) {
125127 return salt .currentSalt ();
126128 }
127- long age = nextEffective - salt .lastUpdated ();
128- if (age / DAY_IN_MS < 90 ) {
129+ if (targetDate .ageOfSaltInDays (salt ) < 90 ) {
129130 return salt .previousSalt ();
130131 }
131132 return null ;
132133 }
133134
134- private List <Integer > pickSaltIndexesToRotate (
135- Instant nextEffective ,
135+ private List <SaltEntry > pickSaltsToRotate (
136+ Set <SaltEntry > rotatableSalts ,
137+ TargetDate targetDate ,
136138 Duration [] minAges ,
137- double fraction ,
138- SaltEntry [] preRotationSalts ,
139- List <Integer > rotatableSaltIndexes
139+ int numSaltsToRotate
140140 ) {
141141 var thresholds = Arrays .stream (minAges )
142- .map (age -> nextEffective . minusSeconds (age .getSeconds ()))
142+ .map (minAge -> targetDate . asInstant (). minusSeconds (minAge .getSeconds ()))
143143 .sorted ()
144144 .toArray (Instant []::new );
145- var maxSalts = (int ) Math .ceil (preRotationSalts .length * fraction );
146- var indexesToRotate = new ArrayList <Integer >();
145+ var indexesToRotate = new ArrayList <SaltEntry >();
147146
148147 var minLastUpdated = Instant .ofEpochMilli (0 );
149148 for (var maxLastUpdated : thresholds ) {
150- if (indexesToRotate .size () >= maxSalts ) break ;
149+ if (indexesToRotate .size () >= numSaltsToRotate ) break ;
151150
152- var maxIndexes = maxSalts - indexesToRotate .size ();
153- var saltsToRotate = selectIndexesToRotate (
154- preRotationSalts ,
155- minLastUpdated .toEpochMilli (),
156- maxLastUpdated .toEpochMilli (),
151+ var maxIndexes = numSaltsToRotate - indexesToRotate .size ();
152+ var saltsToRotate = pickSaltsToRotateInTimeWindow (
153+ rotatableSalts ,
157154 maxIndexes ,
158- rotatableSaltIndexes
155+ minLastUpdated .toEpochMilli (),
156+ maxLastUpdated .toEpochMilli ()
159157 );
160158 indexesToRotate .addAll (saltsToRotate );
161159 minLastUpdated = maxLastUpdated ;
162160 }
163161 return indexesToRotate ;
164162 }
165163
166- private List <Integer > selectIndexesToRotate (
167- SaltEntry [] salts ,
168- long minLastUpdated ,
169- long maxLastUpdated ,
164+ private List <SaltEntry > pickSaltsToRotateInTimeWindow (
165+ Set <SaltEntry > rotatableSalts ,
170166 int maxIndexes ,
171- List <Integer > rotatableSaltIndexes
172- ) {
173- var candidateIndexes = indexesForRotation (salts , minLastUpdated , maxLastUpdated , rotatableSaltIndexes );
174-
175- if (candidateIndexes .size () <= maxIndexes ) {
176- return candidateIndexes ;
177- }
178- Collections .shuffle (candidateIndexes );
179- return candidateIndexes .subList (0 , Math .min (maxIndexes , candidateIndexes .size ()));
180- }
181-
182- private List <Integer > indexesForRotation (
183- SaltEntry [] salts ,
184167 long minLastUpdated ,
185- long maxLastUpdated ,
186- List <Integer > rotatableSaltIndexes
168+ long maxLastUpdated
187169 ) {
188- var candidateIndexes = new ArrayList <Integer >();
189- for (int i = 0 ; i < salts .length ; i ++) {
190- var salt = salts [i ];
170+ var candidateSalts = new ArrayList <SaltEntry >();
171+ for (SaltEntry salt : rotatableSalts ) {
191172 var lastUpdated = salt .lastUpdated ();
192173 var isInTimeWindow = minLastUpdated <= lastUpdated && lastUpdated < maxLastUpdated ;
193- var isRotatable = rotatableSaltIndexes .contains (i );
194174
195- if (isInTimeWindow && isRotatable ) {
196- candidateIndexes .add (i );
175+ if (isInTimeWindow ) {
176+ candidateSalts .add (salt );
197177 }
198178 }
199- return candidateIndexes ;
200- }
201179
202- private boolean isRotatable (long nextEffective , SaltEntry salt ) {
203- if (this .isRefreshFromEnabled ) {
204- if (salt .refreshFrom () == null ) { // TODO: remove once refreshFrom is no longer optional
205- return true ;
206- }
207- return salt .refreshFrom () == nextEffective ;
180+ if (candidateSalts .size () <= maxIndexes ) {
181+ return candidateSalts ;
208182 }
209183
210- return true ;
184+ Collections .shuffle (candidateSalts );
185+ return candidateSalts .subList (0 , Math .min (maxIndexes , candidateSalts .size ()));
211186 }
212187
213- private void logSaltAgeCounts (String logEvent , LocalDate targetDate , Instant nextEffective , SaltEntry [] salts ) {
214- var formattedDate = DateTimeFormatter .ofPattern ("yyyy-MM-dd" ).format (targetDate );
215-
188+ private void logSaltAgeCounts (String saltCountType , TargetDate targetDate , Collection <SaltEntry > salts ) {
216189 var ages = new HashMap <Long , Long >(); // salt age to count
217190 for (var salt : salts ) {
218- long age = ( nextEffective . toEpochMilli () - salt . lastUpdated ()) / DAY_IN_MS ;
219- ages .put (age , ages .getOrDefault (age , 0L ) + 1 );
191+ long ageInDays = targetDate . ageOfSaltInDays ( salt ) ;
192+ ages .put (ageInDays , ages .getOrDefault (ageInDays , 0L ) + 1 );
220193 }
221194
222195 for (var entry : ages .entrySet ()) {
223- LOGGER .info ("{} target-date={} age={} salts ={}" , logEvent , formattedDate , entry .getKey (), entry .getValue ());
196+ LOGGER .info ("salt-count-type= {} target-date={} age={} salt-count ={}" , saltCountType , targetDate , entry .getKey (), entry .getValue ());
224197 }
225198 }
226199
227- private static SaltEntry [] onlySaltsAtIndexes (SaltEntry [] salts , List <Integer > saltIndexes ) {
228- SaltEntry [] selected = new SaltEntry [saltIndexes .size ()];
229- for (int i = 0 ; i < saltIndexes .size (); i ++) {
230- selected [i ] = salts [saltIndexes .get (i )];
200+ public static class TargetDate {
201+ private final static long DAY_IN_MS = Duration .ofDays (1 ).toMillis ();
202+
203+ private final LocalDate date ;
204+ private final long epochMs ;
205+ private final Instant instant ;
206+ private final String formatted ;
207+
208+ public TargetDate (LocalDate date ) {
209+ this .instant = date .atStartOfDay ().toInstant (ZoneOffset .UTC );
210+ this .date = date ;
211+ this .epochMs = instant .toEpochMilli ();
212+ this .formatted = date .format (DateTimeFormatter .ofPattern ("yyyy-MM-dd" ));
231213 }
232- return selected ;
233- }
234214
235- private void logSaltCounts ( LocalDate targetDate , Instant nextEffective , SaltEntry [] preRotationSalts , SaltEntry [] postRotationSalts , List < Integer > rotatableSaltIndexes , List < Integer > rotatedSaltIndexes ) {
236- var rotatableSalts = onlySaltsAtIndexes ( preRotationSalts , rotatableSaltIndexes );
237- logSaltAgeCounts ( "rotatable-salts" , targetDate , nextEffective , rotatableSalts );
215+ public static TargetDate of ( int year , int month , int day ) {
216+ return new TargetDate ( LocalDate . of ( year , month , day ) );
217+ }
238218
239- var rotatedSalts = onlySaltsAtIndexes (preRotationSalts , rotatedSaltIndexes );
240- logSaltAgeCounts ("rotated-salts" , targetDate , nextEffective , rotatedSalts );
219+ public long asEpochMs () {
220+ return epochMs ;
221+ }
241222
242- logSaltAgeCounts ("total-salts" , targetDate , nextEffective , postRotationSalts );
223+ public Instant asInstant () {
224+ return instant ;
225+ }
226+
227+ public long ageOfSaltInMs (SaltEntry salt ) {
228+ return this .asEpochMs () - salt .lastUpdated ();
229+ }
230+
231+ public long ageOfSaltInDays (SaltEntry salt ) {
232+ return ageOfSaltInMs (salt ) / DAY_IN_MS ;
233+ }
234+
235+ public TargetDate plusDays (int days ) {
236+ return new TargetDate (date .plusDays (days ));
237+ }
238+
239+ public TargetDate minusDays (int days ) {
240+ return new TargetDate (date .minusDays (days ));
241+ }
242+
243+ @ Override
244+ public String toString () {
245+ return formatted ;
246+ }
243247 }
244248
245249 @ Getter
0 commit comments