Skip to content

Commit de865be

Browse files
t-burchrrayst
andauthored
Add tenant ID validation for JWT authentication (#2435)
* Add tenant ID validation for JWT authentication Enhance JWT authentication by introducing tenant ID (`tid`) validation. Update `JwtAuthInterceptor` and `RequireAuth` classes to support `tid`. Add `TidValidator` to handle validation logic. Extend tests for tenant ID scenarios. * fixed logic --------- Co-authored-by: Tobias Polley <[email protected]>
1 parent 9fb2657 commit de865be

File tree

5 files changed

+98
-3
lines changed

5 files changed

+98
-3
lines changed

core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ public static String ERROR_JWT_VALUE_NOT_PRESENT(String key) {
5353
JwtRetriever jwtRetriever;
5454
Jwks jwks;
5555
String expectedAud;
56+
String expectedTid;
5657

5758
// should be used read only after init
5859
// Hashmap done on purpose as only here the read only thread safety is guaranteed
@@ -163,6 +164,10 @@ private JwtConsumer createValidator(RsaJsonWebKey key) {
163164
.setExpectedAudience(expectedAud);
164165
}
165166

167+
if (expectedTid != null && !expectedTid.isEmpty())
168+
jwtConsumerBuilder
169+
.registerValidator(new TidValidator(expectedTid));
170+
166171
return jwtConsumerBuilder.build();
167172
}
168173

@@ -192,6 +197,10 @@ public String getExpectedAud() {
192197
return expectedAud;
193198
}
194199

200+
public String getExpectedTid() {
201+
return expectedTid;
202+
}
203+
195204
/**
196205
* @description
197206
* <p>Expected audience ('aud') value of the token.</p>
@@ -203,6 +212,18 @@ public JwtAuthInterceptor setExpectedAud(String expectedAud) {
203212
return this;
204213
}
205214

215+
/**
216+
* @description
217+
* <p>Expected tenant ID ('tid') value of the token.</p>
218+
* @default not set
219+
* @example 67c869d3-0cd4-4a99-86db-088bed1a9601
220+
*/
221+
@MCAttribute
222+
public JwtAuthInterceptor setExpectedTid(String expectedTid) {
223+
this.expectedTid = expectedTid;
224+
return this;
225+
}
226+
206227
@Override
207228
public String getShortDescription() {
208229
return "Checks for a valid JWT.";
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package com.predic8.membrane.core.interceptor.jwt;
2+
3+
import org.jose4j.jwt.JwtClaims;
4+
import org.jose4j.jwt.MalformedClaimException;
5+
import org.jose4j.jwt.consumer.ErrorCodeValidator;
6+
import org.jose4j.jwt.consumer.JwtContext;
7+
8+
public class TidValidator implements ErrorCodeValidator {
9+
private static final Error MISSING_TID = new Error(-1, "No Tenant ID (tid) claim present.");
10+
11+
private final String acceptableTenantId;
12+
13+
public TidValidator(String acceptableTenantId)
14+
{
15+
this.acceptableTenantId = acceptableTenantId;
16+
}
17+
18+
@Override
19+
public Error validate(JwtContext jwtContext) throws MalformedClaimException
20+
{
21+
final JwtClaims jwtClaims = jwtContext.getJwtClaims();
22+
23+
if (!jwtClaims.hasClaim("tid"))
24+
return MISSING_TID;
25+
26+
String tid = jwtClaims.getClaimValue("tid", String.class);
27+
28+
if (acceptableTenantId.equals(tid)) {
29+
return null;
30+
} else {
31+
StringBuilder sb = new StringBuilder();
32+
sb.append("Tenant ID (tid) claim '").append(tid).append("' doesn't match the expected value '");
33+
sb.append(acceptableTenantId).append("' .");
34+
return new Error(-1, sb.toString());
35+
}
36+
}
37+
}

core/src/main/java/com/predic8/membrane/core/interceptor/oauth2client/OAuth2Resource2Interceptor.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ public class OAuth2Resource2Interceptor extends AbstractInterceptorWithSession {
5656
private static final Logger log = LoggerFactory.getLogger(OAuth2Resource2Interceptor.class.getName());
5757
public static final String ERROR_STATUS = "oauth2-error-status";
5858
public static final String EXPECTED_AUDIENCE = "oauth2-expected-audience";
59+
public static final String EXPECTED_TENANT_ID = "oauth2-expected-tenant-id";
5960
public static final String WANTED_SCOPE = "oauth2-wanted-scope";
6061

6162
private AuthorizationService auth;

core/src/main/java/com/predic8/membrane/core/interceptor/oauth2client/RequireAuth.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public class RequireAuth extends AbstractInterceptor {
3232
private static final Logger log = LoggerFactory.getLogger(RequireAuth.class.getName());
3333

3434
private String expectedAud;
35+
private String expectedTid;
3536
private OAuth2Resource2Interceptor oauth2;
3637
private JwtAuthInterceptor jwtAuth;
3738
private boolean required = true;
@@ -53,6 +54,7 @@ public void init() {
5354
jwtAuth = new JwtAuthInterceptor();
5455
jwtAuth.setJwks(jwks);
5556
jwtAuth.setExpectedAud(expectedAud);
57+
jwtAuth.setExpectedTid(expectedTid);
5658

5759
jwtAuth.init(router);
5860
}
@@ -63,6 +65,7 @@ public Outcome handleRequest(Exchange exc) {
6365
if (errorStatus != null)
6466
exc.setProperty(ERROR_STATUS, errorStatus);
6567
exc.setProperty(EXPECTED_AUDIENCE, expectedAud);
68+
exc.setProperty(EXPECTED_TENANT_ID, expectedTid);
6669
exc.setProperty(WANTED_SCOPE, scope);
6770
var outcome = oauth2.handleRequest(exc);
6871
if (outcome != Outcome.CONTINUE) {
@@ -84,6 +87,10 @@ public String getExpectedAud() {
8487
return expectedAud;
8588
}
8689

90+
public String getExpectedTid() {
91+
return expectedTid;
92+
}
93+
8794
@Required
8895
@MCAttribute
8996
public void setExpectedAud(String expectedAud) {
@@ -93,6 +100,14 @@ public void setExpectedAud(String expectedAud) {
93100
}
94101
}
95102

103+
@MCAttribute
104+
public void setExpectedTid(String expectedTid) {
105+
this.expectedTid = expectedTid;
106+
if (jwtAuth != null) {
107+
jwtAuth.setExpectedTid(expectedTid);
108+
}
109+
}
110+
96111
public OAuth2Resource2Interceptor getOauth2() {
97112
return oauth2;
98113
}

core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptorTest.java

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,12 @@ public class JwtAuthInterceptorTest{
4444
public static final String KID = "membrane";
4545
public static final String SUB_CLAIM_CONTENT = "Till, der fleissige Programmierer";
4646
private static final String AUDIENCE = "AusgestelltFuer";
47+
private static final String TENANT_ID = "Tenant12345";
4748

4849
public static Stream<Named<TestData>> data() throws Exception {
4950
return Stream.of(happyPath(),
5051
wrongAudience(),
52+
wrongTenantId(),
5153
manipulatedSignature(),
5254
unknownKey(),
5355
wrongKId(),
@@ -188,6 +190,20 @@ private static TestData wrongAudience() {
188190
);
189191
}
190192

193+
private static TestData wrongTenantId() {
194+
return new TestData(
195+
"wrongTenantId",
196+
(RsaJsonWebKey privateKey) -> new Request.Builder()
197+
.get("")
198+
.header("Authorization", "Bearer " + getSignedJwt(privateKey, getClaimsWithWrongTenantId()))
199+
.buildExchange(),
200+
(Exchange exc) -> {
201+
assertTrue(exc.getResponse().isUserError());
202+
assertNull(exc.getProperties().get("jwt"));
203+
assertEquals(JwtAuthInterceptor.ERROR_VALIDATION_FAILED, unpackBody(exc).get("detail"));
204+
}
205+
);
206+
}
191207

192208
private static TestData happyPath() {
193209
return new TestData(
@@ -262,11 +278,12 @@ private JwtAuthInterceptor createInterceptor(RsaJsonWebKey publicOnly) {
262278
jwks.getJwks().add(jwk);
263279
interceptor.setJwks(jwks);
264280
interceptor.setExpectedAud(AUDIENCE);
281+
interceptor.setExpectedTid(TENANT_ID);
265282
return interceptor;
266283
}
267284

268285
private static String getSignedJwt(RsaJsonWebKey privateKey) throws JoseException {
269-
return getSignedJwt(privateKey,createClaims(AUDIENCE));
286+
return getSignedJwt(privateKey,createClaims(AUDIENCE, TENANT_ID));
270287
}
271288

272289
private static String getSignedJwt(RsaJsonWebKey privateKey, JwtClaims claims) throws JoseException {
@@ -281,19 +298,23 @@ private static String getSignedJwt(RsaJsonWebKey privateKey, JwtClaims claims) t
281298
return jws.getCompactSerialization();
282299
}
283300

284-
private static JwtClaims createClaims(String audience){
301+
private static JwtClaims createClaims(String audience, String tenantId){
285302
JwtClaims claims = new JwtClaims();
286303
claims.setExpirationTimeMinutesInTheFuture(10);
287304
claims.setIssuedAtToNow();
288305
claims.setNotBeforeMinutesInThePast(30);
289306
claims.setSubject(SUB_CLAIM_CONTENT);
290307
claims.setAudience(audience);
308+
claims.setClaim("tid", tenantId);
291309

292310
return claims;
293311
}
294312

295313
private static JwtClaims getClaimsWithWrongAudience() {
296-
return createClaims(AUDIENCE + "1");
314+
return createClaims(AUDIENCE + "1", TENANT_ID);
297315
}
298316

317+
private static JwtClaims getClaimsWithWrongTenantId() {
318+
return createClaims(AUDIENCE, TENANT_ID + "1");
319+
}
299320
}

0 commit comments

Comments
 (0)