Skip to content

Commit c0e1704

Browse files
michaelqi793michaelqi793
andauthored
Kv jca cache token (Azure#23847)
* Cache the access token * Added unit test for cache token * Added one more test for expired cache token Co-authored-by: michaelqi793 <[email protected]>
1 parent c7cc461 commit c0e1704

File tree

6 files changed

+152
-56
lines changed

6 files changed

+152
-56
lines changed

sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/KeyVaultClient.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
import com.azure.security.keyvault.jca.implementation.model.CertificateListResult;
1212
import com.azure.security.keyvault.jca.implementation.model.CertificatePolicy;
1313
import com.azure.security.keyvault.jca.implementation.model.KeyProperties;
14+
import com.azure.security.keyvault.jca.implementation.model.AccessToken;
1415
import com.azure.security.keyvault.jca.implementation.model.SecretBundle;
16+
import com.azure.security.keyvault.jca.implementation.model.SignResult;
1517
import com.azure.security.keyvault.jca.implementation.utils.AccessTokenUtil;
1618
import com.azure.security.keyvault.jca.implementation.utils.HttpUtil;
1719
import com.azure.security.keyvault.jca.implementation.utils.JsonConverterUtil;
18-
import com.azure.security.keyvault.jca.implementation.model.SignResult;
1920

2021
import java.io.BufferedReader;
2122
import java.io.ByteArrayInputStream;
@@ -116,6 +117,11 @@ public static String getAADLoginURIByKeyVaultBaseUri(String keyVaultBaseUri) {
116117
*/
117118
private String managedIdentity;
118119

120+
/**
121+
* Stores the token.
122+
*/
123+
private AccessToken accessToken;
124+
119125
/**
120126
* Constructor for authentication with user-assigned managed identity.
121127
*
@@ -183,8 +189,21 @@ public static KeyVaultClient createKeyVaultClientBySystemProperty() {
183189
* @return the access token.
184190
*/
185191
private String getAccessToken() {
192+
if (accessToken != null && !accessToken.isExpired()) {
193+
return accessToken.getAccessToken();
194+
}
195+
accessToken = getAccessTokenByHttpRequest();
196+
return accessToken.getAccessToken();
197+
}
198+
199+
/**
200+
* Get the access token.
201+
*
202+
* @return the access token.
203+
*/
204+
private AccessToken getAccessTokenByHttpRequest() {
186205
LOGGER.entering("KeyVaultClient", "getAccessToken");
187-
String accessToken = null;
206+
AccessToken accessToken = null;
188207
try {
189208
String resource = URLEncoder.encode(keyVaultBaseUri, "UTF-8");
190209
if (managedIdentity != null) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
package com.azure.security.keyvault.jca.implementation.model;
4+
5+
import com.fasterxml.jackson.annotation.JsonProperty;
6+
7+
import java.time.OffsetDateTime;
8+
9+
/**
10+
* An OAuth2 token.
11+
*/
12+
public class AccessToken {
13+
14+
/**
15+
* Stores the access token.
16+
*/
17+
@JsonProperty("access_token")
18+
private String accessToken;
19+
20+
/**
21+
* Stores the life duration of the access token.
22+
*/
23+
@JsonProperty("expires_in")
24+
private long expiresIn;
25+
26+
/**
27+
*
28+
* @return the life duration of the access token in seconds
29+
*/
30+
public long getExpiresIn() {
31+
return expiresIn;
32+
}
33+
34+
/**
35+
* Set the life duration of the access token in seconds
36+
*
37+
* @param expiresIn
38+
*/
39+
public void setExpiresIn(long expiresIn) {
40+
this.expiresIn = expiresIn;
41+
}
42+
43+
/**
44+
* Stores the time when the token is retrieved for the first time.
45+
*/
46+
private final OffsetDateTime creationDate = OffsetDateTime.now();
47+
48+
/**
49+
* Get the access token.
50+
*
51+
* @return the access token.
52+
*/
53+
public String getAccessToken() {
54+
return accessToken;
55+
}
56+
57+
/**
58+
* Set the access token.
59+
*
60+
* @param accessToken the access token.
61+
*/
62+
public void setAccessToken(String accessToken) {
63+
this.accessToken = accessToken;
64+
}
65+
66+
/**
67+
* Reserve 60 seconds, in case that the time the token is used it is valid but when the token gets to the server side, it expires.
68+
* @return boolean, whether the token is expired.
69+
*
70+
*/
71+
public boolean isExpired() {
72+
return OffsetDateTime.now().isAfter(creationDate.plusSeconds(expiresIn - 60));
73+
}
74+
}

sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/model/OAuthToken.java

Lines changed: 0 additions & 35 deletions
This file was deleted.

sdk/keyvault/azure-security-keyvault-jca/src/main/java/com/azure/security/keyvault/jca/implementation/utils/AccessTokenUtil.java

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import static java.util.logging.Level.FINER;
66
import static java.util.logging.Level.INFO;
77

8-
import com.azure.security.keyvault.jca.implementation.model.OAuthToken;
8+
import com.azure.security.keyvault.jca.implementation.model.AccessToken;
99
import java.util.HashMap;
1010
import java.util.logging.Logger;
1111

@@ -62,8 +62,8 @@ public final class AccessTokenUtil {
6262
* @param identity the user-assigned identity (null if system-assigned)
6363
* @return the authorization token.
6464
*/
65-
public static String getAccessToken(String resource, String identity) {
66-
String result;
65+
public static AccessToken getAccessToken(String resource, String identity) {
66+
AccessToken result;
6767

6868
if (System.getenv("WEBSITE_SITE_NAME") != null
6969
&& !System.getenv("WEBSITE_SITE_NAME").isEmpty()) {
@@ -84,13 +84,13 @@ public static String getAccessToken(String resource, String identity) {
8484
* @param clientSecret the client secret.
8585
* @return the authorization token.
8686
*/
87-
public static String getAccessToken(String resource, String aadAuthenticationUrl,
88-
String tenantId, String clientId, String clientSecret) {
87+
public static AccessToken getAccessToken(String resource, String aadAuthenticationUrl,
88+
String tenantId, String clientId, String clientSecret) {
8989

9090
LOGGER.entering("AccessTokenUtil", "getAccessToken", new Object[]{
9191
resource, tenantId, clientId, clientSecret});
9292
LOGGER.info("Getting access token using client ID / client secret");
93-
String result = null;
93+
AccessToken result = null;
9494

9595
StringBuilder oauth2Url = new StringBuilder();
9696
oauth2Url.append(aadAuthenticationUrl == null ? OAUTH2_TOKEN_BASE_URL : aadAuthenticationUrl)
@@ -106,8 +106,7 @@ public static String getAccessToken(String resource, String aadAuthenticationUrl
106106
String body = HttpUtil
107107
.post(oauth2Url.toString(), requestBody.toString(), "application/x-www-form-urlencoded");
108108
if (body != null) {
109-
OAuthToken token = (OAuthToken) JsonConverterUtil.fromJson(body, OAuthToken.class);
110-
result = token.getAccessToken();
109+
result = (AccessToken) JsonConverterUtil.fromJson(body, AccessToken.class);
111110
}
112111
LOGGER.log(FINER, "Access token: {0}", result);
113112
return result;
@@ -120,11 +119,10 @@ public static String getAccessToken(String resource, String aadAuthenticationUrl
120119
* @param clientId the user-assigned managed identity (null if system-assigned).
121120
* @return the authorization token.
122121
*/
123-
private static String getAccessTokenOnAppService(String resource, String clientId) {
122+
private static AccessToken getAccessTokenOnAppService(String resource, String clientId) {
124123
LOGGER.entering("AccessTokenUtil", "getAccessTokenOnAppService", resource);
125124
LOGGER.info("Getting access token using managed identity based on MSI_SECRET");
126-
String result = null;
127-
125+
AccessToken result = null;
128126
StringBuilder url = new StringBuilder();
129127
url.append(System.getenv("MSI_ENDPOINT"))
130128
.append("?api-version=2017-09-01")
@@ -140,8 +138,7 @@ private static String getAccessTokenOnAppService(String resource, String clientI
140138
String body = HttpUtil.get(url.toString(), headers);
141139

142140
if (body != null) {
143-
OAuthToken token = (OAuthToken) JsonConverterUtil.fromJson(body, OAuthToken.class);
144-
result = token.getAccessToken();
141+
result = (AccessToken) JsonConverterUtil.fromJson(body, AccessToken.class);
145142
}
146143
LOGGER.exiting("AccessTokenUtil", "getAccessTokenOnAppService", result);
147144
return result;
@@ -154,13 +151,13 @@ private static String getAccessTokenOnAppService(String resource, String clientI
154151
* @param identity the user-assigned identity (null if system-assigned).
155152
* @return the authorization token.
156153
*/
157-
private static String getAccessTokenOnOthers(String resource, String identity) {
154+
private static AccessToken getAccessTokenOnOthers(String resource, String identity) {
158155
LOGGER.entering("AccessTokenUtil", "getAccessTokenOnOthers", resource);
159156
LOGGER.info("Getting access token using managed identity");
160157
if (identity != null) {
161158
LOGGER.log(INFO, "Using managed identity with object ID: {0}", identity);
162159
}
163-
String result = null;
160+
AccessToken result = null;
164161

165162
StringBuilder url = new StringBuilder();
166163
url.append(OAUTH2_MANAGED_IDENTITY_TOKEN_URL)
@@ -174,8 +171,7 @@ private static String getAccessTokenOnOthers(String resource, String identity) {
174171
String body = HttpUtil.get(url.toString(), headers);
175172

176173
if (body != null) {
177-
OAuthToken token = (OAuthToken) JsonConverterUtil.fromJson(body, OAuthToken.class);
178-
result = token.getAccessToken();
174+
result = (AccessToken) JsonConverterUtil.fromJson(body, AccessToken.class);
179175
}
180176
LOGGER.exiting("AccessTokenUtil", "getAccessTokenOnOthers", result);
181177
return result;

sdk/keyvault/azure-security-keyvault-jca/src/test/java/com/azure/security/keyvault/jca/AccessTokenUtilTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
package com.azure.security.keyvault.jca;
55

6+
import com.azure.security.keyvault.jca.implementation.model.AccessToken;
67
import com.azure.security.keyvault.jca.implementation.utils.AccessTokenUtil;
78
import org.junit.jupiter.api.Test;
89
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
@@ -27,7 +28,7 @@ public void testGetAuthorizationToken() throws Exception {
2728
String tenantId = System.getenv("AZURE_KEYVAULT_TENANT_ID");
2829
String clientId = System.getenv("AZURE_KEYVAULT_CLIENT_ID");
2930
String clientSecret = System.getenv("AZURE_KEYVAULT_CLIENT_SECRET");
30-
String result = AccessTokenUtil.getAccessToken(
31+
AccessToken result = AccessTokenUtil.getAccessToken(
3132
"https://management.azure.com/",
3233
null,
3334
tenantId,

sdk/keyvault/azure-security-keyvault-jca/src/test/java/com/azure/security/keyvault/jca/implementation/KeyVaultClientTest.java

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
package com.azure.security.keyvault.jca.implementation;
55

6+
import com.azure.security.keyvault.jca.implementation.model.AccessToken;
67
import com.azure.security.keyvault.jca.implementation.model.CertificateItem;
78
import com.azure.security.keyvault.jca.implementation.model.CertificateListResult;
9+
import com.azure.security.keyvault.jca.implementation.utils.AccessTokenUtil;
810
import com.azure.security.keyvault.jca.implementation.utils.HttpUtil;
911
import com.azure.security.keyvault.jca.implementation.utils.JsonConverterUtil;
1012
import org.junit.jupiter.api.Assertions;
@@ -27,7 +29,7 @@
2729
import static com.azure.security.keyvault.jca.implementation.KeyVaultClient.KEY_VAULT_BASE_URI_US;
2830
import static org.junit.jupiter.api.Assertions.*;
2931
import static org.mockito.ArgumentMatchers.*;
30-
import static org.mockito.Mockito.mock;
32+
import static org.mockito.Mockito.*;
3133

3234
public class KeyVaultClientTest {
3335

@@ -155,4 +157,43 @@ private KeyVaultClient getKeyVaultClient() {
155157
String clientSecret = System.getProperty("azure.keyvault.client-secret");
156158
return new KeyVaultClient(keyVaultUri, tenantId, clientId, clientSecret);
157159
}
160+
161+
162+
@Test
163+
public void testCacheToken() {
164+
try (MockedStatic<AccessTokenUtil> tokenUtilMockedStatic = Mockito.mockStatic(AccessTokenUtil.class); MockedStatic<HttpUtil> httpUtilMockedStatic = Mockito.mockStatic(HttpUtil.class)) {
165+
AccessToken cacheToken = new AccessToken();
166+
cacheToken.setExpiresIn(300); // 300 seconds.
167+
tokenUtilMockedStatic.when(() -> AccessTokenUtil.getAccessToken(anyString(), anyString())).thenReturn(cacheToken);
168+
CertificateItem fakeCertificateItem = new CertificateItem();
169+
fakeCertificateItem.setId("certificates/fakeCertificateItem");
170+
CertificateListResult certificateListResult = new CertificateListResult();
171+
certificateListResult.setValue(Arrays.asList(fakeCertificateItem));
172+
String certificateListResultString = JsonConverterUtil.toJson(certificateListResult);
173+
httpUtilMockedStatic.when(() -> HttpUtil.get(anyString(), anyMap())).thenReturn(certificateListResultString);
174+
KeyVaultClient keyVaultClient = new KeyVaultClient(KEY_VAULT_TEST_URI_GLOBAL, "");
175+
keyVaultClient.getAliases();
176+
keyVaultClient.getAliases(); // get aliases the second time.
177+
tokenUtilMockedStatic.verify(() -> AccessTokenUtil.getAccessToken(anyString(), anyString()), times(1));
178+
}
179+
}
180+
181+
@Test
182+
public void testCacheTokenExpired() {
183+
try (MockedStatic<AccessTokenUtil> tokenUtilMockedStatic = Mockito.mockStatic(AccessTokenUtil.class); MockedStatic<HttpUtil> httpUtilMockedStatic = Mockito.mockStatic(HttpUtil.class)) {
184+
AccessToken cacheToken = new AccessToken();
185+
cacheToken.setExpiresIn(50); // 50 seconds.
186+
tokenUtilMockedStatic.when(() -> AccessTokenUtil.getAccessToken(anyString(), anyString())).thenReturn(cacheToken);
187+
CertificateItem fakeCertificateItem = new CertificateItem();
188+
fakeCertificateItem.setId("certificates/fakeCertificateItem");
189+
CertificateListResult certificateListResult = new CertificateListResult();
190+
certificateListResult.setValue(Arrays.asList(fakeCertificateItem));
191+
String certificateListResultString = JsonConverterUtil.toJson(certificateListResult);
192+
httpUtilMockedStatic.when(() -> HttpUtil.get(anyString(), anyMap())).thenReturn(certificateListResultString);
193+
KeyVaultClient keyVaultClient = new KeyVaultClient(KEY_VAULT_TEST_URI_GLOBAL, "");
194+
keyVaultClient.getAliases();
195+
keyVaultClient.getAliases(); // get aliases the second time.
196+
tokenUtilMockedStatic.verify(() -> AccessTokenUtil.getAccessToken(anyString(), anyString()), times(2));
197+
}
198+
}
158199
}

0 commit comments

Comments
 (0)