From 7b156cbb5ba6704d839bd91af6dd59dc1b525ef0 Mon Sep 17 00:00:00 2001 From: Faur Ioan-Aurel Date: Wed, 24 Sep 2025 23:52:19 +0300 Subject: [PATCH] refactor: remove unnecessary SNI manipulation from SSL socket factory In our codebase we currently have two layers of custom logic: - one that alters the SNI in the ClientHello (via a custom SSLSocketFactory) - another that compares an alternate hostname against the SAN entries during client-side certificate verification. This work was done for one of Coder's clients that wants to do auth via certificates instead of API tokens. After recent discussions it turns out the SNI manipulation is not needed, we only need to do custom certificate validation. --- CHANGELOG.md | 4 + src/main/kotlin/com/coder/gateway/util/TLS.kt | 95 +------ .../util/AlternateNameSSLSocketFactoryTest.kt | 237 ------------------ 3 files changed, 8 insertions(+), 328 deletions(-) delete mode 100644 src/test/kotlin/com/coder/gateway/util/AlternateNameSSLSocketFactoryTest.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index 177650e2..2bbf513e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ ## Unreleased +### Changed + +- simplified TLS configuration + ## 2.22.3 - 2025-09-19 ### Fixed diff --git a/src/main/kotlin/com/coder/gateway/util/TLS.kt b/src/main/kotlin/com/coder/gateway/util/TLS.kt index 7d945f53..2baa286d 100644 --- a/src/main/kotlin/com/coder/gateway/util/TLS.kt +++ b/src/main/kotlin/com/coder/gateway/util/TLS.kt @@ -5,10 +5,6 @@ import okhttp3.internal.tls.OkHostnameVerifier import org.slf4j.LoggerFactory import java.io.File import java.io.FileInputStream -import java.net.IDN -import java.net.InetAddress -import java.net.Socket -import java.nio.charset.StandardCharsets import java.security.KeyFactory import java.security.KeyStore import java.security.cert.CertificateException @@ -21,12 +17,9 @@ import java.util.Locale import javax.net.ssl.HostnameVerifier import javax.net.ssl.KeyManager import javax.net.ssl.KeyManagerFactory -import javax.net.ssl.SNIServerName import javax.net.ssl.SSLContext import javax.net.ssl.SSLSession -import javax.net.ssl.SSLSocket import javax.net.ssl.SSLSocketFactory -import javax.net.ssl.StandardConstants import javax.net.ssl.TrustManager import javax.net.ssl.TrustManagerFactory import javax.net.ssl.X509TrustManager @@ -60,7 +53,7 @@ fun sslContextFromPEMs( val kf = KeyFactory.getInstance("RSA") val keySpec = PKCS8EncodedKeySpec(pemBytes) kf.generatePrivate(keySpec) - } catch (e: InvalidKeySpecException) { + } catch (_: InvalidKeySpecException) { val kf = KeyFactory.getInstance("EC") val keySpec = PKCS8EncodedKeySpec(pemBytes) kf.generatePrivate(keySpec) @@ -87,11 +80,7 @@ fun sslContextFromPEMs( fun coderSocketFactory(settings: CoderTLSSettings): SSLSocketFactory { val sslContext = sslContextFromPEMs(settings.certPath, settings.keyPath, settings.caPath) - if (settings.altHostname.isBlank()) { - return sslContext.socketFactory - } - - return AlternateNameSSLSocketFactory(sslContext.socketFactory, settings.altHostname) + return sslContext.socketFactory } fun coderTrustManagers(tlsCAPath: String): Array { @@ -115,82 +104,6 @@ fun coderTrustManagers(tlsCAPath: String): Array { return trustManagerFactory.trustManagers.map { MergedSystemTrustManger(it as X509TrustManager) }.toTypedArray() } -class AlternateNameSSLSocketFactory(private val delegate: SSLSocketFactory, private val alternateName: String) : - SSLSocketFactory() { - override fun getDefaultCipherSuites(): Array = delegate.defaultCipherSuites - - override fun getSupportedCipherSuites(): Array = delegate.supportedCipherSuites - - override fun createSocket(): Socket { - val socket = delegate.createSocket() as SSLSocket - customizeSocket(socket) - return socket - } - - override fun createSocket( - host: String?, - port: Int, - ): Socket { - val socket = delegate.createSocket(host, port) as SSLSocket - customizeSocket(socket) - return socket - } - - override fun createSocket( - host: String?, - port: Int, - localHost: InetAddress?, - localPort: Int, - ): Socket { - val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket - customizeSocket(socket) - return socket - } - - override fun createSocket( - host: InetAddress?, - port: Int, - ): Socket { - val socket = delegate.createSocket(host, port) as SSLSocket - customizeSocket(socket) - return socket - } - - override fun createSocket( - address: InetAddress?, - port: Int, - localAddress: InetAddress?, - localPort: Int, - ): Socket { - val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket - customizeSocket(socket) - return socket - } - - override fun createSocket( - s: Socket?, - host: String?, - port: Int, - autoClose: Boolean, - ): Socket { - val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket - customizeSocket(socket) - return socket - } - - private fun customizeSocket(socket: SSLSocket) { - val params = socket.sslParameters - - params.serverNames = listOf(RelaxedSNIHostname(alternateName)) - socket.sslParameters = params - } -} - -private class RelaxedSNIHostname(hostname: String) : SNIServerName( - StandardConstants.SNI_HOST_NAME, - IDN.toASCII(hostname, 0).toByteArray(StandardCharsets.UTF_8) -) - class CoderHostnameVerifier(private val alternateName: String) : HostnameVerifier { private val logger = LoggerFactory.getLogger(javaClass) @@ -238,7 +151,7 @@ class MergedSystemTrustManger(private val otherTrustManager: X509TrustManager) : ) { try { otherTrustManager.checkClientTrusted(chain, authType) - } catch (e: CertificateException) { + } catch (_: CertificateException) { systemTrustManager.checkClientTrusted(chain, authType) } } @@ -249,7 +162,7 @@ class MergedSystemTrustManger(private val otherTrustManager: X509TrustManager) : ) { try { otherTrustManager.checkServerTrusted(chain, authType) - } catch (e: CertificateException) { + } catch (_: CertificateException) { systemTrustManager.checkServerTrusted(chain, authType) } } diff --git a/src/test/kotlin/com/coder/gateway/util/AlternateNameSSLSocketFactoryTest.kt b/src/test/kotlin/com/coder/gateway/util/AlternateNameSSLSocketFactoryTest.kt deleted file mode 100644 index e16b6c57..00000000 --- a/src/test/kotlin/com/coder/gateway/util/AlternateNameSSLSocketFactoryTest.kt +++ /dev/null @@ -1,237 +0,0 @@ -package com.coder.gateway.util - -import io.mockk.Runs -import io.mockk.every -import io.mockk.just -import io.mockk.mockk -import io.mockk.verify -import java.net.InetAddress -import java.net.Socket -import javax.net.ssl.SSLParameters -import javax.net.ssl.SSLSocket -import javax.net.ssl.SSLSocketFactory -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertSame - - -class AlternateNameSSLSocketFactoryTest { - - @Test - fun `createSocket with no parameters should customize socket with alternate name`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - - every { mockFactory.createSocket() } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "alternate.example.com") - - // When - val result = alternateFactory.createSocket() - - // Then - verify { mockSocket.sslParameters = any() } - assertSame(mockSocket, result) - } - - @Test - fun `createSocket with host and port should customize socket with alternate name`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - - every { mockFactory.createSocket("original.com", 443) } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "alternate.example.com") - - // When - val result = alternateFactory.createSocket("original.com", 443) - - // Then - verify { mockSocket.sslParameters = any() } - assertSame(mockSocket, result) - } - - @Test - fun `createSocket with host port and local address should customize socket`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - val localHost = mockk() - - every { mockFactory.createSocket("original.com", 443, localHost, 8080) } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "alternate.example.com") - - // When - val result = alternateFactory.createSocket("original.com", 443, localHost, 8080) - - // Then - verify { mockSocket.sslParameters = any() } - assertSame(mockSocket, result) - } - - @Test - fun `createSocket with InetAddress should customize socket with alternate name`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - val address = mockk() - - every { mockFactory.createSocket(address, 443) } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "alternate.example.com") - - // When - val result = alternateFactory.createSocket(address, 443) - - // Then - verify { mockSocket.sslParameters = any() } - assertSame(mockSocket, result) - } - - @Test - fun `createSocket with InetAddress and local address should customize socket`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - val address = mockk() - val localAddress = mockk() - - every { mockFactory.createSocket(address, 443, localAddress, 8080) } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "alternate.example.com") - - // When - val result = alternateFactory.createSocket(address, 443, localAddress, 8080) - - // Then - verify { mockSocket.sslParameters = any() } - assertSame(mockSocket, result) - } - - @Test - fun `createSocket with existing socket should customize socket with alternate name`() { - // Given - val mockFactory = mockk() - val mockSSLSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - val existingSocket = mockk() - - every { mockFactory.createSocket(existingSocket, "original.com", 443, true) } returns mockSSLSocket - every { mockSSLSocket.sslParameters } returns mockParams - every { mockSSLSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "alternate.example.com") - - // When - val result = alternateFactory.createSocket(existingSocket, "original.com", 443, true) - - // Then - verify { mockSSLSocket.sslParameters = any() } - assertSame(mockSSLSocket, result) - } - - @Test - fun `customizeSocket should set SNI hostname to alternate name for valid hostname`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - - every { mockFactory.createSocket() } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "valid-hostname.example.com") - - // When & Then - This should work without throwing an exception - assertNotNull(alternateFactory.createSocket()) - verify { mockSocket.sslParameters = any() } - } - - @Test - fun `customizeSocket should NOT throw IllegalArgumentException for hostname with underscore`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - - every { mockFactory.createSocket() } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "non_compliant_hostname.example.com") - - // When & Then - This should work without throwing an exception - assertNotNull(alternateFactory.createSocket()) - verify { mockSocket.sslParameters = any() } - assertEquals(0, mockSocket.sslParameters.serverNames.size) - } - - @Test - fun `createSocket should work with valid international domain names`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - - every { mockFactory.createSocket() } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "test-server.example.com") - - // When & Then - This should work as hyphens are valid - assertNotNull(alternateFactory.createSocket()) - verify { mockSocket.sslParameters = any() } - } - - private fun createMockSSLSocketFactory(): SSLSocketFactory { - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - - // Setup default behavior - every { mockFactory.defaultCipherSuites } returns arrayOf("TLS_AES_256_GCM_SHA384") - every { mockFactory.supportedCipherSuites } returns arrayOf("TLS_AES_256_GCM_SHA384", "TLS_AES_128_GCM_SHA256") - - // Make all createSocket methods return our mock socket - every { mockFactory.createSocket() } returns mockSocket - every { mockFactory.createSocket(any(), any()) } returns mockSocket - every { mockFactory.createSocket(any(), any(), any(), any()) } returns mockSocket - every { mockFactory.createSocket(any(), any()) } returns mockSocket - every { - mockFactory.createSocket( - any(), - any(), - any(), - any() - ) - } returns mockSocket - every { mockFactory.createSocket(any(), any(), any(), any()) } returns mockSocket - - // Setup SSL parameters - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - return mockFactory - } -} \ No newline at end of file