| 
 | 1 | +package info.unterrainer.websocketserver;  | 
 | 2 | + | 
 | 3 | +import java.io.IOException;  | 
 | 4 | +import java.math.BigInteger;  | 
 | 5 | +import java.net.URI;  | 
 | 6 | +import java.net.http.HttpClient;  | 
 | 7 | +import java.net.http.HttpRequest;  | 
 | 8 | +import java.net.http.HttpResponse;  | 
 | 9 | +import java.security.KeyFactory;  | 
 | 10 | +import java.security.PublicKey;  | 
 | 11 | +import java.security.spec.RSAPublicKeySpec;  | 
 | 12 | +import java.util.Base64;  | 
 | 13 | + | 
 | 14 | +import org.keycloak.TokenVerifier;  | 
 | 15 | +import org.keycloak.common.VerificationException;  | 
 | 16 | +import org.keycloak.representations.AccessToken;  | 
 | 17 | + | 
 | 18 | +import com.fasterxml.jackson.databind.JsonNode;  | 
 | 19 | +import com.fasterxml.jackson.databind.ObjectMapper;  | 
 | 20 | + | 
 | 21 | +import info.unterrainer.commons.httpserver.exceptions.ForbiddenException;  | 
 | 22 | +import info.unterrainer.commons.httpserver.exceptions.UnauthorizedException;  | 
 | 23 | +import lombok.RequiredArgsConstructor;  | 
 | 24 | +import lombok.extern.slf4j.Slf4j;  | 
 | 25 | + | 
 | 26 | +@Slf4j  | 
 | 27 | +@RequiredArgsConstructor  | 
 | 28 | +public class JwtTokenHandler {  | 
 | 29 | + | 
 | 30 | +	private final String host;  | 
 | 31 | +	private final String realm;  | 
 | 32 | + | 
 | 33 | +	private String authUrl;  | 
 | 34 | +	private PublicKey publicKey = null;  | 
 | 35 | + | 
 | 36 | +	public void initPublicKey() {  | 
 | 37 | +		String correctedHost = host;  | 
 | 38 | +		String correctedRealm = realm;  | 
 | 39 | + | 
 | 40 | +		if (publicKey != null)  | 
 | 41 | +			return;  | 
 | 42 | +		if (!correctedHost.endsWith("/"))  | 
 | 43 | +			correctedHost += "/";  | 
 | 44 | +		if (!correctedRealm.startsWith("/"))  | 
 | 45 | +			correctedRealm = "/" + correctedRealm;  | 
 | 46 | + | 
 | 47 | +		authUrl = correctedHost + "realms" + correctedRealm + "/protocol/openid-connect/certs";  | 
 | 48 | +		try {  | 
 | 49 | +			log.info("Getting public key from: [{}]", authUrl);  | 
 | 50 | +			publicKey = fetchPublicKey(authUrl);  | 
 | 51 | +		} catch (Exception e) {  | 
 | 52 | +			log.error("There was an error fetching the PublicKey from the openIdConnect-server [{}].", authUrl);  | 
 | 53 | +			throw new IllegalStateException(e);  | 
 | 54 | +		}  | 
 | 55 | +	}  | 
 | 56 | + | 
 | 57 | +	private PublicKey fetchPublicKey(String jwksUrl) throws Exception {  | 
 | 58 | +		ObjectMapper objectMapper = new ObjectMapper();  | 
 | 59 | +		HttpClient client = HttpClient.newHttpClient();  | 
 | 60 | +		HttpRequest request = HttpRequest.newBuilder().uri(URI.create(jwksUrl)).GET().build();  | 
 | 61 | + | 
 | 62 | +		HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());  | 
 | 63 | + | 
 | 64 | +		if (response.statusCode() >= 300) {  | 
 | 65 | +			throw new IOException("Failed to fetch JWKS: HTTP " + response.statusCode());  | 
 | 66 | +		}  | 
 | 67 | + | 
 | 68 | +		JsonNode jwks = objectMapper.readTree(response.body());  | 
 | 69 | +		// Just take the first key for now.  | 
 | 70 | +		JsonNode key = jwks.get("keys").get(0);  | 
 | 71 | + | 
 | 72 | +		String modulusBase64 = key.get("n").asText();  | 
 | 73 | +		String exponentBase64 = key.get("e").asText();  | 
 | 74 | + | 
 | 75 | +		byte[] modulusBytes = Base64.getUrlDecoder().decode(modulusBase64);  | 
 | 76 | +		byte[] exponentBytes = Base64.getUrlDecoder().decode(exponentBase64);  | 
 | 77 | + | 
 | 78 | +		BigInteger modulus = new BigInteger(1, modulusBytes);  | 
 | 79 | +		BigInteger exponent = new BigInteger(1, exponentBytes);  | 
 | 80 | + | 
 | 81 | +		RSAPublicKeySpec spec = new RSAPublicKeySpec(modulus, exponent);  | 
 | 82 | +		KeyFactory factory = KeyFactory.getInstance("RSA");  | 
 | 83 | +		return factory.generatePublic(spec);  | 
 | 84 | +	}  | 
 | 85 | + | 
 | 86 | +	public void checkAccess(String authorizationHeader) {  | 
 | 87 | +		try {  | 
 | 88 | +			TokenVerifier<AccessToken> tokenVerifier = persistUserInfoInContext(authorizationHeader);  | 
 | 89 | +			if (tokenVerifier == null)  | 
 | 90 | +				throw new UnauthorizedException();  | 
 | 91 | + | 
 | 92 | +			initPublicKey();  | 
 | 93 | +			tokenVerifier.publicKey(publicKey);  | 
 | 94 | +			try {  | 
 | 95 | +				tokenVerifier.verifySignature();  | 
 | 96 | +			} catch (VerificationException e) {  | 
 | 97 | +				throw new UnauthorizedException(  | 
 | 98 | +						"Error verifying token from user with publicKey obtained from keycloak.", e);  | 
 | 99 | +			}  | 
 | 100 | + | 
 | 101 | +			try {  | 
 | 102 | +				tokenVerifier.verify();  | 
 | 103 | +				throw new ForbiddenException();  | 
 | 104 | +			} catch (VerificationException e) {  | 
 | 105 | +				throw new ForbiddenException();  | 
 | 106 | +			}  | 
 | 107 | +		} catch (Exception e) {  | 
 | 108 | +			log.error("Error checking token.", e);  | 
 | 109 | +			throw e;  | 
 | 110 | +		}  | 
 | 111 | +	}  | 
 | 112 | + | 
 | 113 | +	private TokenVerifier<AccessToken> persistUserInfoInContext(String authorizationHeader) {  | 
 | 114 | +		if (authorizationHeader == null || authorizationHeader.isBlank())  | 
 | 115 | +			return null;  | 
 | 116 | + | 
 | 117 | +		try {  | 
 | 118 | +			TokenVerifier<AccessToken> tokenVerifier = TokenVerifier.create(authorizationHeader, AccessToken.class);  | 
 | 119 | +			AccessToken token = tokenVerifier.getToken();  | 
 | 120 | +			if (!token.isActive()) {  | 
 | 121 | +				log.warn("Token is inactive.");  | 
 | 122 | +				return null;  | 
 | 123 | +			}  | 
 | 124 | +			// Disabled to enable getting token from side-channels like 'localhost'.  | 
 | 125 | +			/*  | 
 | 126 | +			 * if (!token.getIssuer().equalsIgnoreCase(authUrl)) {  | 
 | 127 | +			 * setTokenRejectionReason(ctx, "Token has wrong real-url."); return null; }  | 
 | 128 | +			 */  | 
 | 129 | +			return tokenVerifier;  | 
 | 130 | + | 
 | 131 | +		} catch (VerificationException e) {  | 
 | 132 | +			log.warn("Token was checked and deemed invalid.", e);  | 
 | 133 | +			return null;  | 
 | 134 | +		}  | 
 | 135 | +	}  | 
 | 136 | +}  | 
0 commit comments