Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package com.azure.identity.implementation.customtokenproxy;

import com.azure.core.util.logging.ClientLogger;
import com.azure.identity.implementation.WorkloadIdentityTokenProxyPolicy;

import java.net.URI;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.net.URISyntaxException;

import com.azure.core.util.Configuration;
import com.azure.core.util.CoreUtils;

public class CustomTokenProxyConfiguration {

private static final ClientLogger LOGGER = new ClientLogger(WorkloadIdentityTokenProxyPolicy.class);

public static final String AZURE_KUBERNETES_TOKEN_PROXY = "AZURE_KUBERNETES_TOKEN_PROXY";
public static final String AZURE_KUBERNETES_CA_FILE = "AZURE_KUBERNETES_CA_FILE";
public static final String AZURE_KUBERNETES_CA_DATA = "AZURE_KUBERNETES_CA_DATA";
public static final String AZURE_KUBERNETES_SNI_NAME = "AZURE_KUBERNETES_SNI_NAME";

private CustomTokenProxyConfiguration() {}

public static boolean isConfigured(Configuration configuration) {
String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY);
return !CoreUtils.isNullOrEmpty(tokenProxyUrl);
}

public static ProxyConfig parseAndValidate(Configuration configuration) {
String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY);
String caFile = configuration.get(AZURE_KUBERNETES_CA_FILE);
String caData = configuration.get(AZURE_KUBERNETES_CA_DATA);
String sniName = configuration.get(AZURE_KUBERNETES_SNI_NAME);

if (CoreUtils.isNullOrEmpty(tokenProxyUrl)) {
if (!CoreUtils.isNullOrEmpty(sniName)
|| !CoreUtils.isNullOrEmpty(caFile)
|| !CoreUtils.isNullOrEmpty(caData)) {
throw LOGGER.logExceptionAsError(new IllegalStateException(
"AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present"));
}
throw LOGGER.logExceptionAsError(new IllegalStateException(
"AZURE_KUBERNETES_TOKEN_PROXY must be set to enable custom token proxy."));
}

if (!CoreUtils.isNullOrEmpty(caFile) && !CoreUtils.isNullOrEmpty(caData)) {
throw LOGGER.logExceptionAsError(new IllegalStateException(
"Only one of AZURE_KUBERNETES_CA_FILE or AZURE_KUBERNETES_CA_DATA can be set."));
}

URL proxyUrl = validateProxyUrl(tokenProxyUrl);

byte[] caCertBytes = null;
if(!CoreUtils.isNullOrEmpty(caData)) {
try {
caCertBytes = caData.getBytes(StandardCharsets.UTF_8);
} catch (Exception e) {
throw LOGGER.logExceptionAsError(new IllegalArgumentException(
"Failed to decode CA certificate data from AZURE_KUBERNETES_CA_DATA", e));
}
}

return new ProxyConfig(proxyUrl, sniName, caFile, caCertBytes);
}

private static URL validateProxyUrl(String endpoint) {
if (CoreUtils.isNullOrEmpty(endpoint)) {
throw LOGGER.logExceptionAsError(new IllegalArgumentException("Proxy endpoint cannot be null or empty"));
}

try {
URI tokenProxy = new URI(endpoint);

if (!"https".equals(tokenProxy.getScheme())) {
throw LOGGER.logExceptionAsError(new IllegalArgumentException(
"Custom token endpoint must use https scheme, got: " + tokenProxy.getScheme()));
}

if (tokenProxy.getRawUserInfo() != null) {
throw LOGGER.logExceptionAsError(new IllegalArgumentException(
"Custom token endpoint URL must not contain user info: " + endpoint));
}

if (tokenProxy.getRawQuery() != null) {
throw LOGGER.logExceptionAsError(new IllegalArgumentException(
"Custom token endpoint URL must not contain a query: " + endpoint));
}

if (tokenProxy.getRawFragment() != null) {
throw LOGGER.logExceptionAsError(new IllegalArgumentException(
"Custom token endpoint URL must not contain a fragment: " + endpoint));
}

if (tokenProxy.getRawPath() == null || tokenProxy.getRawPath().isEmpty()) {
tokenProxy = new URI(tokenProxy.getScheme(), null, tokenProxy.getHost(),
tokenProxy.getPort(), "/", null, null);
}

return tokenProxy.toURL();

} catch (URISyntaxException | IllegalArgumentException e) {
throw LOGGER.logExceptionAsError(new IllegalArgumentException("Failed to normalize proxy URL path", e));
} catch (Exception e) {
throw new RuntimeException("Unexpected error while validating proxy URL: " + endpoint, e);
}
}

