diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java index a0c04ff2d952..f0653bf91522 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java @@ -10,6 +10,9 @@ import com.azure.core.util.CoreUtils; import com.azure.core.util.logging.ClientLogger; import com.azure.identity.implementation.IdentityClientOptions; +import com.azure.identity.implementation.customtokenproxy.CustomTokenProxyConfiguration; +import com.azure.identity.implementation.customtokenproxy.CustomTokenProxyHttpClient; +import com.azure.identity.implementation.customtokenproxy.ProxyConfig; import com.azure.identity.implementation.util.LoggingUtil; import com.azure.identity.implementation.util.ValidationUtil; import reactor.core.publisher.Mono; @@ -89,6 +92,15 @@ public class WorkloadIdentityCredential implements TokenCredential { ClientAssertionCredential tempClientAssertionCredential = null; String tempClientId = null; + if (identityClientOptions.isKubernetesTokenProxyEnabled()) { + if (!CustomTokenProxyConfiguration.isConfigured(configuration)) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Kubernetes token proxy is enabled but not configured.")); + } + ProxyConfig proxyConfig = CustomTokenProxyConfiguration.parseAndValidate(configuration); + identityClientOptions.setHttpClient(new CustomTokenProxyHttpClient(proxyConfig)); + } + if (!(CoreUtils.isNullOrEmpty(tenantIdInput) || CoreUtils.isNullOrEmpty(federatedTokenFilePathInput) || CoreUtils.isNullOrEmpty(clientIdInput) diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java index 95f16fd9628d..5520ad122bf4 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java @@ -47,6 +47,7 @@ public class WorkloadIdentityCredentialBuilder extends AadCredentialBuilderBase { private static final ClientLogger LOGGER = new ClientLogger(WorkloadIdentityCredentialBuilder.class); private String tokenFilePath; + private boolean enableTokenProxy = false; /** * Creates an instance of a WorkloadIdentityCredentialBuilder. @@ -66,6 +67,11 @@ public WorkloadIdentityCredentialBuilder tokenFilePath(String tokenFilePath) { return this; } + public WorkloadIdentityCredentialBuilder enableKubernetesTokenProxy(boolean enable) { + this.enableTokenProxy = enable; + return this; + } + /** * Creates new {@link WorkloadIdentityCredential} with the configured options set. * @@ -88,6 +94,8 @@ public WorkloadIdentityCredential build() { ValidationUtil.validate(this.getClass().getSimpleName(), LOGGER, "Client ID", clientIdInput, "Tenant ID", tenantIdInput, "Service Token File Path", federatedTokenFilePathInput); + identityClientOptions.setEnableKubernetesTokenProxy(this.enableTokenProxy); + return new WorkloadIdentityCredential(tenantIdInput, clientIdInput, federatedTokenFilePathInput, identityClientOptions.clone()); } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java index d934c23c1d67..3c2d21d08c56 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java @@ -71,6 +71,7 @@ public final class IdentityClientOptions implements Cloneable { private List perRetryPolicies; private boolean instanceDiscovery; private String dacEnvConfiguredCredential; + private boolean enableKubernetesTokenProxy; private Duration credentialProcessTimeout = Duration.ofSeconds(10); @@ -833,6 +834,15 @@ public String getDACEnvConfiguredCredential() { return dacEnvConfiguredCredential; } + public boolean isKubernetesTokenProxyEnabled() { + return enableKubernetesTokenProxy; + } + + public IdentityClientOptions setEnableKubernetesTokenProxy(boolean enableTokenProxy) { + this.enableKubernetesTokenProxy = enableTokenProxy; + return this; + } + public IdentityClientOptions clone() { IdentityClientOptions clone = new IdentityClientOptions().setAdditionallyAllowedTenants(this.additionallyAllowedTenants) @@ -863,7 +873,8 @@ public IdentityClientOptions clone() { .setPerRetryPolicies(this.perRetryPolicies) .setBrowserCustomizationOptions(this.browserCustomizationOptions) .setChained(this.isChained) - .subscription(this.subscription); + .subscription(this.subscription) + .setEnableKubernetesTokenProxy(this.enableKubernetesTokenProxy); if (isBrokerEnabled()) { clone.setBrokerWindowHandle(this.brokerWindowHandle); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java new file mode 100644 index 000000000000..4670457b423b --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java @@ -0,0 +1,118 @@ +package com.azure.identity.implementation.customtokenproxy; + +import com.azure.core.util.logging.ClientLogger; + +import java.net.URI; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.net.URISyntaxException; + +import com.azure.core.util.Configuration; +import com.azure.core.util.CoreUtils; + +public class CustomTokenProxyConfiguration { + + private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyConfiguration.class); + + public static final String AZURE_KUBERNETES_TOKEN_PROXY = "AZURE_KUBERNETES_TOKEN_PROXY"; + public static final String AZURE_KUBERNETES_CA_FILE = "AZURE_KUBERNETES_CA_FILE"; + public static final String AZURE_KUBERNETES_CA_DATA = "AZURE_KUBERNETES_CA_DATA"; + public static final String AZURE_KUBERNETES_SNI_NAME = "AZURE_KUBERNETES_SNI_NAME"; + + private CustomTokenProxyConfiguration() { + } + + public static boolean isConfigured(Configuration configuration) { + String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY); + return !CoreUtils.isNullOrEmpty(tokenProxyUrl); + } + + public static ProxyConfig parseAndValidate(Configuration configuration) { + String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY); + String caFile = configuration.get(AZURE_KUBERNETES_CA_FILE); + String caData = configuration.get(AZURE_KUBERNETES_CA_DATA); + String sniName = configuration.get(AZURE_KUBERNETES_SNI_NAME); + + if (CoreUtils.isNullOrEmpty(tokenProxyUrl)) { + if (!CoreUtils.isNullOrEmpty(sniName) + || !CoreUtils.isNullOrEmpty(caFile) + || !CoreUtils.isNullOrEmpty(caData)) { + throw LOGGER.logExceptionAsError(new IllegalStateException( + "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present")); + } + return null; + } + + if (!CoreUtils.isNullOrEmpty(caFile) && !CoreUtils.isNullOrEmpty(caData)) { + throw LOGGER.logExceptionAsError(new IllegalStateException( + "Only one of AZURE_KUBERNETES_CA_FILE or AZURE_KUBERNETES_CA_DATA can be set.")); + } + + URL proxyUrl = validateProxyUrl(tokenProxyUrl); + + byte[] caCertBytes = null; + if (!CoreUtils.isNullOrEmpty(caData)) { + try { + caCertBytes = caData.getBytes(StandardCharsets.UTF_8); + } catch (Exception e) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Failed to decode CA certificate data from AZURE_KUBERNETES_CA_DATA", e)); + } + } + + return new ProxyConfig(proxyUrl, sniName, caFile, caCertBytes); + } + + private static URL validateProxyUrl(String endpoint) { + if (CoreUtils.isNullOrEmpty(endpoint)) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Proxy endpoint cannot be null or empty")); + } + + try { + URI tokenProxy = new URI(endpoint); + + if (!"https".equals(tokenProxy.getScheme())) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Custom token endpoint must use https scheme, got: " + tokenProxy.getScheme())); + } + + if (tokenProxy.getRawUserInfo() != null) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Custom token endpoint URL must not contain user info: " + endpoint)); + } + + if (tokenProxy.getRawQuery() != null) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Custom token endpoint URL must not contain a query: " + endpoint)); + } + + if (tokenProxy.getRawFragment() != null) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Custom token endpoint URL must not contain a fragment: " + endpoint)); + } + + if (tokenProxy.getRawPath() == null || tokenProxy.getRawPath().isEmpty()) { + tokenProxy = new URI(tokenProxy.getScheme(), null, tokenProxy.getHost(), tokenProxy.getPort(), "/", + null, null); + } + + return tokenProxy.toURL(); + + } catch (URISyntaxException | IllegalArgumentException e) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Failed to normalize proxy URL path", e)); + } catch (Exception e) { + throw new RuntimeException("Unexpected error while validating proxy URL: " + endpoint, e); + } + } + + // public static String getTokenProxyUrl() { + // String tokenProxyUrl = System.getenv(AZURE_KUBERNETES_TOKEN_PROXY); + // if (tokenProxyUrl == null || tokenProxyUrl.isEmpty()) { + // throw LOGGER.logExceptionAsError(new IllegalStateException( + // String.format("Environment variable '%s' is not set or is empty. It must be set to the URL of the" + // + " token proxy.", AZURE_KUBERNETES_TOKEN_PROXY))); + // } + // return tokenProxyUrl; + // } + +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java new file mode 100644 index 000000000000..9794dc4a5609 --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java @@ -0,0 +1,220 @@ +package com.azure.identity.implementation.customtokenproxy; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManagerFactory; + +import com.azure.core.http.HttpClient; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; +import com.azure.core.util.Context; +import com.azure.core.util.CoreUtils; +import com.azure.identity.implementation.util.IdentitySslUtil; + +import reactor.core.publisher.Mono; + +public class CustomTokenProxyHttpClient implements HttpClient { + + private final ProxyConfig proxyConfig; + private volatile SSLContext cachedSSLContext; + private volatile byte[] cachedFileContent; + + public CustomTokenProxyHttpClient(ProxyConfig proxyConfig) { + this.proxyConfig = proxyConfig; + } + + @Override + public Mono send(HttpRequest request) { + return Mono.fromCallable(() -> sendSync(request, Context.NONE)); + } + + @Override + public HttpResponse sendSync(HttpRequest request, Context context) { + try { + HttpURLConnection connection = createConnection(request); + return new CustomTokenProxyHttpResponse(request, connection); + } catch (IOException e) { + throw new RuntimeException("Failed to create connection to token proxy", e); + } + } + + private HttpURLConnection createConnection(HttpRequest request) throws IOException { + URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl()); + HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection(); + + // If SNI explicitly provided + try { + SSLContext sslContext = getSSLContext(); + SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); + if (!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) { + sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, proxyConfig.getSniName()); + } + connection.setSSLSocketFactory(sslSocketFactory); + } catch (Exception e) { + throw new RuntimeException("Failed to set up SSL context for token proxy", e); + } + + connection.setRequestMethod(request.getHttpMethod().toString()); + connection.setInstanceFollowRedirects(false); + connection.setConnectTimeout(10_000); + connection.setReadTimeout(20_000); + connection.setDoOutput(true); + + request.getHeaders().forEach(header -> { + connection.addRequestProperty(header.getName(), header.getValue()); + }); + + if (request.getBodyAsBinaryData() != null) { + byte[] bytes = request.getBodyAsBinaryData().toBytes(); + if (bytes != null && bytes.length > 0) { + connection.getOutputStream().write(bytes); + } + } + + return connection; + } + + private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLException { + try { + String originalPath = originalUrl.getPath(); + String originalQuery = originalUrl.getQuery(); + + String tokenProxyBase = proxyConfig.getTokenProxyUrl().toString(); + if (!tokenProxyBase.endsWith("/")) + tokenProxyBase += "/"; + + URI combined = URI.create(tokenProxyBase) + .resolve(originalPath.startsWith("/") ? originalPath.substring(1) : originalPath); + + String combinedStr = combined.toString(); + if (originalQuery != null && !originalQuery.isEmpty()) { + combinedStr += "?" + originalQuery; + } + + return new URL(combinedStr); + + } catch (Exception e) { + throw new RuntimeException("Failed to rewrite token request for proxy", e); + } + } + + private SSLContext getSSLContext() { + try { + // If no CA override provided, use default + if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile()) + && (proxyConfig.getCaData() == null || proxyConfig.getCaData().length == 0)) { + synchronized (this) { + if (cachedSSLContext == null) { + cachedSSLContext = SSLContext.getDefault(); + } + } + return cachedSSLContext; + } + + // If CA data provided, use it + if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile())) { + synchronized (this) { + if (cachedSSLContext == null) { + cachedSSLContext = createSslContextFromBytes(proxyConfig.getCaData()); + } + } + return cachedSSLContext; + } + + // If CA file provided, read it (and re-read if it changes) + Path path = Paths.get(proxyConfig.getCaFile()); + if (!Files.exists(path)) { + throw new IOException("CA file not found: " + proxyConfig.getCaFile()); + } + + byte[] currentContent = Files.readAllBytes(path); + + synchronized (this) { + if (currentContent.length == 0) { + throw new IOException("CA file " + proxyConfig.getCaFile() + " is empty"); + } + + if (cachedSSLContext == null || !Arrays.equals(currentContent, cachedFileContent)) { + cachedSSLContext = createSslContextFromBytes(currentContent); + cachedFileContent = currentContent; + } + } + + return cachedSSLContext; + + } catch (Exception e) { + throw new RuntimeException("Failed to initialize SSLContext for proxy", e); + } + } + + // Create SSLContext from byte array containing PEM certificate data + private SSLContext createSslContextFromBytes(byte[] certificateData) { + try (InputStream inputStream = new ByteArrayInputStream(certificateData)) { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + + List certificates = new ArrayList<>(); + // while(inputStream.available() > 0) { + // X509Certificate cert = (X509Certificate) cf.generateCertificate(inputStream); + // certificates.add(cert); + // } + while (true) { + try { + X509Certificate cert = (X509Certificate) cf.generateCertificate(inputStream); + certificates.add(cert); + } catch (CertificateException e) { + break; // end of stream + } + } + + if (certificates.isEmpty()) { + throw new RuntimeException("No valid certificates found"); + } + + // X509Certificate caCert = certificates.get(0); + return createSslContext(certificates); + } catch (Exception e) { + throw new RuntimeException("Failed to create SSLContext from bytes", e); + } + } + + // Create SSLContext from a single X509Certificate + private SSLContext createSslContext(List certificates) { + try { + KeyStore keystore = KeyStore.getInstance(KeyStore.getDefaultType()); + keystore.load(null, null); + int index = 1; + for (X509Certificate caCert : certificates) { + keystore.setCertificateEntry("ca-cert-" + index, caCert); + index++; + } + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(keystore); + + SSLContext context = SSLContext.getInstance("TLS"); + context.init(null, tmf.getTrustManagers(), null); + return context; + } catch (Exception e) { + throw new RuntimeException("Failed to create SSLContext", e); + } + } + +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java new file mode 100644 index 000000000000..f9fe17265891 --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java @@ -0,0 +1,135 @@ +package com.azure.identity.implementation.customtokenproxy; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import com.azure.core.http.HttpHeaderName; +import com.azure.core.http.HttpHeaders; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class CustomTokenProxyHttpResponse extends HttpResponse { + + // private final HttpRequest request; + private final int statusCode; + private final HttpHeaders headers; + private final HttpURLConnection connection; + private byte[] cachedResponseBodyBytes; + + public CustomTokenProxyHttpResponse(HttpRequest request, HttpURLConnection connection) { + super(request); + this.connection = connection; + this.statusCode = extractStatusCode(connection); + this.headers = extractHeaders(connection); + } + + private HttpHeaders extractHeaders(HttpURLConnection connection) { + HttpHeaders headers = new HttpHeaders(); + for (Map.Entry> entry : connection.getHeaderFields().entrySet()) { + String headerName = entry.getKey(); + if (headerName != null) { + for (String headerValue : entry.getValue()) { + headers.add(headerName, headerValue); + } + } + } + return headers; + } + + private int extractStatusCode(HttpURLConnection connection) { + try { + return connection.getResponseCode(); + } catch (IOException e) { + throw new RuntimeException("Failed to get status code from token proxy response", e); + } + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getHeaderValue(String name) { + return headers.getValue(HttpHeaderName.fromString(name)); + } + + @Override + public HttpHeaders getHeaders() { + return headers; + } + + // @Override + // public Mono getBodyAsByteArray() { + // return Mono.fromCallable(() -> { + // try (InputStream inputStream = connection.getInputStream()) { + // return inputStream.readAllBytes(); + // } catch (IOException e) { + // throw new RuntimeException("Failed to read body from token proxy response", e); + // } + // }); + // } + + @Override + public Mono getBodyAsByteArray() { + return Mono.fromCallable(() -> { + if (cachedResponseBodyBytes != null) { + return cachedResponseBodyBytes; + } + try (InputStream stream = getResponseStream()) { + if (stream == null) { + cachedResponseBodyBytes = new byte[0]; + return cachedResponseBodyBytes; + } + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + byte[] tmp = new byte[4096]; + int n; + while ((n = stream.read(tmp)) != -1) { + buffer.write(tmp, 0, n); + } + cachedResponseBodyBytes = buffer.toByteArray(); + return cachedResponseBodyBytes; + } + }); + } + + @Override + public Flux getBody() { + return getBodyAsByteArray().flatMapMany(bytes -> Flux.just(ByteBuffer.wrap(bytes))); + } + + @Override + public Mono getBodyAsString() { + return getBodyAsString(StandardCharsets.UTF_8); + } + + @Override + public Mono getBodyAsString(Charset charset) { + return getBodyAsByteArray().map(bytes -> new String(bytes, charset)); + } + + @Override + public void close() { + connection.disconnect(); + } + + private InputStream getResponseStream() throws IOException { + try { + return connection.getInputStream(); + } catch (IOException e) { + // On non-2xx responses, getInputStream() throws, use error stream instead + return connection.getErrorStream(); + } + } + +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java new file mode 100644 index 000000000000..a104a8ab7053 --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java @@ -0,0 +1,33 @@ +package com.azure.identity.implementation.customtokenproxy; + +import java.net.URL; + +public class ProxyConfig { + private final URL tokenProxyUrl; + private final String sniName; + private final String caFile; + private final byte[] caData; + + public ProxyConfig(URL tokenProxyUrl, String sniName, String caFile, byte[] caData) { + this.tokenProxyUrl = tokenProxyUrl; + this.sniName = sniName; + this.caFile = caFile; + this.caData = caData; + } + + public URL getTokenProxyUrl() { + return tokenProxyUrl; + } + + public String getSniName() { + return sniName; + } + + public String getCaFile() { + return caFile; + } + + public byte[] getCaData() { + return caData; + } +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java index ba932d5e5a55..01d0e7edf894 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java @@ -3,15 +3,22 @@ package com.azure.identity.implementation.util; +import com.azure.core.util.CoreUtils; import com.azure.core.util.logging.ClientLogger; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SNIHostName; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; + +import java.io.IOException; +import java.net.Socket; import java.security.KeyManagementException; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -19,6 +26,9 @@ import java.security.cert.CertificateEncodingException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; public final class IdentitySslUtil { public static final HostnameVerifier ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER; @@ -125,4 +135,71 @@ private static String extractCertificateThumbprint(Certificate certificate, Clie throw logger.logExceptionAsError(new RuntimeException(e)); } } + + public static final class SniSslSocketFactory extends SSLSocketFactory { + private final SSLSocketFactory sslSocketFactory; + private final String sniName; + + public SniSslSocketFactory(SSLSocketFactory sslSocketFactory, String sniName) { + this.sslSocketFactory = sslSocketFactory; + this.sniName = sniName; + } + + @Override + public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException { + Socket sslSocket = (SSLSocket) sslSocketFactory.createSocket(s, host, port, autoClose); + configureSni(sslSocket); + return sslSocket; + } + + @Override + public Socket createSocket(String host, int port) throws IOException { + Socket socket = sslSocketFactory.createSocket(host, port); + configureSni(socket); + return socket; + } + + @Override + public Socket createSocket(String host, int port, java.net.InetAddress localAddress, int localPort) + throws IOException { + Socket socket = sslSocketFactory.createSocket(host, port, localAddress, localPort); + configureSni(socket); + return socket; + } + + @Override + public Socket createSocket(java.net.InetAddress host, int port) throws IOException { + Socket socket = sslSocketFactory.createSocket(host, port); + configureSni(socket); + return socket; + } + + @Override + public Socket createSocket(java.net.InetAddress address, int port, java.net.InetAddress localAddress, + int localPort) throws IOException { + Socket socket = sslSocketFactory.createSocket(address, port, localAddress, localPort); + configureSni(socket); + return socket; + } + + @Override + public String[] getDefaultCipherSuites() { + return sslSocketFactory.getDefaultCipherSuites(); + } + + @Override + public String[] getSupportedCipherSuites() { + return sslSocketFactory.getSupportedCipherSuites(); + } + + private void configureSni(Socket socket) { + if (socket instanceof SSLSocket && !CoreUtils.isNullOrEmpty(sniName)) { + SSLSocket sslSocket = (SSLSocket) socket; + SSLParameters sslParameters = sslSocket.getSSLParameters(); + sslParameters.setServerNames(Collections.singletonList(new SNIHostName(sniName))); + sslSocket.setSSLParameters(sslParameters); + } + } + + } }