Skip to content

Commit e6250ef

Browse files
committed
feat: PKCE support refresh_token.
1 parent 84727b2 commit e6250ef

File tree

3 files changed

+119
-57
lines changed

3 files changed

+119
-57
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import org.springframework.security.core.AuthenticationException;
3030
import org.springframework.security.oauth2.core.AuthorizationGrantType;
3131
import org.springframework.security.oauth2.core.ClaimAccessor;
32-
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
3332
import org.springframework.security.oauth2.core.OAuth2AccessToken;
3433
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
3534
import org.springframework.security.oauth2.core.OAuth2Error;
@@ -142,17 +141,34 @@ public Authentication authenticate(Authentication authentication) throws Authent
142141
}
143142

144143
if (!authorizationCode.isActive()) {
144+
if (authorizationCode.isInvalidated()) {
145+
OAuth2Authorization.Token<? extends OAuth2Token> token = (authorization.getRefreshToken() != null)
146+
? authorization.getRefreshToken() : authorization.getAccessToken();
147+
if (token != null) {
148+
// Invalidate the access (and refresh) token as the client is
149+
// attempting to use the authorization code more than once
150+
authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, token.getToken());
151+
this.authorizationService.save(authorization);
152+
if (this.logger.isWarnEnabled()) {
153+
this.logger.warn(LogMessage.format(
154+
"Invalidated authorization token(s) previously issued to registered client '%s'",
155+
registeredClient.getId()));
156+
}
157+
}
158+
}
145159
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT);
146160
}
147161

148162
if (this.logger.isTraceEnabled()) {
149163
this.logger.trace("Validated token request parameters");
150164
}
151165

166+
Authentication principal = authorization.getAttribute(Principal.class.getName());
167+
152168
// @formatter:off
153169
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
154170
.registeredClient(registeredClient)
155-
.principal(authorization.getAttribute(Principal.class.getName()))
171+
.principal(principal)
156172
.authorizationServerContext(AuthorizationServerContextHolder.getContext())
157173
.authorization(authorization)
158174
.authorizedScopes(authorization.getAuthorizedScopes())
@@ -181,30 +197,31 @@ public Authentication authenticate(Authentication authentication) throws Authent
181197
if (generatedAccessToken instanceof ClaimAccessor) {
182198
authorizationBuilder.token(accessToken, (metadata) ->
183199
metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, ((ClaimAccessor) generatedAccessToken).getClaims()));
184-
} else {
200+
}
201+
else {
185202
authorizationBuilder.accessToken(accessToken);
186203
}
187204

188205
// ----- Refresh token -----
189206
OAuth2RefreshToken refreshToken = null;
190-
if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN) &&
191-
// Do not issue refresh token to public client
192-
!clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) {
193-
207+
// Do not issue refresh token to public client
208+
if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) {
194209
tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.REFRESH_TOKEN).build();
195210
OAuth2Token generatedRefreshToken = this.tokenGenerator.generate(tokenContext);
196-
if (!(generatedRefreshToken instanceof OAuth2RefreshToken)) {
197-
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
198-
"The token generator failed to generate the refresh token.", ERROR_URI);
199-
throw new OAuth2AuthenticationException(error);
200-
}
211+
if (generatedRefreshToken != null) {
212+
if (!(generatedRefreshToken instanceof OAuth2RefreshToken)) {
213+
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
214+
"The token generator failed to generate a valid refresh token.", ERROR_URI);
215+
throw new OAuth2AuthenticationException(error);
216+
}
201217

202-
if (this.logger.isTraceEnabled()) {
203-
this.logger.trace("Generated refresh token");
204-
}
218+
if (this.logger.isTraceEnabled()) {
219+
this.logger.trace("Generated refresh token");
220+
}
205221

206-
refreshToken = (OAuth2RefreshToken) generatedRefreshToken;
207-
authorizationBuilder.refreshToken(refreshToken);
222+
refreshToken = (OAuth2RefreshToken) generatedRefreshToken;
223+
authorizationBuilder.refreshToken(refreshToken);
224+
}
208225
}
209226

210227
// ----- ID token -----
@@ -231,7 +248,8 @@ public Authentication authenticate(Authentication authentication) throws Authent
231248
generatedIdToken.getExpiresAt(), ((Jwt) generatedIdToken).getClaims());
232249
authorizationBuilder.token(idToken, (metadata) ->
233250
metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims()));
234-
} else {
251+
}
252+
else {
235253
idToken = null;
236254
}
237255

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2RefreshTokenGenerator.java

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020-2022 the original author or authors.
2+
* Copyright 2020-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -21,8 +21,11 @@
2121
import org.springframework.lang.Nullable;
2222
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
2323
import org.springframework.security.crypto.keygen.StringKeyGenerator;
24+
import org.springframework.security.oauth2.core.AuthorizationGrantType;
25+
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
2426
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
2527
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
28+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
2629

