|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +package com.microsoft.azure.msalwebsample; |
| 5 | + |
| 6 | +import java.io.IOException; |
| 7 | +import java.io.UnsupportedEncodingException; |
| 8 | +import java.net.MalformedURLException; |
| 9 | +import java.net.URI; |
| 10 | +import java.net.URLEncoder; |
| 11 | +import java.text.ParseException; |
| 12 | +import java.util.*; |
| 13 | +import java.util.concurrent.*; |
| 14 | + |
| 15 | +import javax.naming.ServiceUnavailableException; |
| 16 | +import javax.servlet.Filter; |
| 17 | +import javax.servlet.FilterChain; |
| 18 | +import javax.servlet.FilterConfig; |
| 19 | +import javax.servlet.ServletException; |
| 20 | +import javax.servlet.ServletRequest; |
| 21 | +import javax.servlet.ServletResponse; |
| 22 | +import javax.servlet.http.HttpServletRequest; |
| 23 | +import javax.servlet.http.HttpServletResponse; |
| 24 | +import javax.servlet.http.HttpSession; |
| 25 | + |
| 26 | +import com.microsoft.aad.msal4j.*; |
| 27 | +import com.nimbusds.jwt.JWTParser; |
| 28 | +import com.nimbusds.oauth2.sdk.AuthorizationCode; |
| 29 | +import com.nimbusds.openid.connect.sdk.AuthenticationErrorResponse; |
| 30 | +import com.nimbusds.openid.connect.sdk.AuthenticationResponse; |
| 31 | +import com.nimbusds.openid.connect.sdk.AuthenticationResponseParser; |
| 32 | +import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse; |
| 33 | +import org.apache.commons.lang3.StringUtils; |
| 34 | +import org.springframework.beans.factory.annotation.Autowired; |
| 35 | +import org.springframework.stereotype.Component; |
| 36 | + |
| 37 | +@Component |
| 38 | +public class AuthFilter implements Filter { |
| 39 | + |
| 40 | + private static final String STATES = "states"; |
| 41 | + private static final String STATE = "state"; |
| 42 | + private static final Integer STATE_TTL = 3600; |
| 43 | + private static final String FAILED_TO_VALIDATE_MESSAGE = "Failed to validate data received from Authorization service - "; |
| 44 | + |
| 45 | + private List<String> excludedUrls = Arrays.asList("/", "/msal4jsample/"); |
| 46 | + |
| 47 | + @Autowired |
| 48 | + AuthHelper authHelper; |
| 49 | + |
| 50 | + @Override |
| 51 | + public void doFilter(ServletRequest request, ServletResponse response, |
| 52 | + FilterChain chain) throws IOException, ServletException { |
| 53 | + if (request instanceof HttpServletRequest) { |
| 54 | + HttpServletRequest httpRequest = (HttpServletRequest) request; |
| 55 | + HttpServletResponse httpResponse = (HttpServletResponse) response; |
| 56 | + try { |
| 57 | + String currentUri = httpRequest.getRequestURL().toString(); |
| 58 | + String path = httpRequest.getServletPath(); |
| 59 | + String queryStr = httpRequest.getQueryString(); |
| 60 | + String fullUrl = currentUri + (queryStr != null ? "?" + queryStr : ""); |
| 61 | + |
| 62 | + // exclude home page |
| 63 | + if(excludedUrls.contains(path)){ |
| 64 | + chain.doFilter(request, response); |
| 65 | + return; |
| 66 | + } |
| 67 | + // check if user has a AuthData in the session |
| 68 | + if (!AuthHelper.isAuthenticated(httpRequest)) { |
| 69 | + if(AuthHelper.containsAuthenticationCode(httpRequest)){ |
| 70 | + // response should have authentication code, which will be used to acquire access token |
| 71 | + processAuthenticationCodeRedirect(httpRequest, currentUri, fullUrl); |
| 72 | + } else { |
| 73 | + // not authenticated, redirecting to login.microsoft.com so user can authenticate |
| 74 | + sendAuthRedirect(authHelper.configuration.signUpSignInAuthority, httpRequest, httpResponse); |
| 75 | + return; |
| 76 | + } |
| 77 | + } |
| 78 | + if (isAccessTokenExpired(httpRequest)) { |
| 79 | + authHelper.updateAuthDataUsingSilentFlow(httpRequest); |
| 80 | + } |
| 81 | + } catch (MsalException authException) { |
| 82 | + // something went wrong (like expiration or revocation of token) |
| 83 | + // we should invalidate AuthData stored in session and redirect to Authorization server |
| 84 | + authHelper.removePrincipalFromSession(httpRequest); |
| 85 | + sendAuthRedirect(authHelper.configuration.signUpSignInAuthority, httpRequest, httpResponse); |
| 86 | + return; |
| 87 | + } catch (Throwable exc) { |
| 88 | + httpResponse.setStatus(500); |
| 89 | + request.setAttribute("error", exc.getMessage()); |
| 90 | + request.getRequestDispatcher("/error").forward(request, response); |
| 91 | + return; |
| 92 | + } |
| 93 | + } |
| 94 | + chain.doFilter(request, response); |
| 95 | + } |
| 96 | + |
| 97 | + private boolean isAccessTokenExpired(HttpServletRequest httpRequest) { |
| 98 | + IAuthenticationResult result = AuthHelper.getAuthSessionObject(httpRequest); |
| 99 | + return result.expiresOnDate().before(new Date()); |
| 100 | + } |
| 101 | + |
| 102 | + private void processAuthenticationCodeRedirect(HttpServletRequest httpRequest, String currentUri, String fullUrl) |
| 103 | + throws Throwable { |
| 104 | + |
| 105 | + Map<String, List<String>> params = new HashMap<>(); |
| 106 | + for (String key : httpRequest.getParameterMap().keySet()) { |
| 107 | + params.put(key, Collections.singletonList(httpRequest.getParameterMap().get(key)[0])); |
| 108 | + } |
| 109 | + // validate that state in response equals to state in request |
| 110 | + StateData stateData = validateState(httpRequest.getSession(), params.get(STATE).get(0)); |
| 111 | + |
| 112 | + AuthenticationResponse authResponse = AuthenticationResponseParser.parse(new URI(fullUrl), params); |
| 113 | + if (AuthHelper.isAuthenticationSuccessful(authResponse)) { |
| 114 | + AuthenticationSuccessResponse oidcResponse = (AuthenticationSuccessResponse) authResponse; |
| 115 | + // validate that OIDC Auth Response matches Code Flow (contains only requested artifacts) |
| 116 | + validateAuthRespMatchesAuthCodeFlow(oidcResponse); |
| 117 | + |
| 118 | + IAuthenticationResult result = authHelper.getAuthResultByAuthCode( |
| 119 | + httpRequest, |
| 120 | + oidcResponse.getAuthorizationCode(), |
| 121 | + currentUri, |
| 122 | + Collections.singleton(authHelper.configuration.apiScope)); |
| 123 | + |
| 124 | + // validate nonce to prevent reply attacks (code maybe substituted to one with broader access) |
| 125 | + validateNonce(stateData, getNonceClaimValueFromIdToken(result.idToken())); |
| 126 | + authHelper.setSessionPrincipal(httpRequest, result); |
| 127 | + } else { |
| 128 | + AuthenticationErrorResponse oidcResponse = (AuthenticationErrorResponse) authResponse; |
| 129 | + throw new Exception(String.format("Request for auth code failed: %s - %s", |
| 130 | + oidcResponse.getErrorObject().getCode(), |
| 131 | + oidcResponse.getErrorObject().getDescription())); |
| 132 | + } |
| 133 | + } |
| 134 | + |
| 135 | + void sendAuthRedirect(String authoriy, HttpServletRequest httpRequest, HttpServletResponse httpResponse) throws IOException { |
| 136 | + // state parameter to validate response from Authorization server and nonce parameter to validate idToken |
| 137 | + String state = UUID.randomUUID().toString(); |
| 138 | + String nonce = UUID.randomUUID().toString(); |
| 139 | + storeStateInSession(httpRequest.getSession(), state, nonce); |
| 140 | + |
| 141 | + httpResponse.setStatus(302); |
| 142 | + String redirectUrl = getRedirectUrl(authoriy, httpRequest.getParameter("claims"), state, nonce); |
| 143 | + httpResponse.sendRedirect(redirectUrl); |
| 144 | + } |
| 145 | + |
| 146 | + private void validateNonce(StateData stateData, String nonce) throws Exception { |
| 147 | + if (StringUtils.isEmpty(nonce) || !nonce.equals(stateData.getNonce())) { |
| 148 | + throw new Exception(FAILED_TO_VALIDATE_MESSAGE + "could not validate nonce"); |
| 149 | + } |
| 150 | + } |
| 151 | + |
| 152 | + private String getNonceClaimValueFromIdToken(String idToken) throws ParseException { |
| 153 | + return (String) JWTParser.parse(idToken).getJWTClaimsSet().getClaim("nonce"); |
| 154 | + } |
| 155 | + |
| 156 | + private StateData validateState(HttpSession session, String state) throws Exception { |
| 157 | + if (StringUtils.isNotEmpty(state)) { |
| 158 | + StateData stateDataInSession = removeStateFromSession(session, state); |
| 159 | + if (stateDataInSession != null) { |
| 160 | + return stateDataInSession; |
| 161 | + } |
| 162 | + } |
| 163 | + throw new Exception(FAILED_TO_VALIDATE_MESSAGE + "could not validate state"); |
| 164 | + } |
| 165 | + |
| 166 | + private void validateAuthRespMatchesAuthCodeFlow(AuthenticationSuccessResponse oidcResponse) throws Exception { |
| 167 | + if (oidcResponse.getIDToken() != null || oidcResponse.getAccessToken() != null || |
| 168 | + oidcResponse.getAuthorizationCode() == null) { |
| 169 | + throw new Exception(FAILED_TO_VALIDATE_MESSAGE + "unexpected set of artifacts received"); |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + private void storeStateInSession(HttpSession session, String state, String nonce) { |
| 174 | + if (session.getAttribute(STATES) == null) { |
| 175 | + session.setAttribute(STATES, new HashMap<String, StateData>()); |
| 176 | + } |
| 177 | + ((Map<String, StateData>) session.getAttribute(STATES)).put(state, new StateData(nonce, new Date())); |
| 178 | + } |
| 179 | + |
| 180 | + private StateData removeStateFromSession(HttpSession session, String state) { |
| 181 | + Map<String, StateData> states = (Map<String, StateData>) session.getAttribute(STATES); |
| 182 | + if (states != null) { |
| 183 | + eliminateExpiredStates(states); |
| 184 | + StateData stateData = states.get(state); |
| 185 | + if (stateData != null) { |
| 186 | + states.remove(state); |
| 187 | + return stateData; |
| 188 | + } |
| 189 | + } |
| 190 | + return null; |
| 191 | + } |
| 192 | + |
| 193 | + private void eliminateExpiredStates(Map<String, StateData> map) { |
| 194 | + Iterator<Map.Entry<String, StateData>> it = map.entrySet().iterator(); |
| 195 | + |
| 196 | + Date currTime = new Date(); |
| 197 | + while (it.hasNext()) { |
| 198 | + Map.Entry<String, StateData> entry = it.next(); |
| 199 | + long diffInSeconds = TimeUnit.MILLISECONDS. |
| 200 | + toSeconds(currTime.getTime() - entry.getValue().getExpirationDate().getTime()); |
| 201 | + |
| 202 | + if (diffInSeconds > STATE_TTL) { |
| 203 | + it.remove(); |
| 204 | + } |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + private String getRedirectUrl(String authority, String claims, String state, String nonce) |
| 209 | + throws UnsupportedEncodingException { |
| 210 | + |
| 211 | + String redirectUrl = authority.replace("/tfp", "") + "oauth2/v2.0/authorize?" + |
| 212 | + "response_type=code&" + |
| 213 | + "response_mode=form_post&" + |
| 214 | + "redirect_uri=" + URLEncoder.encode(authHelper.configuration.redirectUri, "UTF-8") + |
| 215 | + "&client_id=" + authHelper.configuration.clientId + |
| 216 | + "&scope=" + URLEncoder.encode("openid offline_access profile " + |
| 217 | + authHelper.configuration.apiScope, "UTF-8") + |
| 218 | + (StringUtils.isEmpty(claims) ? "" : "&claims=" + claims) + |
| 219 | + "&prompt=select_account" + |
| 220 | + "&state=" + state |
| 221 | + + "&nonce=" + nonce; |
| 222 | + |
| 223 | + return redirectUrl; |
| 224 | + } |
| 225 | +} |
0 commit comments