// public static String getTokenProxyUrl() {
// String tokenProxyUrl = System.getenv(AZURE_KUBERNETES_TOKEN_PROXY);
// if (tokenProxyUrl == null || tokenProxyUrl.isEmpty()) {
// throw LOGGER.logExceptionAsError(new IllegalStateException(
// String.format("Environment variable '%s' is not set or is empty. It must be set to the URL of the"
// + " token proxy.", AZURE_KUBERNETES_TOKEN_PROXY)));
// }
// return tokenProxyUrl;
// }

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
package com.azure.identity.implementation.customtokenproxy;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.MalformedParametersException;
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MalformedParametersException (from java.lang.reflect) is unrelated here and never thrown; remove the import and the throws clause (or replace with MalformedURLException if actually needed) to prevent confusion and inaccurate exception signaling.

Suggested change
import java.lang.reflect.MalformedParametersException;

Copilot uses AI. Check for mistakes.

import java.lang.reflect.Proxy;
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import java.lang.reflect.Proxy is unused; remove it to keep imports clean.

Suggested change
import java.lang.reflect.Proxy;

Copilot uses AI. Check for mistakes.

import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.KeyStore;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Arrays;

import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManagerFactory;

import com.azure.core.http.HttpClient;
import com.azure.core.http.HttpRequest;
import com.azure.core.http.HttpResponse;
import com.azure.core.util.Context;
import com.azure.core.util.CoreUtils;
import com.azure.identity.implementation.util.IdentitySslUtil;

import reactor.core.publisher.Mono;

public class CustomTokenProxyHttpClient implements HttpClient {

private final ProxyConfig proxyConfig;
private volatile SSLContext cachedSSLContext;
private volatile byte[] cachedFileContent;

public CustomTokenProxyHttpClient(ProxyConfig proxyConfig) {
this.proxyConfig = proxyConfig;
}

@Override
public Mono<HttpResponse> send(HttpRequest request) {
return Mono.fromCallable(() -> sendSync(request, Context.NONE));
}

@Override
public HttpResponse sendSync(HttpRequest request, Context context) {
try {
HttpURLConnection connection = createConnection(request);
connection.connect();
return new CustomTokenProxyHttpResponse(request, connection);
} catch (IOException e) {
throw new RuntimeException("Failed to create connection to token proxy", e);
}
}

// private HttpURLConnection createConnection(HttpRequest request) throws IOException {
// URL updateProxyRequest = rewriteTokenRequestForProxy(request.getUrl());
// HttpsURLConnection connection = (HttpsURLConnection) updateProxyRequest.openConnection();
// try {
// SSLContext sslContext = getSSLContext();
// connection.setSSLSocketFactory(sslContext.getSocketFactory());

// if(!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) {
// SSLParameters sslParameters = connection.getSSLParameters();
// }
// }
// return connection;

// // connection.setRequestMethod(request.getMethod().toString());
// // connection.setDoOutput(true);
// // request.getHeaders().forEach((key, values) -> {
// // values.forEach(value -> connection.addRequestProperty(key, value));
// // });
// // return connection;
// }

private HttpURLConnection createConnection(HttpRequest request) throws IOException {
URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl());
HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection();

// If SNI explicitly provided
try {
SSLContext sslContext = getSSLContext();
SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
if(!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) {
sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, proxyConfig.getSniName());
}
connection.setSSLSocketFactory(sslSocketFactory);
} catch (Exception e) {
throw new RuntimeException("Failed to set up SSL context for token proxy", e);
}

connection.setRequestMethod(request.getHttpMethod().toString());
// connection.setConnectTimeout(10_000);
// connection.setReadTimeout(20_000);
connection.setDoOutput(true);
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setDoOutput(true) is applied unconditionally; restrict this to methods that actually send a body (e.g., POST/PUT/PATCH) to avoid unintended semantics on GET/DELETE requests.

Suggested change
connection.setDoOutput(true);
// Only set doOutput for methods that support a body
String method = request.getHttpMethod().toString();
if ("POST".equalsIgnoreCase(method) || "PUT".equalsIgnoreCase(method) || "PATCH".equalsIgnoreCase(method)) {
connection.setDoOutput(true);
}

