diff --git a/CHANGELOG.md b/CHANGELOG.md index 177650e2..2bbf513e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ ## Unreleased +### Changed + +- simplified TLS configuration + ## 2.22.3 - 2025-09-19 ### Fixed diff --git a/src/main/kotlin/com/coder/gateway/util/TLS.kt b/src/main/kotlin/com/coder/gateway/util/TLS.kt index 7d945f53..2baa286d 100644 --- a/src/main/kotlin/com/coder/gateway/util/TLS.kt +++ b/src/main/kotlin/com/coder/gateway/util/TLS.kt @@ -5,10 +5,6 @@ import okhttp3.internal.tls.OkHostnameVerifier import org.slf4j.LoggerFactory import java.io.File import java.io.FileInputStream -import java.net.IDN -import java.net.InetAddress -import java.net.Socket -import java.nio.charset.StandardCharsets import java.security.KeyFactory import java.security.KeyStore import java.security.cert.CertificateException @@ -21,12 +17,9 @@ import java.util.Locale import javax.net.ssl.HostnameVerifier import javax.net.ssl.KeyManager import javax.net.ssl.KeyManagerFactory -import javax.net.ssl.SNIServerName import javax.net.ssl.SSLContext import javax.net.ssl.SSLSession -import javax.net.ssl.SSLSocket import javax.net.ssl.SSLSocketFactory -import javax.net.ssl.StandardConstants import javax.net.ssl.TrustManager import javax.net.ssl.TrustManagerFactory import javax.net.ssl.X509TrustManager @@ -60,7 +53,7 @@ fun sslContextFromPEMs( val kf = KeyFactory.getInstance("RSA") val keySpec = PKCS8EncodedKeySpec(pemBytes) kf.generatePrivate(keySpec) - } catch (e: InvalidKeySpecException) { + } catch (_: InvalidKeySpecException) { val kf = KeyFactory.getInstance("EC") val keySpec = PKCS8EncodedKeySpec(pemBytes) kf.generatePrivate(keySpec) @@ -87,11 +80,7 @@ fun sslContextFromPEMs( fun coderSocketFactory(settings: CoderTLSSettings): SSLSocketFactory { val sslContext = sslContextFromPEMs(settings.certPath, settings.keyPath, settings.caPath) - if (settings.altHostname.isBlank()) { - return sslContext.socketFactory - } - - return AlternateNameSSLSocketFactory(sslContext.socketFactory, settings.altHostname) + return sslContext.socketFactory } fun coderTrustManagers(tlsCAPath: String): Array { @@ -115,82 +104,6 @@ fun coderTrustManagers(tlsCAPath: String): Array { return trustManagerFactory.trustManagers.map { MergedSystemTrustManger(it as X509TrustManager) }.toTypedArray() } -class AlternateNameSSLSocketFactory(private val delegate: SSLSocketFactory, private val alternateName: String) : - SSLSocketFactory() { - override fun getDefaultCipherSuites(): Array = delegate.defaultCipherSuites - - override fun getSupportedCipherSuites(): Array = delegate.supportedCipherSuites - - override fun createSocket(): Socket { - val socket = delegate.createSocket() as SSLSocket - customizeSocket(socket) - return socket - } - - override fun createSocket( - host: String?, - port: Int, - ): Socket { - val socket = delegate.createSocket(host, port) as SSLSocket - customizeSocket(socket) - return socket - } - - override fun createSocket( - host: String?, - port: Int, - localHost: InetAddress?, - localPort: Int, - ): Socket { - val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket - customizeSocket(socket) - return socket - } - - override fun createSocket( - host: InetAddress?, - port: Int, - ): Socket { - val socket = delegate.createSocket(host, port) as SSLSocket - customizeSocket(socket) - return socket - } - - override fun createSocket( - address: InetAddress?, - port: Int, - localAddress: InetAddress?, - localPort: Int, - ): Socket { - val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket - customizeSocket(socket) - return socket - } - - override fun createSocket( - s: Socket?, - host: String?, - port: Int, - autoClose: Boolean, - ): Socket { - val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket - customizeSocket(socket) - return socket - } - - private fun customizeSocket(socket: SSLSocket) { - val params = socket.sslParameters - - params.serverNames = listOf(RelaxedSNIHostname(alternateName)) - socket.sslParameters = params - } -} - -private class RelaxedSNIHostname(hostname: String) : SNIServerName( - StandardConstants.SNI_HOST_NAME, - IDN.toASCII(hostname, 0).toByteArray(StandardCharsets.UTF_8) -) - class CoderHostnameVerifier(private val alternateName: String) : HostnameVerifier { private val logger = LoggerFactory.getLogger(javaClass) @@ -238,7 +151,7 @@ class MergedSystemTrustManger(private val otherTrustManager: X509TrustManager) : ) { try { otherTrustManager.checkClientTrusted(chain, authType) - } catch (e: CertificateException) { + } catch (_: CertificateException) { systemTrustManager.checkClientTrusted(chain, authType) } } @@ -249,7 +162,7 @@ class MergedSystemTrustManger(private val otherTrustManager: X509TrustManager) : ) { try { otherTrustManager.checkServerTrusted(chain, authType) - } catch (e: CertificateException) { + } catch (_: CertificateException) { systemTrustManager.checkServerTrusted(chain, authType) } } diff --git a/src/test/kotlin/com/coder/gateway/util/AlternateNameSSLSocketFactoryTest.kt b/src/test/kotlin/com/coder/gateway/util/AlternateNameSSLSocketFactoryTest.kt deleted file mode 100644 index e16b6c57..00000000 --- a/src/test/kotlin/com/coder/gateway/util/AlternateNameSSLSocketFactoryTest.kt +++ /dev/null @@ -1,237 +0,0 @@ -package com.coder.gateway.util - -import io.mockk.Runs -import io.mockk.every -import io.mockk.just -import io.mockk.mockk -import io.mockk.verify -import java.net.InetAddress -import java.net.Socket -import javax.net.ssl.SSLParameters -import javax.net.ssl.SSLSocket -import javax.net.ssl.SSLSocketFactory -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertSame - - -class AlternateNameSSLSocketFactoryTest { - - @Test - fun `createSocket with no parameters should customize socket with alternate name`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - - every { mockFactory.createSocket() } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "alternate.example.com") - - // When - val result = alternateFactory.createSocket() - - // Then - verify { mockSocket.sslParameters = any() } - assertSame(mockSocket, result) - } - - @Test - fun `createSocket with host and port should customize socket with alternate name`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - - every { mockFactory.createSocket("original.com", 443) } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "alternate.example.com") - - // When - val result = alternateFactory.createSocket("original.com", 443) - - // Then - verify { mockSocket.sslParameters = any() } - assertSame(mockSocket, result) - } - - @Test - fun `createSocket with host port and local address should customize socket`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - val localHost = mockk() - - every { mockFactory.createSocket("original.com", 443, localHost, 8080) } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "alternate.example.com") - - // When - val result = alternateFactory.createSocket("original.com", 443, localHost, 8080) - - // Then - verify { mockSocket.sslParameters = any() } - assertSame(mockSocket, result) - } - - @Test - fun `createSocket with InetAddress should customize socket with alternate name`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - val address = mockk() - - every { mockFactory.createSocket(address, 443) } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "alternate.example.com") - - // When - val result = alternateFactory.createSocket(address, 443) - - // Then - verify { mockSocket.sslParameters = any() } - assertSame(mockSocket, result) - } - - @Test - fun `createSocket with InetAddress and local address should customize socket`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - val address = mockk() - val localAddress = mockk() - - every { mockFactory.createSocket(address, 443, localAddress, 8080) } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "alternate.example.com") - - // When - val result = alternateFactory.createSocket(address, 443, localAddress, 8080) - - // Then - verify { mockSocket.sslParameters = any() } - assertSame(mockSocket, result) - } - - @Test - fun `createSocket with existing socket should customize socket with alternate name`() { - // Given - val mockFactory = mockk() - val mockSSLSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - val existingSocket = mockk() - - every { mockFactory.createSocket(existingSocket, "original.com", 443, true) } returns mockSSLSocket - every { mockSSLSocket.sslParameters } returns mockParams - every { mockSSLSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "alternate.example.com") - - // When - val result = alternateFactory.createSocket(existingSocket, "original.com", 443, true) - - // Then - verify { mockSSLSocket.sslParameters = any() } - assertSame(mockSSLSocket, result) - } - - @Test - fun `customizeSocket should set SNI hostname to alternate name for valid hostname`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - - every { mockFactory.createSocket() } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "valid-hostname.example.com") - - // When & Then - This should work without throwing an exception - assertNotNull(alternateFactory.createSocket()) - verify { mockSocket.sslParameters = any() } - } - - @Test - fun `customizeSocket should NOT throw IllegalArgumentException for hostname with underscore`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - - every { mockFactory.createSocket() } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "non_compliant_hostname.example.com") - - // When & Then - This should work without throwing an exception - assertNotNull(alternateFactory.createSocket()) - verify { mockSocket.sslParameters = any() } - assertEquals(0, mockSocket.sslParameters.serverNames.size) - } - - @Test - fun `createSocket should work with valid international domain names`() { - // Given - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - - every { mockFactory.createSocket() } returns mockSocket - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - val alternateFactory = AlternateNameSSLSocketFactory(mockFactory, "test-server.example.com") - - // When & Then - This should work as hyphens are valid - assertNotNull(alternateFactory.createSocket()) - verify { mockSocket.sslParameters = any() } - } - - private fun createMockSSLSocketFactory(): SSLSocketFactory { - val mockFactory = mockk() - val mockSocket = mockk(relaxed = true) - val mockParams = mockk(relaxed = true) - - // Setup default behavior - every { mockFactory.defaultCipherSuites } returns arrayOf("TLS_AES_256_GCM_SHA384") - every { mockFactory.supportedCipherSuites } returns arrayOf("TLS_AES_256_GCM_SHA384", "TLS_AES_128_GCM_SHA256") - - // Make all createSocket methods return our mock socket - every { mockFactory.createSocket() } returns mockSocket - every { mockFactory.createSocket(any(), any()) } returns mockSocket - every { mockFactory.createSocket(any(), any(), any(), any()) } returns mockSocket - every { mockFactory.createSocket(any(), any()) } returns mockSocket - every { - mockFactory.createSocket( - any(), - any(), - any(), - any() - ) - } returns mockSocket - every { mockFactory.createSocket(any(), any(), any(), any()) } returns mockSocket - - // Setup SSL parameters - every { mockSocket.sslParameters } returns mockParams - every { mockSocket.sslParameters = any() } just Runs - - return mockFactory - } -} \ No newline at end of file