Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions src/main/java/com/uid2/admin/auth/AdminAuthMiddleware.java
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
package com.uid2.admin.auth;


import com.okta.jwt.*;
import com.uid2.admin.AdminConst;
import com.uid2.shared.auth.Role;
import io.vertx.core.Handler;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.RoutingContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.uid2.shared.audit.AuditParams;
import com.uid2.shared.audit.Audit;

import java.util.*;

public class AdminAuthMiddleware {
private static final Logger LOGGER = LoggerFactory.getLogger(AdminAuthMiddleware.class);
private final Map<Role, List<OktaGroup>> roleToOktaGroups = new EnumMap<>(Role.class);
private final AuthProvider authProvider;
private final String environment;
private final boolean isAuthDisabled;
private final Audit audit;

final Map<Role, List<OktaGroup>> roleToOktaGroups = new EnumMap<>(Role.class);
public AdminAuthMiddleware(AuthProvider authProvider, JsonObject config) {
this.authProvider = authProvider;
this.environment = config.getString("environment", "local");
Expand Down Expand Up @@ -59,7 +55,6 @@ public Handler<RoutingContext> handle(Handler<RoutingContext> handler, Role... r
return this.handle(handler, new AuditParams(), roles);
}


private Handler<RoutingContext> logAndHandle(Handler<RoutingContext> handler, AuditParams params) {
return ctx -> {
ctx.addBodyEndHandler(v -> this.audit.log(ctx, params));
Expand All @@ -73,6 +68,7 @@ private static class AdminAuthHandler {
private final Set<Role> allowedRoles;
private final Map<Role, List<OktaGroup>> roleToOktaGroups;
private final AuthProvider authProvider;

private AdminAuthHandler(Handler<RoutingContext> handler, AuthProvider authProvider, Set<Role> allowedRoles,
String environment, Map<Role, List<OktaGroup>> roleToOktaGroups) {
this.environment = environment;
Expand All @@ -96,6 +92,7 @@ public static String extractBearerToken(String headerValue) {
}
}
}

private boolean isAuthorizedUser(List<String> userAssignedGroups) {
for (Role role : allowedRoles) {
if (roleToOktaGroups.containsKey(role)) {
Expand All @@ -109,6 +106,7 @@ private boolean isAuthorizedUser(List<String> userAssignedGroups) {
}
return false;
}

private boolean isAuthorizedService(List<String> scopes) {
for (String scope : scopes) {
if (allowedRoles.contains(OktaCustomScope.fromName(scope).getRole())) {
Expand All @@ -117,21 +115,22 @@ private boolean isAuthorizedService(List<String> scopes) {
}
return false;
}

public void handle(RoutingContext rc) {
// human user
String idToken = null;
if(rc.user() != null && rc.user().principal() != null) {
if (rc.user() != null && rc.user().principal() != null) {
idToken = rc.user().principal().getString("id_token");
}
if(idToken != null) {
if (idToken != null) {
validateIdToken(rc, idToken);
return;
}

// machine user
String authHeaderValue = rc.request().getHeader("Authorization");
String accessToken = extractBearerToken(authHeaderValue);
if(accessToken == null) {
if (accessToken == null) {
rc.response().putHeader("REQUIRES_AUTH", "1").setStatusCode(401).end();
return;
}
Expand All @@ -146,7 +145,7 @@ private void validateAccessToken(RoutingContext rc, String accessToken) {
rc.response().setStatusCode(401).end();
return;
}
if(jwt.getClaims().get("environment") == null || !jwt.getClaims().get("environment").toString().equals(environment)) {
if (jwt.getClaims().get("environment") == null || !jwt.getClaims().get("environment").toString().equals(environment)) {
rc.response().setStatusCode(401).end();
return;
}
Expand All @@ -155,7 +154,7 @@ private void validateAccessToken(RoutingContext rc, String accessToken) {
serviceAccountDetails.put("scope", scopes);
serviceAccountDetails.put("client_id", jwt.getClaims().get("client_id"));
rc.put("user_details", serviceAccountDetails);
if(isAuthorizedService(scopes)) {
if (isAuthorizedService(scopes)) {
innerHandler.handle(rc);
} else {
rc.response().setStatusCode(401).end();
Expand All @@ -171,7 +170,7 @@ private void validateIdToken(RoutingContext rc, String idToken) {
rc.response().putHeader("REQUIRES_AUTH", "1").setStatusCode(401).end();
return;
}
if(jwt.getClaims().get("environment") == null || !jwt.getClaims().get("environment").toString().equals(environment)) {
if (jwt.getClaims().get("environment") == null || !jwt.getClaims().get("environment").toString().equals(environment)) {
rc.response().setStatusCode(401).end();
return;
}
Expand All @@ -181,7 +180,7 @@ private void validateIdToken(RoutingContext rc, String idToken) {
userDetails.put("email", jwt.getClaims().get("email"));
userDetails.put("sub", jwt.getClaims().get("sub"));
rc.put("user_details", userDetails);
if(isAuthorizedUser(groups)) {
if (isAuthorizedUser(groups)) {
innerHandler.handle(rc);
} else {
rc.response().setStatusCode(401).end();
Expand Down
14 changes: 6 additions & 8 deletions src/main/java/com/uid2/admin/salt/SaltRotation.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
public class SaltRotation {
private static final long THIRTY_DAYS_IN_MS = Duration.ofDays(30).toMillis();
private static final double MAX_SALT_PERCENTAGE = 0.8;
private final boolean ENABLE_V4_RAW_UID;
private final boolean enableV4RawUid;

private final IKeyGenerator keyGenerator;

private static final Logger LOGGER = LoggerFactory.getLogger(SaltRotation.class);

public SaltRotation(IKeyGenerator keyGenerator, JsonObject config) {
this.keyGenerator = keyGenerator;
this.ENABLE_V4_RAW_UID = config.getBoolean(AdminConst.ENABLE_V4_RAW_UID, false);
this.enableV4RawUid = config.getBoolean(AdminConst.ENABLE_V4_RAW_UID, false);
}

public Result rotateSalts(
Expand Down Expand Up @@ -143,10 +143,9 @@ private long calculateRefreshFrom(SaltEntry bucket, TargetDate targetDate) {

private String calculateCurrentSalt(SaltEntry bucket, boolean shouldRotate) throws Exception {
if (shouldRotate) {
if (ENABLE_V4_RAW_UID) {
if (enableV4RawUid) {
return null;
}
else {
} else {
return this.keyGenerator.generateRandomKeyString(32);
}
}
Expand All @@ -165,10 +164,10 @@ private String calculatePreviousSalt(SaltEntry bucket, boolean shouldRotate, Tar

private SaltEntry.KeyMaterial calculateCurrentKeySalt(SaltEntry bucket, boolean shouldRotate, KeyIdGenerator keyIdGenerator) throws Exception {
if (shouldRotate) {
if (ENABLE_V4_RAW_UID) {
if (enableV4RawUid) {
return new SaltEntry.KeyMaterial(
keyIdGenerator.getNextKeyId(),
this.keyGenerator.generateRandomKeyString(32),
this.keyGenerator.generateRandomKeyString(24),
this.keyGenerator.generateRandomKeyString(32)
);
} else {
Expand Down Expand Up @@ -253,7 +252,6 @@ private void logSaltAges(String saltCountType, TargetDate targetDate, Collection
}
}


/** Logging to monitor migration of buckets from salts (old format - v2/v3) to encryption keys (new format - v4) **/
private void logBucketFormatCount(TargetDate targetDate, SaltEntry[] postRotationBuckets) {
int totalKeys = 0, totalSalts = 0, totalPreviousKeys = 0, totalPreviousSalts = 0;
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/uid2/admin/vertx/service/SaltService.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ public void setupRoutes(Router router) {
}
}, new AuditParams(List.of(), Collections.emptyList()), Role.MAINTAINER));

router.post(API_SALT_ROTATE.toString()).blockingHandler(auth.handle((ctx) -> {
router.post(API_SALT_ROTATE.toString()).blockingHandler(auth.handle(ctx -> {
synchronized (writeLock) {
this.handleSaltRotate(ctx);
}
}, new AuditParams(List.of("fraction", "min_ages_in_seconds", "target_date"), Collections.emptyList()), Role.SUPER_USER, Role.SECRET_ROTATION));
}, new AuditParams(List.of("fraction", "target_date"), Collections.emptyList()), Role.SUPER_USER, Role.SECRET_ROTATION));
}

private void handleSaltSnapshots(RoutingContext rc) {
Expand Down
29 changes: 16 additions & 13 deletions src/test/java/com/uid2/admin/salt/SaltServiceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;

public class SaltServiceTest extends ServiceTestBase {
class SaltServiceTest extends ServiceTestBase {
private final TargetDate utcTomorrow = TargetDate.now().plusDays(1);
@Mock RotatingSaltProvider saltProvider;
@Mock SaltRotation saltRotation;

@Mock
private RotatingSaltProvider saltProvider;
@Mock
private SaltRotation saltRotation;

@Override
protected IService createService() {
Expand Down Expand Up @@ -77,7 +80,7 @@ void rotateSalts(Vertx vertx, VertxTestContext testContext) throws Exception {
var result = SaltRotation.Result.fromSnapshot(addedSnapshots[0].build());
when(saltRotation.rotateSalts(any(), any(), eq(0.2), eq(utcTomorrow))).thenReturn(result);

post(vertx, testContext, "api/salt/rotate?min_ages_in_seconds=50,60,70&fraction=0.2", "", response -> {
post(vertx, testContext, "api/salt/rotate?fraction=0.2", "", response -> {
assertEquals(200, response.statusCode());
checkSnapshotsResponse(addedSnapshots, new Object[]{response.bodyAsJsonObject()});
verify(saltStoreWriter).upload(any());
Expand All @@ -100,7 +103,7 @@ void rotateSaltsNoNewSnapshot(Vertx vertx, VertxTestContext testContext) throws
var result = SaltRotation.Result.noSnapshot("test");
when(saltRotation.rotateSalts(any(), any(), eq(0.2), eq(utcTomorrow))).thenReturn(result);

post(vertx, testContext, "api/salt/rotate?min_ages_in_seconds=50,60,70&fraction=0.2", "", response -> {
post(vertx, testContext, "api/salt/rotate?fraction=0.2", "", response -> {
assertEquals(200, response.statusCode());
JsonObject jo = response.bodyAsJsonObject();
assertFalse(jo.containsKey("effective"));
Expand All @@ -127,31 +130,31 @@ void rotateSaltsWithSpecificTargetDate(Vertx vertx, VertxTestContext testContext
var result = SaltRotation.Result.fromSnapshot(addedSnapshots[0].build());
when(saltRotation.rotateSalts(any(), any(), eq(0.2), eq(targetDate()))).thenReturn(result);

post(vertx, testContext, "api/salt/rotate?min_ages_in_seconds=50,60,70&fraction=0.2&target_date=2025-01-01", "", response -> {
post(vertx, testContext, "api/salt/rotate?fraction=0.2&target_date=2025-01-01", "", response -> {
assertEquals(200, response.statusCode());
testContext.completeNow();
});
}

@Test
void rotateSaltsWithDefaultAgeThresholds(Vertx vertx, VertxTestContext testContext) throws Exception {
fakeAuth(Role.SUPER_USER);
fakeAuth(Role.SUPER_USER);

final SaltSnapshotBuilder lastSnapshot = SaltSnapshotBuilder.start().effective(daysEarlier(1)).expires(daysLater(6)).entries(1, daysEarlier(1));
setSnapshots(lastSnapshot);

var result = SaltRotation.Result.fromSnapshot(SaltSnapshotBuilder.start().effective(targetDate()).expires(daysEarlier(7)).entries(1, targetDate()).build());

Duration[] expectedDefaultAgeThresholds = new Duration[]{
Duration.ofDays(30), Duration.ofDays(60), Duration.ofDays(90), Duration.ofDays(120),
Duration.ofDays(150), Duration.ofDays(180), Duration.ofDays(210), Duration.ofDays(240),
Duration.ofDays(270), Duration.ofDays(300), Duration.ofDays(330), Duration.ofDays(360),
Duration.ofDays(390)
Duration.ofDays(30), Duration.ofDays(60), Duration.ofDays(90), Duration.ofDays(120),
Duration.ofDays(150), Duration.ofDays(180), Duration.ofDays(210), Duration.ofDays(240),
Duration.ofDays(270), Duration.ofDays(300), Duration.ofDays(330), Duration.ofDays(360),
Duration.ofDays(390)
};

when(saltRotation.rotateSalts(any(), eq(expectedDefaultAgeThresholds), eq(0.2), eq(utcTomorrow))).thenReturn(result);

post(vertx, testContext, "api/salt/rotate?min_ages_in_seconds=50,60,70&fraction=0.2", "", response -> {
post(vertx, testContext, "api/salt/rotate?fraction=0.2", "", response -> {
verify(saltRotation).rotateSalts(any(), eq(expectedDefaultAgeThresholds), eq(0.2), eq(utcTomorrow));
assertEquals(200, response.statusCode());
testContext.completeNow();
Expand Down
14 changes: 2 additions & 12 deletions webroot/adm/salt.html
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,6 @@ <h1>UID2 Env - Salt Management</h1>
defaultValue: defaultTargetDate
};

const minAgesMultilineInput = {
name: 'minAges',
label: 'Min Ages (seconds)',
required: true,
defaultValue: '2592000,5184000,7776000,10368000,12960000,15552000,18144000,20736000,23328000,25920000,28512000,31104000,33696000',
type: 'multi-line'
};

const operationConfig = {
read: [
{
Expand Down Expand Up @@ -81,16 +73,14 @@ <h1>UID2 Env - Salt Management</h1>
role: 'superuser',
inputs: [
fractionInput,
targetDateInput,
minAgesMultilineInput
targetDateInput
],
apiCall: {
method: 'POST',
getUrl: (inputs) => {
const minAges = encodeURIComponent(inputs.minAges);
const fraction = encodeURIComponent(inputs.fraction);
const targetDate = encodeURIComponent(inputs.targetDate);
return `/api/salt/rotate?min_ages_in_seconds=${minAges}&fraction=${fraction}&target_date=${targetDate}`;
return `/api/salt/rotate?fraction=${fraction}&target_date=${targetDate}`;
}
}
}
Expand Down