Copilot uses AI. Check for mistakes.


request.getHeaders().forEach(header -> {
connection.addRequestProperty(header.getName(), header.getValue());
});

if (request.getBodyAsBinaryData() != null) {
byte[] bytes = request.getBodyAsBinaryData().toBytes();
if (bytes != null && bytes.length > 0) {
connection.getOutputStream().write(bytes);
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output stream is not closed or flushed, leading to a potential resource leak; wrap the write in a try-with-resources: try (OutputStream os = connection.getOutputStream()) { os.write(bytes); }.

Suggested change
connection.getOutputStream().write(bytes);
try (java.io.OutputStream os = connection.getOutputStream()) {
os.write(bytes);
}

Copilot uses AI. Check for mistakes.

}
}

return connection;
}


private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedParametersException{
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MalformedParametersException (from java.lang.reflect) is unrelated here and never thrown; remove the import and the throws clause (or replace with MalformedURLException if actually needed) to prevent confusion and inaccurate exception signaling.

Copilot uses AI. Check for mistakes.

try {
String originalPath = originalUrl.getPath();
String originalQuery = originalUrl.getQuery();

String tokenProxyBase = proxyConfig.getTokenProxyUrl().toString();
if(!tokenProxyBase.endsWith("/")) tokenProxyBase += "/";

URI combined = URI.create(tokenProxyBase).resolve(originalPath.startsWith("/") ? originalPath.substring(1) : originalPath);

String combinedStr = combined.toString();
if (originalQuery != null && !originalQuery.isEmpty()) {
combinedStr += "?" + originalQuery;
}

return new URL(combinedStr);

} catch (Exception e) {
throw new RuntimeException("Failed to rewrite token request for proxy", e);
}
}

private SSLContext getSSLContext() {
try {
// If no CA override provide, use default
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected wording from 'provide' to 'provided'.

Suggested change
// If no CA override provide, use default
// If no CA override provided, use default

Copilot uses AI. Check for mistakes.

if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile())
&& (proxyConfig.getCaData() == null || proxyConfig.getCaData().length == 0)) {
synchronized (this) {
if (cachedSSLContext == null) {
cachedSSLContext = SSLContext.getDefault();
}
}
return cachedSSLContext;
}

// If CA data provided, use it
if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile())) {
synchronized (this) {
if (cachedSSLContext == null) {
cachedSSLContext = createSslContextFromBytes(proxyConfig.getCaData());
}
}
return cachedSSLContext;
}

// If CA file provided, read it (and re-read if it changes)
Path path = Paths.get(proxyConfig.getCaFile());
if (!Files.exists(path)) {
throw new IOException("CA file not found: " + proxyConfig.getCaFile());
}

byte[] currentContent;

synchronized (this) {
currentContent = Files.readAllBytes(path);
if (currentContent.length == 0) {
throw new IOException("CA file " + proxyConfig.getCaFile() + " is empty");
}

if (cachedSSLContext == null || !Arrays.equals(currentContent, cachedFileContent)) {
cachedSSLContext = createSslContextFromBytes(currentContent);
cachedFileContent = currentContent;
}
}

return cachedSSLContext;

} catch (Exception e) {
throw new RuntimeException("Failed to create default SSLContext", e);
}
}

// Create SSLContext from byte array containing PEM certificate data
private SSLContext createSslContextFromBytes(byte[] certificateData) {
try (InputStream inputStream = new ByteArrayInputStream(certificateData)) {
CertificateFactory cf = CertificateFactory.getInstance("X.509");
X509Certificate caCert = (X509Certificate) cf.generateCertificate(inputStream);
return createSslContext(caCert);
} catch (Exception e) {
throw new RuntimeException("Failed to create SSLContext from bytes", e);
}
}

// Create SSLContext from a single X509Certificate
private SSLContext createSslContext(X509Certificate caCert) {
try {
KeyStore keystore = KeyStore.getInstance(KeyStore.getDefaultType());
keystore.load(null, null);
keystore.setCertificateEntry("ca-cert", caCert);

TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(keystore);

SSLContext context = SSLContext.getInstance("TLS");
context.init(null, tmf.getTrustManagers(), null);
return context;
} catch (Exception e) {
throw new RuntimeException("Failed to create SSLContext", e);
}
}

}
Loading
Loading