44import com .uid2 .shared .secret .IKeyGenerator ;
55import com .uid2 .shared .store .salt .RotatingSaltProvider ;
66
7- import java .time .*;
7+ import com .uid2 .shared .store .salt .RotatingSaltProvider .SaltSnapshot ;
8+
9+ import java .time .Duration ;
10+ import java .time .Instant ;
11+ import java .time .LocalDate ;
12+ import java .time .ZoneOffset ;
813import java .time .temporal .ChronoUnit ;
914import java .util .*;
1015import java .util .stream .IntStream ;
@@ -30,62 +35,94 @@ public Result rotateSalts(RotatingSaltProvider.SaltSnapshot lastSnapshot,
3035 return Result .noSnapshot ("cannot create a new salt snapshot with effective timestamp equal or prior to that of an existing snapshot" );
3136 }
3237
38+ List <Integer > saltIndexesToRotate = pickSaltIndexesToRotate (lastSnapshot , nextEffective , minAges , fraction );
39+ if (saltIndexesToRotate .isEmpty ()) {
40+ return Result .noSnapshot ("all salts are below min rotation age" );
41+ }
42+
43+ var updatedSalts = updateSalts (lastSnapshot .getAllRotatingSalts (), saltIndexesToRotate , nextEffective .toEpochMilli ());
44+
45+ SaltSnapshot nextSnapshot = new SaltSnapshot (
46+ nextEffective ,
47+ nextExpires ,
48+ updatedSalts ,
49+ lastSnapshot .getFirstLevelSalt ());
50+ return Result .fromSnapshot (nextSnapshot );
51+ }
52+
53+ private SaltEntry [] updateSalts (SaltEntry [] oldSalts , List <Integer > saltIndexesToRotate , long nextEffective ) throws Exception {
54+ var updatedSalts = new SaltEntry [oldSalts .length ];
55+
56+ for (int i = 0 ; i < oldSalts .length ; i ++) {
57+ var shouldRotate = saltIndexesToRotate .contains (i );
58+ updatedSalts [i ] = updateSalt (oldSalts [i ], shouldRotate , nextEffective );
59+ }
60+ return updatedSalts ;
61+ }
62+
63+ private SaltEntry updateSalt (SaltEntry oldSalt , boolean shouldRotate , long nextEffective ) throws Exception {
64+ var currentSalt = shouldRotate ? this .keyGenerator .generateRandomKeyString (32 ) : oldSalt .currentSalt ();
65+ var lastUpdated = shouldRotate ? nextEffective : oldSalt .lastUpdated ();
66+
67+ return new SaltEntry (
68+ oldSalt .id (),
69+ oldSalt .hashedId (),
70+ lastUpdated ,
71+ currentSalt ,
72+ null ,
73+ null ,
74+ null ,
75+ null
76+ );
77+ }
78+
79+ private List <Integer > pickSaltIndexesToRotate (
80+ SaltSnapshot lastSnapshot ,
81+ Instant nextEffective ,
82+ Duration [] minAges ,
83+ double fraction ) {
3384 final Instant [] thresholds = Arrays .stream (minAges )
34- .map (a -> nextEffective .minusSeconds (a .getSeconds ()))
85+ .map (age -> nextEffective .minusSeconds (age .getSeconds ()))
3586 .sorted ()
3687 .toArray (Instant []::new );
37- final int maxSalts = (int )Math .ceil (lastSnapshot .getAllRotatingSalts ().length * fraction );
38- final List <Integer > entryIndexes = new ArrayList <>();
88+ final int maxSalts = (int ) Math .ceil (lastSnapshot .getAllRotatingSalts ().length * fraction );
89+ final List <Integer > indexesToRotate = new ArrayList <>();
3990
4091 Instant minLastUpdated = Instant .ofEpochMilli (0 );
4192 for (Instant threshold : thresholds ) {
42- if (entryIndexes .size () >= maxSalts ) break ;
43- addIndexesToRotate (entryIndexes , lastSnapshot ,
44- minLastUpdated .toEpochMilli (), threshold .toEpochMilli (),
45- maxSalts - entryIndexes .size ());
93+ if (indexesToRotate .size () >= maxSalts ) break ;
94+ addIndexesToRotate (
95+ indexesToRotate ,
96+ lastSnapshot ,
97+ minLastUpdated .toEpochMilli (),
98+ threshold .toEpochMilli (),
99+ maxSalts - indexesToRotate .size ()
100+ );
46101 minLastUpdated = threshold ;
47102 }
48-
49- if (entryIndexes .isEmpty ()) return Result .noSnapshot ("all salts are below min rotation age" );
50-
51- return Result .fromSnapshot (createRotatedSnapshot (lastSnapshot , nextEffective , nextExpires , entryIndexes ));
103+ return indexesToRotate ;
52104 }
53105
54106 private void addIndexesToRotate (List <Integer > entryIndexes ,
55- RotatingSaltProvider . SaltSnapshot lastSnapshot ,
107+ SaltSnapshot lastSnapshot ,
56108 long minLastUpdated ,
57109 long maxLastUpdated ,
58110 int maxIndexes ) {
59111 final SaltEntry [] entries = lastSnapshot .getAllRotatingSalts ();
60112 final List <Integer > candidateIndexes = IntStream .range (0 , entries .length )
61113 .filter (i -> isBetween (entries [i ].lastUpdated (), minLastUpdated , maxLastUpdated ))
62- .boxed ().collect (toList ());
114+ .boxed ()
115+ .collect (toList ());
63116 if (candidateIndexes .size () <= maxIndexes ) {
64117 entryIndexes .addAll (candidateIndexes );
65118 return ;
66119 }
67120 Collections .shuffle (candidateIndexes );
68- candidateIndexes .stream ().limit (maxIndexes ).forEachOrdered (i -> entryIndexes . add ( i ) );
121+ candidateIndexes .stream ().limit (maxIndexes ).forEachOrdered (entryIndexes :: add );
69122 }
70123
71124 private static boolean isBetween (long t , long minInclusive , long maxExclusive ) {
72125 return minInclusive <= t && t < maxExclusive ;
73126 }
74127
75- private RotatingSaltProvider .SaltSnapshot createRotatedSnapshot (RotatingSaltProvider .SaltSnapshot lastSnapshot ,
76- Instant nextEffective ,
77- Instant nextExpires ,
78- List <Integer > entryIndexes ) throws Exception {
79- final long lastUpdated = nextEffective .toEpochMilli ();
80- final RotatingSaltProvider .SaltSnapshot nextSnapshot = new RotatingSaltProvider .SaltSnapshot (
81- nextEffective , nextExpires ,
82- Arrays .copyOf (lastSnapshot .getAllRotatingSalts (), lastSnapshot .getAllRotatingSalts ().length ),
83- lastSnapshot .getFirstLevelSalt ());
84- for (Integer i : entryIndexes ) {
85- final SaltEntry oldSalt = nextSnapshot .getAllRotatingSalts ()[i ];
86- final String secret = this .keyGenerator .generateRandomKeyString (32 );
87- nextSnapshot .getAllRotatingSalts ()[i ] = new SaltEntry (oldSalt .id (), oldSalt .hashedId (), lastUpdated , secret , null , null , null , null );
88- }
89- return nextSnapshot ;
90- }
91128}
0 commit comments