From b7ca0dcda6560a5d799022fa403672f7aa963aef Mon Sep 17 00:00:00 2001 From: anannya03 Date: Sat, 18 Oct 2025 00:39:13 -0700 Subject: [PATCH 1/3] changes to add proxy in wic --- .../CustomTokenProxyConfiguration.java | 119 ++++++++++ .../CustomTokenProxyHttpClient.java | 218 ++++++++++++++++++ .../CustomTokenProxyHttpResponse.java | 154 +++++++++++++ .../customtokenproxy/ProxyConfig.java | 33 +++ .../implementation/util/IdentitySslUtil.java | 76 ++++++ 5 files changed, 600 insertions(+) create mode 100644 sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java create mode 100644 sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java create mode 100644 sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java create mode 100644 sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java new file mode 100644 index 000000000000..b0849d29bf56 --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java @@ -0,0 +1,119 @@ +package com.azure.identity.implementation.customtokenproxy; + +import com.azure.core.util.logging.ClientLogger; +import com.azure.identity.implementation.WorkloadIdentityTokenProxyPolicy; + +import java.net.URI; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.net.URISyntaxException; + +import com.azure.core.util.Configuration; +import com.azure.core.util.CoreUtils; + +public class CustomTokenProxyConfiguration { + + private static final ClientLogger LOGGER = new ClientLogger(WorkloadIdentityTokenProxyPolicy.class); + + public static final String AZURE_KUBERNETES_TOKEN_PROXY = "AZURE_KUBERNETES_TOKEN_PROXY"; + public static final String AZURE_KUBERNETES_CA_FILE = "AZURE_KUBERNETES_CA_FILE"; + public static final String AZURE_KUBERNETES_CA_DATA = "AZURE_KUBERNETES_CA_DATA"; + public static final String AZURE_KUBERNETES_SNI_NAME = "AZURE_KUBERNETES_SNI_NAME"; + + private CustomTokenProxyConfiguration() {} + + public static boolean isConfigured(Configuration configuration) { + String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY); + return !CoreUtils.isNullOrEmpty(tokenProxyUrl); + } + + public static ProxyConfig parseAndValidate(Configuration configuration) { + String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY); + String caFile = configuration.get(AZURE_KUBERNETES_CA_FILE); + String caData = configuration.get(AZURE_KUBERNETES_CA_DATA); + String sniName = configuration.get(AZURE_KUBERNETES_SNI_NAME); + + if (CoreUtils.isNullOrEmpty(tokenProxyUrl)) { + if (!CoreUtils.isNullOrEmpty(sniName) + || !CoreUtils.isNullOrEmpty(caFile) + || !CoreUtils.isNullOrEmpty(caData)) { + throw LOGGER.logExceptionAsError(new IllegalStateException( + "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present")); + } + throw LOGGER.logExceptionAsError(new IllegalStateException( + "AZURE_KUBERNETES_TOKEN_PROXY must be set to enable custom token proxy.")); + } + + if (!CoreUtils.isNullOrEmpty(caFile) && !CoreUtils.isNullOrEmpty(caData)) { + throw LOGGER.logExceptionAsError(new IllegalStateException( + "Only one of AZURE_KUBERNETES_CA_FILE or AZURE_KUBERNETES_CA_DATA can be set.")); + } + + URL proxyUrl = validateProxyUrl(tokenProxyUrl); + + byte[] caCertBytes = null; + if(!CoreUtils.isNullOrEmpty(caData)) { + try { + caCertBytes = caData.getBytes(StandardCharsets.UTF_8); + } catch (Exception e) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Failed to decode CA certificate data from AZURE_KUBERNETES_CA_DATA", e)); + } + } + + return new ProxyConfig(proxyUrl, sniName, caFile, caCertBytes); + } + + private static URL validateProxyUrl(String endpoint) { + if (CoreUtils.isNullOrEmpty(endpoint)) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Proxy endpoint cannot be null or empty")); + } + + try { + URI tokenProxy = new URI(endpoint); + + if (!"https".equals(tokenProxy.getScheme())) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Custom token endpoint must use https scheme, got: " + tokenProxy.getScheme())); + } + + if (tokenProxy.getRawUserInfo() != null) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Custom token endpoint URL must not contain user info: " + endpoint)); + } + + if (tokenProxy.getRawQuery() != null) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Custom token endpoint URL must not contain a query: " + endpoint)); + } + + if (tokenProxy.getRawFragment() != null) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Custom token endpoint URL must not contain a fragment: " + endpoint)); + } + + if (tokenProxy.getRawPath() == null || tokenProxy.getRawPath().isEmpty()) { + tokenProxy = new URI(tokenProxy.getScheme(), null, tokenProxy.getHost(), + tokenProxy.getPort(), "/", null, null); + } + + return tokenProxy.toURL(); + + } catch (URISyntaxException | IllegalArgumentException e) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Failed to normalize proxy URL path", e)); + } catch (Exception e) { + throw new RuntimeException("Unexpected error while validating proxy URL: " + endpoint, e); + } + } + + // public static String getTokenProxyUrl() { + // String tokenProxyUrl = System.getenv(AZURE_KUBERNETES_TOKEN_PROXY); + // if (tokenProxyUrl == null || tokenProxyUrl.isEmpty()) { + // throw LOGGER.logExceptionAsError(new IllegalStateException( + // String.format("Environment variable '%s' is not set or is empty. It must be set to the URL of the" + // + " token proxy.", AZURE_KUBERNETES_TOKEN_PROXY))); + // } + // return tokenProxyUrl; + // } + +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java new file mode 100644 index 000000000000..194d306eef5b --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java @@ -0,0 +1,218 @@ +package com.azure.identity.implementation.customtokenproxy; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.MalformedParametersException; +import java.lang.reflect.Proxy; +import java.net.HttpURLConnection; +import java.net.URI; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.KeyStore; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Arrays; + +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManagerFactory; + +import com.azure.core.http.HttpClient; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; +import com.azure.core.util.Context; +import com.azure.core.util.CoreUtils; +import com.azure.identity.implementation.util.IdentitySslUtil; + +import reactor.core.publisher.Mono; + +public class CustomTokenProxyHttpClient implements HttpClient { + + private final ProxyConfig proxyConfig; + private volatile SSLContext cachedSSLContext; + private volatile byte[] cachedFileContent; + + public CustomTokenProxyHttpClient(ProxyConfig proxyConfig) { + this.proxyConfig = proxyConfig; + } + + @Override + public Mono send(HttpRequest request) { + return Mono.fromCallable(() -> sendSync(request, Context.NONE)); + } + + @Override + public HttpResponse sendSync(HttpRequest request, Context context) { + try { + HttpURLConnection connection = createConnection(request); + connection.connect(); + return new CustomTokenProxyHttpResponse(request, connection); + } catch (IOException e) { + throw new RuntimeException("Failed to create connection to token proxy", e); + } + } + + // private HttpURLConnection createConnection(HttpRequest request) throws IOException { + // URL updateProxyRequest = rewriteTokenRequestForProxy(request.getUrl()); + // HttpsURLConnection connection = (HttpsURLConnection) updateProxyRequest.openConnection(); + // try { + // SSLContext sslContext = getSSLContext(); + // connection.setSSLSocketFactory(sslContext.getSocketFactory()); + + // if(!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) { + // SSLParameters sslParameters = connection.getSSLParameters(); + // } + // } + // return connection; + + // // connection.setRequestMethod(request.getMethod().toString()); + // // connection.setDoOutput(true); + // // request.getHeaders().forEach((key, values) -> { + // // values.forEach(value -> connection.addRequestProperty(key, value)); + // // }); + // // return connection; + // } + + private HttpURLConnection createConnection(HttpRequest request) throws IOException { + URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl()); + HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection(); + + // If SNI explicitly provided + try { + SSLContext sslContext = getSSLContext(); + SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); + if(!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) { + sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, proxyConfig.getSniName()); + } + connection.setSSLSocketFactory(sslSocketFactory); + } catch (Exception e) { + throw new RuntimeException("Failed to set up SSL context for token proxy", e); + } + + connection.setRequestMethod(request.getHttpMethod().toString()); + // connection.setConnectTimeout(10_000); + // connection.setReadTimeout(20_000); + connection.setDoOutput(true); + + request.getHeaders().forEach(header -> { + connection.addRequestProperty(header.getName(), header.getValue()); + }); + + if (request.getBodyAsBinaryData() != null) { + byte[] bytes = request.getBodyAsBinaryData().toBytes(); + if (bytes != null && bytes.length > 0) { + connection.getOutputStream().write(bytes); + } + } + + return connection; + } + + + private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedParametersException{ + try { + String originalPath = originalUrl.getPath(); + String originalQuery = originalUrl.getQuery(); + + String tokenProxyBase = proxyConfig.getTokenProxyUrl().toString(); + if(!tokenProxyBase.endsWith("/")) tokenProxyBase += "/"; + + URI combined = URI.create(tokenProxyBase).resolve(originalPath.startsWith("/") ? originalPath.substring(1) : originalPath); + + String combinedStr = combined.toString(); + if (originalQuery != null && !originalQuery.isEmpty()) { + combinedStr += "?" + originalQuery; + } + + return new URL(combinedStr); + + } catch (Exception e) { + throw new RuntimeException("Failed to rewrite token request for proxy", e); + } + } + + private SSLContext getSSLContext() { + try { + // If no CA override provide, use default + if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile()) + && (proxyConfig.getCaData() == null || proxyConfig.getCaData().length == 0)) { + synchronized (this) { + if (cachedSSLContext == null) { + cachedSSLContext = SSLContext.getDefault(); + } + } + return cachedSSLContext; + } + + // If CA data provided, use it + if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile())) { + synchronized (this) { + if (cachedSSLContext == null) { + cachedSSLContext = createSslContextFromBytes(proxyConfig.getCaData()); + } + } + return cachedSSLContext; + } + + // If CA file provided, read it (and re-read if it changes) + Path path = Paths.get(proxyConfig.getCaFile()); + if (!Files.exists(path)) { + throw new IOException("CA file not found: " + proxyConfig.getCaFile()); + } + + byte[] currentContent; + + synchronized (this) { + currentContent = Files.readAllBytes(path); + if (currentContent.length == 0) { + throw new IOException("CA file " + proxyConfig.getCaFile() + " is empty"); + } + + if (cachedSSLContext == null || !Arrays.equals(currentContent, cachedFileContent)) { + cachedSSLContext = createSslContextFromBytes(currentContent); + cachedFileContent = currentContent; + } + } + + return cachedSSLContext; + + } catch (Exception e) { + throw new RuntimeException("Failed to create default SSLContext", e); + } + } + + // Create SSLContext from byte array containing PEM certificate data + private SSLContext createSslContextFromBytes(byte[] certificateData) { + try (InputStream inputStream = new ByteArrayInputStream(certificateData)) { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + X509Certificate caCert = (X509Certificate) cf.generateCertificate(inputStream); + return createSslContext(caCert); + } catch (Exception e) { + throw new RuntimeException("Failed to create SSLContext from bytes", e); + } + } + + // Create SSLContext from a single X509Certificate + private SSLContext createSslContext(X509Certificate caCert) { + try { + KeyStore keystore = KeyStore.getInstance(KeyStore.getDefaultType()); + keystore.load(null, null); + keystore.setCertificateEntry("ca-cert", caCert); + + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(keystore); + + SSLContext context = SSLContext.getInstance("TLS"); + context.init(null, tmf.getTrustManagers(), null); + return context; + } catch (Exception e) { + throw new RuntimeException("Failed to create SSLContext", e); + } + } + +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java new file mode 100644 index 000000000000..33d148c86927 --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java @@ -0,0 +1,154 @@ +package com.azure.identity.implementation.customtokenproxy; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import com.azure.core.http.HttpHeaderName; +import com.azure.core.http.HttpHeaders; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class CustomTokenProxyHttpResponse extends HttpResponse { + + // private final HttpRequest request; + private final int statusCode; + private final HttpHeaders headers; + private final HttpURLConnection connection; + private byte[] cachedRequestBodyBytes; + + public CustomTokenProxyHttpResponse(HttpRequest request, HttpURLConnection connection) { + super(request); + this.connection = connection; + this.statusCode = extractStatusCode(connection); + this.headers = extractHeaders(connection); + } + + private HttpHeaders extractHeaders(HttpURLConnection connection) { + HttpHeaders headers = new HttpHeaders(); + for (Map.Entry> entry : connection.getHeaderFields().entrySet()) { + String headerName = entry.getKey(); + if (headerName != null) { + for (String headerValue : entry.getValue()) { + headers.add(headerName, headerValue); + } + } + } + return headers; + } + + public int extractStatusCode(HttpURLConnection connection) { + try { + return connection.getResponseCode(); + } catch (IOException e) { + throw new RuntimeException("Failed to get status code from token proxy response", e); + } + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getHeaderValue(String name) { + return headers.getValue(HttpHeaderName.fromString(name)); + } + + @Override + public HttpHeaders getHeaders() { + return headers; + } + + // @Override + // public Mono getBodyAsByteArray() { + // return Mono.fromCallable(() -> { + // try (InputStream inputStream = connection.getInputStream()) { + // return inputStream.readAllBytes(); + // } catch (IOException e) { + // throw new RuntimeException("Failed to read body from token proxy response", e); + // } + // }); + // } + + @Override + public Mono getBodyAsByteArray() { + return Mono.fromCallable(() -> { + if (cachedRequestBodyBytes != null) { + return cachedRequestBodyBytes; + } + try (InputStream stream = getResponseStream()) { + if (stream == null) { + cachedRequestBodyBytes = new byte[0]; + return cachedRequestBodyBytes; + } + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + byte[] tmp = new byte[4096]; + int n; + while ((n = stream.read(tmp)) != -1) { + buffer.write(tmp, 0, n); + } + cachedRequestBodyBytes = buffer.toByteArray(); + return cachedRequestBodyBytes; + } + }); + } + + @Override + public Flux getBody() { + return getBodyAsByteArray().flatMapMany(bytes -> Flux.just(ByteBuffer.wrap(bytes))); + } + + @Override + public Mono getBodyAsString() { + return getBodyAsString(StandardCharsets.UTF_8); + } + + @Override + public Mono getBodyAsString(Charset charset) { + return getBodyAsByteArray().map(bytes -> new String(bytes, charset)); + } + + + @Override + public void close() { + connection.disconnect(); + } + + + + // @Override + // public Flux getBody() { + // // TODO Auto-generated method stub + // throw new UnsupportedOperationException("Unimplemented method 'getBody'"); + // } + + // @Override + // public Mono getBodyAsString() { + // return getBodyAsByteArray().map(bytes -> new String(bytes, StandardCharsets.UTF_8)); + // } + + // @Override + // public Mono getBodyAsString(Charset charset) { + // return getBodyAsByteArray().map(bytes -> new String(bytes, charset)); + // } + + private InputStream getResponseStream() throws IOException { + try { + return connection.getInputStream(); + } catch (IOException e) { + // On non-2xx responses, getInputStream() throws, use error stream instead + return connection.getErrorStream(); + } + } + +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java new file mode 100644 index 000000000000..a104a8ab7053 --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java @@ -0,0 +1,33 @@ +package com.azure.identity.implementation.customtokenproxy; + +import java.net.URL; + +public class ProxyConfig { + private final URL tokenProxyUrl; + private final String sniName; + private final String caFile; + private final byte[] caData; + + public ProxyConfig(URL tokenProxyUrl, String sniName, String caFile, byte[] caData) { + this.tokenProxyUrl = tokenProxyUrl; + this.sniName = sniName; + this.caFile = caFile; + this.caData = caData; + } + + public URL getTokenProxyUrl() { + return tokenProxyUrl; + } + + public String getSniName() { + return sniName; + } + + public String getCaFile() { + return caFile; + } + + public byte[] getCaData() { + return caData; + } +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java index ba932d5e5a55..604dba0b59bf 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java @@ -3,15 +3,23 @@ package com.azure.identity.implementation.util; +import com.azure.core.util.CoreUtils; import com.azure.core.util.logging.ClientLogger; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; + +import java.io.IOException; +import java.net.Socket; import java.security.KeyManagementException; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -19,6 +27,9 @@ import java.security.cert.CertificateEncodingException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; public final class IdentitySslUtil { public static final HostnameVerifier ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER; @@ -125,4 +136,69 @@ private static String extractCertificateThumbprint(Certificate certificate, Clie throw logger.logExceptionAsError(new RuntimeException(e)); } } + + public static final class SniSslSocketFactory extends SSLSocketFactory { + private final SSLSocketFactory sslSocketFactory; + private final String sniName; + + public SniSslSocketFactory(SSLSocketFactory sslSocketFactory, String sniName) { + this.sslSocketFactory = sslSocketFactory; + this.sniName = sniName; + } + + @Override + public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException { + Socket sslSocket = (SSLSocket) sslSocketFactory.createSocket(s, host, port, autoClose); + configureSni(sslSocket); + return sslSocket; + } + + @Override + public Socket createSocket(String host, int port) throws IOException { + Socket socket = sslSocketFactory.createSocket(host, port); + configureSni(socket); + return socket; + } + + @Override + public Socket createSocket(String host, int port, java.net.InetAddress localAddress, int localPort) throws IOException { + Socket socket = sslSocketFactory.createSocket(host, port, localAddress, localPort); + configureSni(socket); + return socket; + } + + @Override + public Socket createSocket(java.net.InetAddress host, int port) throws IOException { + Socket socket = sslSocketFactory.createSocket(host, port); + configureSni(socket); + return socket; + } + + @Override + public Socket createSocket(java.net.InetAddress address, int port, java.net.InetAddress localAddress, int localPort) throws IOException { + Socket socket = sslSocketFactory.createSocket(address, port, localAddress, localPort); + configureSni(socket); + return socket; + } + + @Override + public String[] getDefaultCipherSuites() { + return sslSocketFactory.getDefaultCipherSuites(); + } + + @Override + public String[] getSupportedCipherSuites() { + return sslSocketFactory.getSupportedCipherSuites(); + } + + private void configureSni(Socket socket) { + if (socket instanceof SSLSocket && !CoreUtils.isNullOrEmpty(sniName)) { + SSLSocket sslSocket = (SSLSocket) socket; + SSLParameters sslParameters = sslSocket.getSSLParameters(); + sslParameters.setServerNames(Collections.singletonList(new SNIHostName(sniName))); + sslSocket.setSSLParameters(sslParameters); + } + } + + } } From b162676ee798f0302b67cf957326b0b425fab047 Mon Sep 17 00:00:00 2001 From: anannya03 Date: Sun, 19 Oct 2025 02:07:12 -0700 Subject: [PATCH 2/3] integration with wicbuilder and wic --- .../identity/WorkloadIdentityCredential.java | 11 +++ .../WorkloadIdentityCredentialBuilder.java | 8 +++ .../implementation/IdentityClientOptions.java | 15 +++- .../CustomTokenProxyConfiguration.java | 6 +- .../CustomTokenProxyHttpClient.java | 72 ++++++++++--------- .../CustomTokenProxyHttpResponse.java | 32 ++------- 6 files changed, 79 insertions(+), 65 deletions(-) diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java index a0c04ff2d952..3fedf5db30d9 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java @@ -10,6 +10,9 @@ import com.azure.core.util.CoreUtils; import com.azure.core.util.logging.ClientLogger; import com.azure.identity.implementation.IdentityClientOptions; +import com.azure.identity.implementation.customtokenproxy.CustomTokenProxyConfiguration; +import com.azure.identity.implementation.customtokenproxy.CustomTokenProxyHttpClient; +import com.azure.identity.implementation.customtokenproxy.ProxyConfig; import com.azure.identity.implementation.util.LoggingUtil; import com.azure.identity.implementation.util.ValidationUtil; import reactor.core.publisher.Mono; @@ -89,6 +92,14 @@ public class WorkloadIdentityCredential implements TokenCredential { ClientAssertionCredential tempClientAssertionCredential = null; String tempClientId = null; + if(identityClientOptions.isKubernetesTokenProxyEnabled()) { + if (!CustomTokenProxyConfiguration.isConfigured(configuration)) { + throw LOGGER.logExceptionAsError (new IllegalArgumentException("Kubernetes token proxy is enabled but not configured.")); + } + ProxyConfig proxyConfig = CustomTokenProxyConfiguration.parseAndValidate(configuration); + identityClientOptions.setHttpClient(new CustomTokenProxyHttpClient(proxyConfig)); + } + if (!(CoreUtils.isNullOrEmpty(tenantIdInput) || CoreUtils.isNullOrEmpty(federatedTokenFilePathInput) || CoreUtils.isNullOrEmpty(clientIdInput) diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java index 95f16fd9628d..5520ad122bf4 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java @@ -47,6 +47,7 @@ public class WorkloadIdentityCredentialBuilder extends AadCredentialBuilderBase { private static final ClientLogger LOGGER = new ClientLogger(WorkloadIdentityCredentialBuilder.class); private String tokenFilePath; + private boolean enableTokenProxy = false; /** * Creates an instance of a WorkloadIdentityCredentialBuilder. @@ -66,6 +67,11 @@ public WorkloadIdentityCredentialBuilder tokenFilePath(String tokenFilePath) { return this; } + public WorkloadIdentityCredentialBuilder enableKubernetesTokenProxy(boolean enable) { + this.enableTokenProxy = enable; + return this; + } + /** * Creates new {@link WorkloadIdentityCredential} with the configured options set. * @@ -88,6 +94,8 @@ public WorkloadIdentityCredential build() { ValidationUtil.validate(this.getClass().getSimpleName(), LOGGER, "Client ID", clientIdInput, "Tenant ID", tenantIdInput, "Service Token File Path", federatedTokenFilePathInput); + identityClientOptions.setEnableKubernetesTokenProxy(this.enableTokenProxy); + return new WorkloadIdentityCredential(tenantIdInput, clientIdInput, federatedTokenFilePathInput, identityClientOptions.clone()); } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java index d934c23c1d67..4acab6df23ad 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java @@ -71,6 +71,7 @@ public final class IdentityClientOptions implements Cloneable { private List perRetryPolicies; private boolean instanceDiscovery; private String dacEnvConfiguredCredential; + private boolean enableKubernetesTokenProxy; private Duration credentialProcessTimeout = Duration.ofSeconds(10); @@ -833,6 +834,15 @@ public String getDACEnvConfiguredCredential() { return dacEnvConfiguredCredential; } + public boolean isKubernetesTokenProxyEnabled() { + return enableKubernetesTokenProxy; + } + + public IdentityClientOptions setEnableKubernetesTokenProxy(boolean enableTokenProxy) { + this.enableKubernetesTokenProxy = enableTokenProxy; + return this; + } + public IdentityClientOptions clone() { IdentityClientOptions clone = new IdentityClientOptions().setAdditionallyAllowedTenants(this.additionallyAllowedTenants) @@ -863,8 +873,9 @@ public IdentityClientOptions clone() { .setPerRetryPolicies(this.perRetryPolicies) .setBrowserCustomizationOptions(this.browserCustomizationOptions) .setChained(this.isChained) - .subscription(this.subscription); - + .subscription(this.subscription) + .setEnableKubernetesTokenProxy(this.enableKubernetesTokenProxy); + if (isBrokerEnabled()) { clone.setBrokerWindowHandle(this.brokerWindowHandle); clone.setEnableLegacyMsaPassthrough(this.enableMsaPassthrough); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java index b0849d29bf56..0b25b64935b3 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java @@ -1,7 +1,6 @@ package com.azure.identity.implementation.customtokenproxy; import com.azure.core.util.logging.ClientLogger; -import com.azure.identity.implementation.WorkloadIdentityTokenProxyPolicy; import java.net.URI; import java.net.URL; @@ -13,7 +12,7 @@ public class CustomTokenProxyConfiguration { - private static final ClientLogger LOGGER = new ClientLogger(WorkloadIdentityTokenProxyPolicy.class); + private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyConfiguration.class); public static final String AZURE_KUBERNETES_TOKEN_PROXY = "AZURE_KUBERNETES_TOKEN_PROXY"; public static final String AZURE_KUBERNETES_CA_FILE = "AZURE_KUBERNETES_CA_FILE"; @@ -40,8 +39,7 @@ public static ProxyConfig parseAndValidate(Configuration configuration) { throw LOGGER.logExceptionAsError(new IllegalStateException( "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present")); } - throw LOGGER.logExceptionAsError(new IllegalStateException( - "AZURE_KUBERNETES_TOKEN_PROXY must be set to enable custom token proxy.")); + return null; } if (!CoreUtils.isNullOrEmpty(caFile) && !CoreUtils.isNullOrEmpty(caData)) { diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java index 194d306eef5b..e37e166a02fe 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java @@ -6,15 +6,19 @@ import java.lang.reflect.MalformedParametersException; import java.lang.reflect.Proxy; import java.net.HttpURLConnection; +import java.net.MalformedURLException; import java.net.URI; import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.security.KeyStore; +import java.security.cert.CertificateException; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; @@ -50,34 +54,12 @@ public Mono send(HttpRequest request) { public HttpResponse sendSync(HttpRequest request, Context context) { try { HttpURLConnection connection = createConnection(request); - connection.connect(); return new CustomTokenProxyHttpResponse(request, connection); } catch (IOException e) { throw new RuntimeException("Failed to create connection to token proxy", e); } } - // private HttpURLConnection createConnection(HttpRequest request) throws IOException { - // URL updateProxyRequest = rewriteTokenRequestForProxy(request.getUrl()); - // HttpsURLConnection connection = (HttpsURLConnection) updateProxyRequest.openConnection(); - // try { - // SSLContext sslContext = getSSLContext(); - // connection.setSSLSocketFactory(sslContext.getSocketFactory()); - - // if(!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) { - // SSLParameters sslParameters = connection.getSSLParameters(); - // } - // } - // return connection; - - // // connection.setRequestMethod(request.getMethod().toString()); - // // connection.setDoOutput(true); - // // request.getHeaders().forEach((key, values) -> { - // // values.forEach(value -> connection.addRequestProperty(key, value)); - // // }); - // // return connection; - // } - private HttpURLConnection createConnection(HttpRequest request) throws IOException { URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl()); HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection(); @@ -95,8 +77,9 @@ private HttpURLConnection createConnection(HttpRequest request) throws IOExcepti } connection.setRequestMethod(request.getHttpMethod().toString()); - // connection.setConnectTimeout(10_000); - // connection.setReadTimeout(20_000); + connection.setInstanceFollowRedirects(false); + connection.setConnectTimeout(10_000); + connection.setReadTimeout(20_000); connection.setDoOutput(true); request.getHeaders().forEach(header -> { @@ -114,7 +97,7 @@ private HttpURLConnection createConnection(HttpRequest request) throws IOExcepti } - private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedParametersException{ + private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLException{ try { String originalPath = originalUrl.getPath(); String originalQuery = originalUrl.getQuery(); @@ -165,10 +148,9 @@ private SSLContext getSSLContext() { throw new IOException("CA file not found: " + proxyConfig.getCaFile()); } - byte[] currentContent; - + byte[] currentContent = Files.readAllBytes(path); + synchronized (this) { - currentContent = Files.readAllBytes(path); if (currentContent.length == 0) { throw new IOException("CA file " + proxyConfig.getCaFile() + " is empty"); } @@ -182,7 +164,7 @@ private SSLContext getSSLContext() { return cachedSSLContext; } catch (Exception e) { - throw new RuntimeException("Failed to create default SSLContext", e); + throw new RuntimeException("Failed to initialize SSLContext for proxy", e); } } @@ -190,20 +172,42 @@ private SSLContext getSSLContext() { private SSLContext createSslContextFromBytes(byte[] certificateData) { try (InputStream inputStream = new ByteArrayInputStream(certificateData)) { CertificateFactory cf = CertificateFactory.getInstance("X.509"); - X509Certificate caCert = (X509Certificate) cf.generateCertificate(inputStream); - return createSslContext(caCert); + + List certificates = new ArrayList<>(); + // while(inputStream.available() > 0) { + // X509Certificate cert = (X509Certificate) cf.generateCertificate(inputStream); + // certificates.add(cert); + // } + while (true) { + try { + X509Certificate cert = (X509Certificate) cf.generateCertificate(inputStream); + certificates.add(cert); + } catch (CertificateException e) { + break; // end of stream + } + } + + if (certificates.isEmpty()) { + throw new RuntimeException("No valid certificates found"); + } + + // X509Certificate caCert = certificates.get(0); + return createSslContext(certificates); } catch (Exception e) { throw new RuntimeException("Failed to create SSLContext from bytes", e); } } // Create SSLContext from a single X509Certificate - private SSLContext createSslContext(X509Certificate caCert) { + private SSLContext createSslContext(List certificates) { try { KeyStore keystore = KeyStore.getInstance(KeyStore.getDefaultType()); keystore.load(null, null); - keystore.setCertificateEntry("ca-cert", caCert); - + int index = 1; + for (X509Certificate caCert : certificates) { + keystore.setCertificateEntry("ca-cert-" + index, caCert); + index++; + } TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); tmf.init(keystore); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java index 33d148c86927..88f5d38eb061 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java @@ -24,7 +24,7 @@ public class CustomTokenProxyHttpResponse extends HttpResponse { private final int statusCode; private final HttpHeaders headers; private final HttpURLConnection connection; - private byte[] cachedRequestBodyBytes; + private byte[] cachedResponseBodyBytes; public CustomTokenProxyHttpResponse(HttpRequest request, HttpURLConnection connection) { super(request); @@ -83,13 +83,13 @@ public HttpHeaders getHeaders() { @Override public Mono getBodyAsByteArray() { return Mono.fromCallable(() -> { - if (cachedRequestBodyBytes != null) { - return cachedRequestBodyBytes; + if (cachedResponseBodyBytes != null) { + return cachedResponseBodyBytes; } try (InputStream stream = getResponseStream()) { if (stream == null) { - cachedRequestBodyBytes = new byte[0]; - return cachedRequestBodyBytes; + cachedResponseBodyBytes = new byte[0]; + return cachedResponseBodyBytes; } ByteArrayOutputStream buffer = new ByteArrayOutputStream(); byte[] tmp = new byte[4096]; @@ -97,8 +97,8 @@ public Mono getBodyAsByteArray() { while ((n = stream.read(tmp)) != -1) { buffer.write(tmp, 0, n); } - cachedRequestBodyBytes = buffer.toByteArray(); - return cachedRequestBodyBytes; + cachedResponseBodyBytes = buffer.toByteArray(); + return cachedResponseBodyBytes; } }); } @@ -124,24 +124,6 @@ public void close() { connection.disconnect(); } - - - // @Override - // public Flux getBody() { - // // TODO Auto-generated method stub - // throw new UnsupportedOperationException("Unimplemented method 'getBody'"); - // } - - // @Override - // public Mono getBodyAsString() { - // return getBodyAsByteArray().map(bytes -> new String(bytes, StandardCharsets.UTF_8)); - // } - - // @Override - // public Mono getBodyAsString(Charset charset) { - // return getBodyAsByteArray().map(bytes -> new String(bytes, charset)); - // } - private InputStream getResponseStream() throws IOException { try { return connection.getInputStream(); From 85022597fbd1f34a860bbf93662a3b61c0dd01d3 Mon Sep 17 00:00:00 2001 From: anannya03 Date: Sun, 19 Oct 2025 02:54:48 -0700 Subject: [PATCH 3/3] review comments --- .../identity/WorkloadIdentityCredential.java | 5 +- .../implementation/IdentityClientOptions.java | 2 +- .../CustomTokenProxyConfiguration.java | 27 +++--- .../CustomTokenProxyHttpClient.java | 94 +++++++++---------- .../CustomTokenProxyHttpResponse.java | 5 +- .../implementation/util/IdentitySslUtil.java | 9 +- 6 files changed, 71 insertions(+), 71 deletions(-) diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java index 3fedf5db30d9..f0653bf91522 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java @@ -92,9 +92,10 @@ public class WorkloadIdentityCredential implements TokenCredential { ClientAssertionCredential tempClientAssertionCredential = null; String tempClientId = null; - if(identityClientOptions.isKubernetesTokenProxyEnabled()) { + if (identityClientOptions.isKubernetesTokenProxyEnabled()) { if (!CustomTokenProxyConfiguration.isConfigured(configuration)) { - throw LOGGER.logExceptionAsError (new IllegalArgumentException("Kubernetes token proxy is enabled but not configured.")); + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Kubernetes token proxy is enabled but not configured.")); } ProxyConfig proxyConfig = CustomTokenProxyConfiguration.parseAndValidate(configuration); identityClientOptions.setHttpClient(new CustomTokenProxyHttpClient(proxyConfig)); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java index 4acab6df23ad..3c2d21d08c56 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java @@ -875,7 +875,7 @@ public IdentityClientOptions clone() { .setChained(this.isChained) .subscription(this.subscription) .setEnableKubernetesTokenProxy(this.enableKubernetesTokenProxy); - + if (isBrokerEnabled()) { clone.setBrokerWindowHandle(this.brokerWindowHandle); clone.setEnableLegacyMsaPassthrough(this.enableMsaPassthrough); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java index 0b25b64935b3..4670457b423b 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java @@ -13,13 +13,14 @@ public class CustomTokenProxyConfiguration { private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyConfiguration.class); - + public static final String AZURE_KUBERNETES_TOKEN_PROXY = "AZURE_KUBERNETES_TOKEN_PROXY"; public static final String AZURE_KUBERNETES_CA_FILE = "AZURE_KUBERNETES_CA_FILE"; public static final String AZURE_KUBERNETES_CA_DATA = "AZURE_KUBERNETES_CA_DATA"; public static final String AZURE_KUBERNETES_SNI_NAME = "AZURE_KUBERNETES_SNI_NAME"; - private CustomTokenProxyConfiguration() {} + private CustomTokenProxyConfiguration() { + } public static boolean isConfigured(Configuration configuration) { String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY); @@ -33,8 +34,8 @@ public static ProxyConfig parseAndValidate(Configuration configuration) { String sniName = configuration.get(AZURE_KUBERNETES_SNI_NAME); if (CoreUtils.isNullOrEmpty(tokenProxyUrl)) { - if (!CoreUtils.isNullOrEmpty(sniName) - || !CoreUtils.isNullOrEmpty(caFile) + if (!CoreUtils.isNullOrEmpty(sniName) + || !CoreUtils.isNullOrEmpty(caFile) || !CoreUtils.isNullOrEmpty(caData)) { throw LOGGER.logExceptionAsError(new IllegalStateException( "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present")); @@ -50,7 +51,7 @@ public static ProxyConfig parseAndValidate(Configuration configuration) { URL proxyUrl = validateProxyUrl(tokenProxyUrl); byte[] caCertBytes = null; - if(!CoreUtils.isNullOrEmpty(caData)) { + if (!CoreUtils.isNullOrEmpty(caData)) { try { caCertBytes = caData.getBytes(StandardCharsets.UTF_8); } catch (Exception e) { @@ -76,23 +77,23 @@ private static URL validateProxyUrl(String endpoint) { } if (tokenProxy.getRawUserInfo() != null) { - throw LOGGER.logExceptionAsError(new IllegalArgumentException( - "Custom token endpoint URL must not contain user info: " + endpoint)); + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Custom token endpoint URL must not contain user info: " + endpoint)); } if (tokenProxy.getRawQuery() != null) { - throw LOGGER.logExceptionAsError(new IllegalArgumentException( - "Custom token endpoint URL must not contain a query: " + endpoint)); + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Custom token endpoint URL must not contain a query: " + endpoint)); } if (tokenProxy.getRawFragment() != null) { - throw LOGGER.logExceptionAsError(new IllegalArgumentException( - "Custom token endpoint URL must not contain a fragment: " + endpoint)); + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Custom token endpoint URL must not contain a fragment: " + endpoint)); } if (tokenProxy.getRawPath() == null || tokenProxy.getRawPath().isEmpty()) { - tokenProxy = new URI(tokenProxy.getScheme(), null, tokenProxy.getHost(), - tokenProxy.getPort(), "/", null, null); + tokenProxy = new URI(tokenProxy.getScheme(), null, tokenProxy.getHost(), tokenProxy.getPort(), "/", + null, null); } return tokenProxy.toURL(); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java index e37e166a02fe..9794dc4a5609 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java @@ -3,8 +3,6 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; -import java.lang.reflect.MalformedParametersException; -import java.lang.reflect.Proxy; import java.net.HttpURLConnection; import java.net.MalformedURLException; import java.net.URI; @@ -22,7 +20,6 @@ import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManagerFactory; @@ -49,70 +46,71 @@ public CustomTokenProxyHttpClient(ProxyConfig proxyConfig) { public Mono send(HttpRequest request) { return Mono.fromCallable(() -> sendSync(request, Context.NONE)); } - + @Override public HttpResponse sendSync(HttpRequest request, Context context) { - try { - HttpURLConnection connection = createConnection(request); - return new CustomTokenProxyHttpResponse(request, connection); - } catch (IOException e) { - throw new RuntimeException("Failed to create connection to token proxy", e); - } + try { + HttpURLConnection connection = createConnection(request); + return new CustomTokenProxyHttpResponse(request, connection); + } catch (IOException e) { + throw new RuntimeException("Failed to create connection to token proxy", e); + } } private HttpURLConnection createConnection(HttpRequest request) throws IOException { - URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl()); - HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection(); - - // If SNI explicitly provided - try { - SSLContext sslContext = getSSLContext(); - SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); - if(!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) { - sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, proxyConfig.getSniName()); + URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl()); + HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection(); + + // If SNI explicitly provided + try { + SSLContext sslContext = getSSLContext(); + SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); + if (!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) { + sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, proxyConfig.getSniName()); + } + connection.setSSLSocketFactory(sslSocketFactory); + } catch (Exception e) { + throw new RuntimeException("Failed to set up SSL context for token proxy", e); } - connection.setSSLSocketFactory(sslSocketFactory); - } catch (Exception e) { - throw new RuntimeException("Failed to set up SSL context for token proxy", e); - } - connection.setRequestMethod(request.getHttpMethod().toString()); - connection.setInstanceFollowRedirects(false); - connection.setConnectTimeout(10_000); - connection.setReadTimeout(20_000); - connection.setDoOutput(true); + connection.setRequestMethod(request.getHttpMethod().toString()); + connection.setInstanceFollowRedirects(false); + connection.setConnectTimeout(10_000); + connection.setReadTimeout(20_000); + connection.setDoOutput(true); - request.getHeaders().forEach(header -> { - connection.addRequestProperty(header.getName(), header.getValue()); - }); + request.getHeaders().forEach(header -> { + connection.addRequestProperty(header.getName(), header.getValue()); + }); - if (request.getBodyAsBinaryData() != null) { - byte[] bytes = request.getBodyAsBinaryData().toBytes(); - if (bytes != null && bytes.length > 0) { - connection.getOutputStream().write(bytes); + if (request.getBodyAsBinaryData() != null) { + byte[] bytes = request.getBodyAsBinaryData().toBytes(); + if (bytes != null && bytes.length > 0) { + connection.getOutputStream().write(bytes); + } } - } - return connection; + return connection; } - - private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLException{ + private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLException { try { String originalPath = originalUrl.getPath(); String originalQuery = originalUrl.getQuery(); String tokenProxyBase = proxyConfig.getTokenProxyUrl().toString(); - if(!tokenProxyBase.endsWith("/")) tokenProxyBase += "/"; + if (!tokenProxyBase.endsWith("/")) + tokenProxyBase += "/"; - URI combined = URI.create(tokenProxyBase).resolve(originalPath.startsWith("/") ? originalPath.substring(1) : originalPath); + URI combined = URI.create(tokenProxyBase) + .resolve(originalPath.startsWith("/") ? originalPath.substring(1) : originalPath); String combinedStr = combined.toString(); if (originalQuery != null && !originalQuery.isEmpty()) { combinedStr += "?" + originalQuery; } - return new URL(combinedStr); + return new URL(combinedStr); } catch (Exception e) { throw new RuntimeException("Failed to rewrite token request for proxy", e); @@ -121,15 +119,15 @@ private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLExce private SSLContext getSSLContext() { try { - // If no CA override provide, use default - if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile()) - && (proxyConfig.getCaData() == null || proxyConfig.getCaData().length == 0)) { + // If no CA override provided, use default + if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile()) + && (proxyConfig.getCaData() == null || proxyConfig.getCaData().length == 0)) { synchronized (this) { if (cachedSSLContext == null) { cachedSSLContext = SSLContext.getDefault(); } } - return cachedSSLContext; + return cachedSSLContext; } // If CA data provided, use it @@ -139,7 +137,7 @@ private SSLContext getSSLContext() { cachedSSLContext = createSslContextFromBytes(proxyConfig.getCaData()); } } - return cachedSSLContext; + return cachedSSLContext; } // If CA file provided, read it (and re-read if it changes) @@ -164,7 +162,7 @@ private SSLContext getSSLContext() { return cachedSSLContext; } catch (Exception e) { - throw new RuntimeException("Failed to initialize SSLContext for proxy", e); + throw new RuntimeException("Failed to initialize SSLContext for proxy", e); } } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java index 88f5d38eb061..f9fe17265891 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java @@ -46,7 +46,7 @@ private HttpHeaders extractHeaders(HttpURLConnection connection) { return headers; } - public int extractStatusCode(HttpURLConnection connection) { + private int extractStatusCode(HttpURLConnection connection) { try { return connection.getResponseCode(); } catch (IOException e) { @@ -106,7 +106,7 @@ public Mono getBodyAsByteArray() { @Override public Flux getBody() { return getBodyAsByteArray().flatMapMany(bytes -> Flux.just(ByteBuffer.wrap(bytes))); - } + } @Override public Mono getBodyAsString() { @@ -117,7 +117,6 @@ public Mono getBodyAsString() { public Mono getBodyAsString(Charset charset) { return getBodyAsByteArray().map(bytes -> new String(bytes, charset)); } - @Override public void close() { diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java index 604dba0b59bf..01d0e7edf894 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java @@ -9,7 +9,6 @@ import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SNIHostName; -import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; @@ -161,7 +160,8 @@ public Socket createSocket(String host, int port) throws IOException { } @Override - public Socket createSocket(String host, int port, java.net.InetAddress localAddress, int localPort) throws IOException { + public Socket createSocket(String host, int port, java.net.InetAddress localAddress, int localPort) + throws IOException { Socket socket = sslSocketFactory.createSocket(host, port, localAddress, localPort); configureSni(socket); return socket; @@ -175,7 +175,8 @@ public Socket createSocket(java.net.InetAddress host, int port) throws IOExcepti } @Override - public Socket createSocket(java.net.InetAddress address, int port, java.net.InetAddress localAddress, int localPort) throws IOException { + public Socket createSocket(java.net.InetAddress address, int port, java.net.InetAddress localAddress, + int localPort) throws IOException { Socket socket = sslSocketFactory.createSocket(address, port, localAddress, localPort); configureSni(socket); return socket; @@ -191,7 +192,7 @@ public String[] getSupportedCipherSuites() { return sslSocketFactory.getSupportedCipherSuites(); } - private void configureSni(Socket socket) { + private void configureSni(Socket socket) { if (socket instanceof SSLSocket && !CoreUtils.isNullOrEmpty(sniName)) { SSLSocket sslSocket = (SSLSocket) socket; SSLParameters sslParameters = sslSocket.getSSLParameters();