2730
/**
2831
* An {@link OAuth2TokenGenerator} that generates an {@link OAuth2RefreshToken}.
@@ -33,18 +36,35 @@
3336
* @see OAuth2RefreshToken
3437
*/
3538
public final class OAuth2RefreshTokenGenerator implements OAuth2TokenGenerator<OAuth2RefreshToken> {
36-
private final StringKeyGenerator refreshTokenGenerator =
37-
new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
39+
40+
private final StringKeyGenerator refreshTokenGenerator = new Base64StringKeyGenerator(
41+
Base64.getUrlEncoder().withoutPadding(), 96);
3842

3943
@Nullable
4044
@Override
4145
public OAuth2RefreshToken generate(OAuth2TokenContext context) {
4246
if (!OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
4347
return null;
4448
}
49+
if (isPublicClientForAuthorizationCodeGrant(context)) {
50+
// Do not issue refresh token to public client
51+
return null;
52+
}
53+
4554
Instant issuedAt = Instant.now();
4655
Instant expiresAt = issuedAt.plus(context.getRegisteredClient().getTokenSettings().getRefreshTokenTimeToLive());
4756
return new OAuth2RefreshToken(this.refreshTokenGenerator.generateKey(), issuedAt, expiresAt);
4857
}
4958

59+
private static boolean isPublicClientForAuthorizationCodeGrant(OAuth2TokenContext context) {
60+
// @formatter:off
61+
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getAuthorizationGrantType()) &&
62+
(context.getAuthorizationGrant().getPrincipal() instanceof OAuth2ClientAuthenticationToken)) {
63+
return ((OAuth2ClientAuthenticationToken) context.getAuthorizationGrant()
64+
.getPrincipal()).getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE);
65+
}
66+
// @formatter:on
67+
return false;
68+
}
69+
5070
}

oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.security.core.Authentication;
3434
import org.springframework.security.oauth2.core.AuthorizationGrantType;
3535
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
36+
import org.springframework.security.oauth2.core.OAuth2AccessToken;
3637
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
3738
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
3839
import org.springframework.security.oauth2.core.OAuth2Token;
@@ -74,6 +75,8 @@
7475
import static org.assertj.core.api.Assertions.entry;
7576
import static org.mockito.ArgumentMatchers.any;
7677
import static org.mockito.ArgumentMatchers.eq;
78+
import static org.mockito.BDDMockito.given;
79+
import static org.mockito.BDDMockito.willAnswer;
7780
import static org.mockito.Mockito.doAnswer;
7881
import static org.mockito.Mockito.mock;
7982
import static org.mockito.Mockito.spy;
@@ -118,7 +121,8 @@ public OAuth2Token generate(OAuth2TokenContext context) {
118121
});
119122
this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
120123
this.authorizationService, this.tokenGenerator);
121-
AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().issuer("https://provider.com").build();
124+
AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder()
125+
.issuer("https://provider.com").build();
122126
AuthorizationServerContextHolder.setContext(new TestAuthorizationServerContext(authorizationServerSettings, null));
123127
}
124128

@@ -302,7 +306,8 @@ public void authenticateWhenAccessTokenNotGeneratedThenThrowOAuth2Authentication
302306
OAuth2TokenContext context = answer.getArgument(0);
303307
if (OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType())) {
304308
return null;
305-
} else {
309+
}
310+
else {
306311
return answer.callRealMethod();
307312
}
308313
}).when(this.tokenGenerator).generate(any());
@@ -317,36 +322,39 @@ public void authenticateWhenAccessTokenNotGeneratedThenThrowOAuth2Authentication
317322
}
318323

319324
@Test
320-
public void authenticateWhenRefreshTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
325+
public void authenticateWhenInvalidRefreshTokenGeneratedThenThrowOAuth2AuthenticationException() {
321326
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
322327
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
323-
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
324-
.thenReturn(authorization);
328+
given(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
329+
.willReturn(authorization);
325330

326-
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
327-
registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
328-
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
329-
OAuth2AuthorizationRequest.class.getName());
330-
OAuth2AuthorizationCodeAuthenticationToken authentication =
331-
new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null);
331+
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient,
332+
ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
333+
OAuth2AuthorizationRequest authorizationRequest = authorization
334+
.getAttribute(OAuth2AuthorizationRequest.class.getName());
335+
OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
336+
AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null);
332337

333-
when(this.jwtEncoder.encode(any())).thenReturn(createJwt());
338+
given(this.jwtEncoder.encode(any())).willReturn(createJwt());
334339

335-
doAnswer(answer -> {
340+
willAnswer((answer) -> {
336341
OAuth2TokenContext context = answer.getArgument(0);
337342
if (OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
338-
return null;
339-
} else {
343+
return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(),
344+
Instant.now().plusSeconds(300));
345+
}
346+
else {
340347
return answer.callRealMethod();
341348
}
342-
}).when(this.tokenGenerator).generate(any());
349+
}).given(this.tokenGenerator).generate(any());
343350

