1818import com .fasterxml .jackson .databind .JsonNode ;
1919import com .fasterxml .jackson .databind .ObjectMapper ;
2020import com .marklogic .client .DatabaseClientFactory .MarkLogicCloudAuthContext ;
21- import okhttp3 .*;
21+ import okhttp3 .Call ;
22+ import okhttp3 .FormBody ;
23+ import okhttp3 .HttpUrl ;
24+ import okhttp3 .Interceptor ;
25+ import okhttp3 .OkHttpClient ;
26+ import okhttp3 .Request ;
27+ import okhttp3 .Response ;
2228import org .slf4j .Logger ;
2329import org .slf4j .LoggerFactory ;
2430
2531import java .io .IOException ;
2632
2733public class MarkLogicCloudAuthenticationConfigurer implements AuthenticationConfigurer <MarkLogicCloudAuthContext > {
2834
29- private final static Logger logger = LoggerFactory .getLogger (MarkLogicCloudAuthenticationConfigurer .class );
30-
3135 private String host ;
3236
3337 public MarkLogicCloudAuthenticationConfigurer (String host ) {
@@ -40,78 +44,155 @@ public void configureAuthentication(OkHttpClient.Builder clientBuilder, MarkLogi
4044 if (apiKey == null || apiKey .trim ().length () < 1 ) {
4145 throw new IllegalArgumentException ("No API key provided" );
4246 }
47+ TokenGenerator tokenGenerator = new DefaultTokenGenerator (this .host , securityContext );
48+ clientBuilder .addInterceptor (new TokenAuthenticationInterceptor (tokenGenerator ));
49+ }
4350
44- final Response response = callTokenEndpoint (securityContext );
45- final String accessToken = getAccessTokenFromResponse (response );
46- if (logger .isInfoEnabled ()) {
47- logger .info ("Successfully obtained authentication token" );
48- }
49- clientBuilder
50- .addInterceptor (chain -> {
51- Request authenticatedRequest = chain .request ().newBuilder ()
52- .header ("Authorization" , "Bearer " + accessToken )
53- .build ();
54- return chain .proceed (authenticatedRequest );
55- });
51+ /**
52+ * Exists solely for mocking in unit tests.
53+ */
54+ public interface TokenGenerator {
55+ String generateToken ();
5656 }
5757
58- private Response callTokenEndpoint (MarkLogicCloudAuthContext securityContext ) {
59- final HttpUrl tokenUrl = buildTokenUrl (securityContext );
60- OkHttpClient .Builder clientBuilder = OkHttpUtil .newClientBuilder ();
61- // Current assumption is that the SSL config provided for connecting to MarkLogic should also be applicable
62- // for connecting to MarkLogic Cloud's "/token" endpoint.
63- OkHttpUtil .configureSocketFactory (clientBuilder , securityContext .getSSLContext (), securityContext .getTrustManager ());
64- OkHttpUtil .configureHostnameVerifier (clientBuilder , securityContext .getSSLHostnameVerifier ());
58+ /**
59+ * Knows how to call the "/token" endpoint in MarkLogic Cloud to generate a new token based on the
60+ * user-provided API key.
61+ */
62+ static class DefaultTokenGenerator implements TokenGenerator {
63+
64+ private final static Logger logger = LoggerFactory .getLogger (DefaultTokenGenerator .class );
65+ private String host ;
66+ private MarkLogicCloudAuthContext securityContext ;
67+
68+ public DefaultTokenGenerator (String host , MarkLogicCloudAuthContext securityContext ) {
69+ this .host = host ;
70+ this .securityContext = securityContext ;
71+ }
6572
66- if (logger .isInfoEnabled ()) {
67- logger .info ("Calling token endpoint at: " + tokenUrl );
73+ public String generateToken () {
74+ final Response tokenResponse = callTokenEndpoint ();
75+ String token = getAccessTokenFromResponse (tokenResponse );
76+ if (logger .isInfoEnabled ()) {
77+ logger .info ("Successfully obtained authentication token" );
78+ }
79+ return token ;
6880 }
6981
70- final Call call = clientBuilder
71- .build ()
72- .newCall (new Request .Builder ()
73- .url (tokenUrl )
74- .post (newFormBody (securityContext ))
75- .build ()
82+ private Response callTokenEndpoint () {
83+ final HttpUrl tokenUrl = buildTokenUrl ();
84+ OkHttpClient .Builder clientBuilder = OkHttpUtil .newClientBuilder ();
85+ // Current assumption is that the SSL config provided for connecting to MarkLogic should also be applicable
86+ // for connecting to MarkLogic Cloud's "/token" endpoint.
87+ OkHttpUtil .configureSocketFactory (clientBuilder , securityContext .getSSLContext (), securityContext .getTrustManager ());
88+ OkHttpUtil .configureHostnameVerifier (clientBuilder , securityContext .getSSLHostnameVerifier ());
89+
90+ if (logger .isInfoEnabled ()) {
91+ logger .info ("Calling token endpoint at: " + tokenUrl );
92+ }
93+
94+ final Call call = clientBuilder .build ().newCall (
95+ new Request .Builder ()
96+ .url (tokenUrl )
97+ .post (newFormBody ())
98+ .build ()
7699 );
77100
78- try {
79- return call .execute ();
80- } catch (IOException e ) {
81- throw new RuntimeException (String .format ("Unable to call token endpoint at %s; cause: %s" ,
82- tokenUrl , e .getMessage (), e ));
101+ try {
102+ return call .execute ();
103+ } catch (IOException e ) {
104+ throw new RuntimeException (String .format ("Unable to call token endpoint at %s; cause: %s" ,
105+ tokenUrl , e .getMessage (), e ));
106+ }
83107 }
84- }
85108
86- protected HttpUrl buildTokenUrl (MarkLogicCloudAuthContext securityContext ) {
87- // For the near future, it's guaranteed that https and 443 will be required for connecting to MarkLogic Cloud,
88- // so providing the ability to customize this would be misleading.
89- return new HttpUrl .Builder ()
90- .scheme ("https" )
91- .host (host )
92- .port (443 )
93- .build ()
94- .resolve (securityContext .getTokenEndpoint ()).newBuilder ().build ();
95- }
109+ protected HttpUrl buildTokenUrl () {
110+ // For the near future, it's guaranteed that https and 443 will be required for connecting to MarkLogic Cloud,
111+ // so providing the ability to customize this would be misleading.
112+ return new HttpUrl .Builder ()
113+ .scheme ("https" )
114+ .host (host )
115+ .port (443 )
116+ .build ()
117+ .resolve (securityContext .getTokenEndpoint ()).newBuilder ().build ();
118+ }
96119
97- protected FormBody newFormBody (MarkLogicCloudAuthContext securityContext ) {
98- return new FormBody .Builder ()
99- .add ("grant_type" , securityContext .getGrantType ())
100- .add ("key" , securityContext .getApiKey ()).build ();
120+ protected FormBody newFormBody () {
121+ return new FormBody .Builder ()
122+ .add ("grant_type" , securityContext .getGrantType ())
123+ .add ("key" , securityContext .getApiKey ()).build ();
124+ }
125+
126+ private String getAccessTokenFromResponse (Response response ) {
127+ String responseBody = null ;
128+ JsonNode payload ;
129+ try {
130+ responseBody = response .body ().string ();
131+ payload = new ObjectMapper ().readTree (responseBody );
132+ } catch (IOException ex ) {
133+ throw new RuntimeException ("Unable to get access token; response: " + responseBody , ex );
134+ }
135+ if (!payload .has ("access_token" )) {
136+ throw new RuntimeException ("Unable to get access token; unexpected JSON response: " + payload );
137+ }
138+ return payload .get ("access_token" ).asText ();
139+ }
101140 }
102141
103- private String getAccessTokenFromResponse (Response response ) {
104- String responseBody = null ;
105- JsonNode payload ;
106- try {
107- responseBody = response .body ().string ();
108- payload = new ObjectMapper ().readTree (responseBody );
109- } catch (IOException ex ) {
110- throw new RuntimeException ("Unable to get access token; response: " + responseBody , ex );
142+ /**
143+ * OkHttp interceptor that handles adding a token to an HTTP request and renewing it when necessary.
144+ */
145+ static class TokenAuthenticationInterceptor implements Interceptor {
146+
147+ private final static Logger logger = LoggerFactory .getLogger (TokenAuthenticationInterceptor .class );
148+
149+ private TokenGenerator tokenGenerator ;
150+ private String token ;
151+
152+ public TokenAuthenticationInterceptor (TokenGenerator tokenGenerator ) {
153+ this .tokenGenerator = tokenGenerator ;
154+ this .token = tokenGenerator .generateToken ();
111155 }
112- if (!payload .has ("access_token" )) {
113- throw new RuntimeException ("Unable to get access token; unexpected JSON response: " + payload );
156+
157+ @ Override
158+ public Response intercept (Chain chain ) throws IOException {
159+ Response response = chain .proceed (addTokenToRequest (chain ));
160+ if (response .code () == 403 ) {
161+ logger .info ("Received 403; will generate new token if necessary and retry request" );
162+ response .close ();
163+ final String currentToken = this .token ;
164+ generateNewTokenIfNecessary (currentToken );
165+ response = chain .proceed (addTokenToRequest (chain ));
166+ }
167+ return response ;
168+ }
169+
170+ /**
171+ * In the case of N threads using the same DatabaseClient - e.g. when using DMSDK - all of them
172+ * may make a request at the same time and get a 403 back. Functionally, it should be fine if all
173+ * make their own requests to renew the token, with the last thread being the one whose token
174+ * value is retained on this class. But to simplify matters, this block is synchronized so only one
175+ * thread can be in here. And if that thread finds that this.token is different from currentToken,
176+ * then some other thread already renewed the token - so this thread doesn't need to do anything and
177+ * can just try again.
178+ *
179+ * @param currentToken the value of this instance's token right before calling this method; in the event that
180+ * another thread using this instance got here first, then this value will differ from the
181+ * instance's token field
182+ */
183+ private synchronized void generateNewTokenIfNecessary (String currentToken ) {
184+ if (currentToken .equals (this .token )) {
185+ logger .info ("Generating new token based on receiving 403" );
186+ this .token = tokenGenerator .generateToken ();
187+ } else if (logger .isDebugEnabled ()) {
188+ logger .debug ("This instance's token has already been updated, presumably by another thread" );
189+ }
190+ }
191+
192+ private Request addTokenToRequest (Chain chain ) {
193+ return chain .request ().newBuilder ()
194+ .header ("Authorization" , "Bearer " + token )
195+ .build ();
114196 }
115- return payload .get ("access_token" ).asText ();
116197 }
117198}
0 commit comments