344351
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
345352
.isInstanceOf(OAuth2AuthenticationException.class)
346-
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
347-
.satisfies(error -> {
353+
.extracting((ex) -> ((OAuth2AuthenticationException) ex).getError())
354+
.satisfies((error) -> {
348355
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
349-
assertThat(error.getDescription()).contains("The token generator failed to generate the refresh token.");
356+
assertThat(error.getDescription())
357+
.contains("The token generator failed to generate a valid refresh token.");
350358
});
351359
}
352360

@@ -370,7 +378,8 @@ public void authenticateWhenIdTokenNotGeneratedThenThrowOAuth2AuthenticationExce
370378
OAuth2TokenContext context = answer.getArgument(0);
371379
if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) {
372380
return null;
373-
} else {
381+
}
382+
else {
374383
return answer.callRealMethod();
375384
}
376385
}).when(this.tokenGenerator).generate(any());
@@ -428,12 +437,16 @@ public void authenticateWhenValidCodeThenReturnAccessToken() {
428437
verify(this.authorizationService).save(authorizationCaptor.capture());
429438
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
430439

431-
assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
440+
assertThat(accessTokenAuthentication.getRegisteredClient()
441+
.getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
432442
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
433-
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
434-
assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(authorization.getAuthorizedScopes());
443+
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()
444+
.getToken());
445+
assertThat(accessTokenAuthentication.getAccessToken()
446+
.getScopes()).isEqualTo(authorization.getAuthorizedScopes());
435447
assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull();
436-
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
448+
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken()
449+
.getToken());
437450
OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
438451
assertThat(authorizationCode.isInvalidated()).isTrue();
439452
}
@@ -443,7 +456,8 @@ public void authenticateWhenValidCodeAndAuthenticationRequestThenReturnIdToken()
443456
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
444457
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
445458
"code", Instant.now(), Instant.now().plusSeconds(120));
446-
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient, authorizationCode).build();
459+
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient, authorizationCode)
460+
.build();
447461
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
448462
.thenReturn(authorization);
449463

@@ -490,19 +504,22 @@ public void authenticateWhenValidCodeAndAuthenticationRequestThenReturnIdToken()
490504
assertThat(idTokenContext.getJwsHeader()).isNotNull();
491505
assertThat(idTokenContext.getClaims()).isNotNull();
492506

493-
verify(this.jwtEncoder, times(2)).encode(any()); // Access token and ID Token
507+
verify(this.jwtEncoder, times(2)).encode(any()); // Access token and ID Token
494508

495509
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
496510
verify(this.authorizationService).save(authorizationCaptor.capture());
497511
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
498512

499-
assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
513+
assertThat(accessTokenAuthentication.getRegisteredClient()
514+
.getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
500515
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
501-
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
516+
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()
517+
.getToken());
502518
Set<String> accessTokenScopes = new HashSet<>(updatedAuthorization.getAuthorizedScopes());
503519
assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(accessTokenScopes);
504520
assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull();
505-
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
521+
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken()
522+
.getToken());
506523
OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCodeToken = updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
507524
assertThat(authorizationCodeToken.isInvalidated()).isTrue();
508525
OAuth2Authorization.Token<OidcIdToken> idToken = updatedAuthorization.getToken(OidcIdToken.class);
@@ -558,10 +575,13 @@ public void authenticateWhenPublicClientThenRefreshTokenNotIssued() {
558575
verify(this.authorizationService).save(authorizationCaptor.capture());
559576
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
560577

561-
assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
578+
assertThat(accessTokenAuthentication.getRegisteredClient()
579+
.getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
562580
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
563-
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
564-
assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(authorization.getAuthorizedScopes());
581+
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()
582+
.getToken());
583+
assertThat(accessTokenAuthentication.getAccessToken()
584+
.getScopes()).isEqualTo(authorization.getAuthorizedScopes());
565585
assertThat(accessTokenAuthentication.getRefreshToken()).isNull();
566586
OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
567587
assertThat(authorizationCode.isInvalidated()).isTrue();
@@ -600,13 +620,17 @@ public void authenticateWhenTokenTimeToLiveConfiguredThenTokenExpirySet() {
600620
verify(this.authorizationService).save(authorizationCaptor.capture());
601621
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
602622

603-
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
604-
Instant expectedAccessTokenExpiresAt = accessTokenAuthentication.getAccessToken().getIssuedAt().plus(accessTokenTTL);
623+
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()
624+
.getToken());
625+
Instant expectedAccessTokenExpiresAt = accessTokenAuthentication.getAccessToken().getIssuedAt()
626+
.plus(accessTokenTTL);
605627
assertThat(accessTokenAuthentication.getAccessToken().getExpiresAt()).isBetween(
606628
expectedAccessTokenExpiresAt.minusSeconds(1), expectedAccessTokenExpiresAt.plusSeconds(1));
607629

608-
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
609-
Instant expectedRefreshTokenExpiresAt = accessTokenAuthentication.getRefreshToken().getIssuedAt().plus(refreshTokenTTL);
630+
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken()
631+
.getToken());
632+
Instant expectedRefreshTokenExpiresAt = accessTokenAuthentication.getRefreshToken().getIssuedAt()
633+
.plus(refreshTokenTTL);
610634
assertThat(accessTokenAuthentication.getRefreshToken().getExpiresAt()).isBetween(
611635
expectedRefreshTokenExpiresAt.minusSeconds(1), expectedRefreshTokenExpiresAt.plusSeconds(1));
612636
}

0 commit comments

Comments
 (0)