Skip to content

Commit 135d23b

Browse files
committed
Break out TLS factory
Since it is used by the CLI manager as well.
1 parent 5aebc63 commit 135d23b

File tree

4 files changed

+225
-219
lines changed

4 files changed

+225
-219
lines changed

src/main/kotlin/com/coder/gateway/sdk/CoderCLIManager.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ package com.coder.gateway.sdk
22

33
import com.coder.gateway.services.CoderSettings
44
import com.coder.gateway.services.CoderSettingsState
5+
import com.coder.gateway.util.CoderHostnameVerifier
56
import com.coder.gateway.util.InvalidVersionException
67
import com.coder.gateway.util.SemVer
78
import com.coder.gateway.util.OS
9+
import com.coder.gateway.util.coderSocketFactory
810
import com.coder.gateway.util.escape
911
import com.coder.gateway.util.getOS
1012
import com.coder.gateway.util.safeHost

src/main/kotlin/com/coder/gateway/sdk/CoderRestClientService.kt

Lines changed: 3 additions & 218 deletions
Original file line numberDiff line numberDiff line change
@@ -16,54 +16,27 @@ import com.coder.gateway.sdk.v2.models.WorkspaceTransition
1616
import com.coder.gateway.sdk.v2.models.toAgentModels
1717
import com.coder.gateway.services.CoderSettings
1818
import com.coder.gateway.services.CoderSettingsService
19-
import com.coder.gateway.services.CoderTLSSettings
20-
import com.coder.gateway.util.OS
21-
import com.coder.gateway.util.expand
22-
import com.coder.gateway.util.getOS
19+
import com.coder.gateway.util.CoderHostnameVerifier
20+
import com.coder.gateway.util.coderSocketFactory
21+
import com.coder.gateway.util.coderTrustManagers
2322
import com.google.gson.Gson
2423
import com.google.gson.GsonBuilder
2524
import com.intellij.ide.plugins.PluginManagerCore
2625
import com.intellij.openapi.components.Service
2726
import com.intellij.openapi.components.service
28-
import com.intellij.openapi.diagnostic.Logger
2927
import com.intellij.openapi.extensions.PluginId
3028
import com.intellij.openapi.util.SystemInfo
3129
import com.intellij.util.net.HttpConfigurable
3230
import okhttp3.Credentials
3331
import okhttp3.OkHttpClient
34-
import okhttp3.internal.tls.OkHostnameVerifier
3532
import okhttp3.logging.HttpLoggingInterceptor
36-
import org.zeroturnaround.exec.ProcessExecutor
3733
import retrofit2.Retrofit
3834
import retrofit2.converter.gson.GsonConverterFactory
39-
import java.io.File
40-
import java.io.FileInputStream
4135
import java.net.HttpURLConnection.HTTP_CREATED
42-
import java.net.InetAddress
4336
import java.net.ProxySelector
44-
import java.net.Socket
4537
import java.net.URL
46-
import java.security.KeyFactory
47-
import java.security.KeyStore
48-
import java.security.cert.CertificateException
49-
import java.security.cert.CertificateFactory
50-
import java.security.cert.X509Certificate
51-
import java.security.spec.InvalidKeySpecException
52-
import java.security.spec.PKCS8EncodedKeySpec
5338
import java.time.Instant
54-
import java.util.Base64
55-
import java.util.Locale
5639
import java.util.UUID
57-
import javax.net.ssl.HostnameVerifier
58-
import javax.net.ssl.KeyManager
59-
import javax.net.ssl.KeyManagerFactory
60-
import javax.net.ssl.SNIHostName
61-
import javax.net.ssl.SSLContext
62-
import javax.net.ssl.SSLSession
63-
import javax.net.ssl.SSLSocket
64-
import javax.net.ssl.SSLSocketFactory
65-
import javax.net.ssl.TrustManagerFactory
66-
import javax.net.ssl.TrustManager
6740
import javax.net.ssl.X509TrustManager
6841

6942
@Service(Service.Level.APP)
@@ -297,191 +270,3 @@ open class CoderRestClient @JvmOverloads constructor(
297270
}
298271
}
299272
}
300-
301-
fun SSLContextFromPEMs(certPath: String, keyPath: String, caPath: String) : SSLContext {
302-
var km: Array<KeyManager>? = null
303-
if (certPath.isNotBlank() && keyPath.isNotBlank()) {
304-
val certificateFactory = CertificateFactory.getInstance("X.509")
305-
val certInputStream = FileInputStream(expand(certPath))
306-
val certChain = certificateFactory.generateCertificates(certInputStream)
307-
certInputStream.close()
308-
309-
// ideally we would use something like PemReader from BouncyCastle, but
310-
// BC is used by the IDE. This makes using BC very impractical since
311-
// type casting will mismatch due to the different class loaders.
312-
val privateKeyPem = File(expand(keyPath)).readText()
313-
val start: Int = privateKeyPem.indexOf("-----BEGIN PRIVATE KEY-----")
314-
val end: Int = privateKeyPem.indexOf("-----END PRIVATE KEY-----", start)
315-
val pemBytes: ByteArray = Base64.getDecoder().decode(
316-
privateKeyPem.substring(start + "-----BEGIN PRIVATE KEY-----".length, end)
317-
.replace("\\s+".toRegex(), "")
318-
)
319-
320-
val privateKey = try {
321-
val kf = KeyFactory.getInstance("RSA")
322-
val keySpec = PKCS8EncodedKeySpec(pemBytes)
323-
kf.generatePrivate(keySpec)
324-
} catch (e: InvalidKeySpecException) {
325-
val kf = KeyFactory.getInstance("EC")
326-
val keySpec = PKCS8EncodedKeySpec(pemBytes)
327-
kf.generatePrivate(keySpec)
328-
}
329-
330-
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType())
331-
keyStore.load(null)
332-
certChain.withIndex().forEach {
333-
keyStore.setCertificateEntry("cert${it.index}", it.value as X509Certificate)
334-
}
335-
keyStore.setKeyEntry("key", privateKey, null, certChain.toTypedArray())
336-
337-
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
338-
keyManagerFactory.init(keyStore, null)
339-
km = keyManagerFactory.keyManagers
340-
}
341-
342-
val sslContext = SSLContext.getInstance("TLS")
343-
344-
val trustManagers = coderTrustManagers(caPath)
345-
sslContext.init(km, trustManagers, null)
346-
return sslContext
347-
}
348-
349-
fun coderSocketFactory(settings: CoderTLSSettings) : SSLSocketFactory {
350-
val sslContext = SSLContextFromPEMs(settings.certPath, settings.keyPath, settings.caPath)
351-
if (settings.altHostname.isBlank()) {
352-
return sslContext.socketFactory
353-
}
354-
355-
return AlternateNameSSLSocketFactory(sslContext.socketFactory, settings.altHostname)
356-
}
357-
358-
fun coderTrustManagers(tlsCAPath: String) : Array<TrustManager> {
359-
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
360-
if (tlsCAPath.isBlank()) {
361-
// return default trust managers
362-
trustManagerFactory.init(null as KeyStore?)
363-
return trustManagerFactory.trustManagers
364-
}
365-
366-
367-
val certificateFactory = CertificateFactory.getInstance("X.509")
368-
val caInputStream = FileInputStream(expand(tlsCAPath))
369-
val certChain = certificateFactory.generateCertificates(caInputStream)
370-
371-
val truststore = KeyStore.getInstance(KeyStore.getDefaultType())
372-
truststore.load(null)
373-
certChain.withIndex().forEach {
374-
truststore.setCertificateEntry("cert${it.index}", it.value as X509Certificate)
375-
}
376-
trustManagerFactory.init(truststore)
377-
return trustManagerFactory.trustManagers.map { MergedSystemTrustManger(it as X509TrustManager) }.toTypedArray()
378-
}
379-
380-
class AlternateNameSSLSocketFactory(private val delegate: SSLSocketFactory, private val alternateName: String) : SSLSocketFactory() {
381-
override fun getDefaultCipherSuites(): Array<String> {
382-
return delegate.defaultCipherSuites
383-
}
384-
385-
override fun getSupportedCipherSuites(): Array<String> {
386-
return delegate.supportedCipherSuites
387-
}
388-
389-
override fun createSocket(): Socket {
390-
val socket = delegate.createSocket() as SSLSocket
391-
customizeSocket(socket)
392-
return socket
393-
}
394-
395-
override fun createSocket(host: String?, port: Int): Socket {
396-
val socket = delegate.createSocket(host, port) as SSLSocket
397-
customizeSocket(socket)
398-
return socket
399-
}
400-
401-
override fun createSocket(host: String?, port: Int, localHost: InetAddress?, localPort: Int): Socket {
402-
val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket
403-
customizeSocket(socket)
404-
return socket
405-
}
406-
407-
override fun createSocket(host: InetAddress?, port: Int): Socket {
408-
val socket = delegate.createSocket(host, port) as SSLSocket
409-
customizeSocket(socket)
410-
return socket
411-
}
412-
413-
override fun createSocket(address: InetAddress?, port: Int, localAddress: InetAddress?, localPort: Int): Socket {
414-
val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket
415-
customizeSocket(socket)
416-
return socket
417-
}
418-
419-
override fun createSocket(s: Socket?, host: String?, port: Int, autoClose: Boolean): Socket {
420-
val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket
421-
customizeSocket(socket)
422-
return socket
423-
}
424-
425-
private fun customizeSocket(socket: SSLSocket) {
426-
val params = socket.sslParameters
427-
params.serverNames = listOf(SNIHostName(alternateName))
428-
socket.sslParameters = params
429-
}
430-
}
431-
432-
class CoderHostnameVerifier(private val alternateName: String) : HostnameVerifier {
433-
val logger = Logger.getInstance(CoderRestClientService::class.java.simpleName)
434-
override fun verify(host: String, session: SSLSession): Boolean {
435-
if (alternateName.isEmpty()) {
436-
return OkHostnameVerifier.verify(host, session)
437-
}
438-
val certs = session.peerCertificates ?: return false
439-
for (cert in certs) {
440-
if (cert !is X509Certificate) {
441-
continue
442-
}
443-
val entries = cert.subjectAlternativeNames ?: continue
444-
for (entry in entries) {
445-
val kind = entry[0] as Int
446-
if (kind != 2) { // DNS Name
447-
continue
448-
}
449-
val hostname = entry[1] as String
450-
logger.debug("Found cert hostname: $hostname")
451-
if (hostname.lowercase(Locale.getDefault()) == alternateName) {
452-
return true
453-
}
454-
}
455-
}
456-
return false
457-
}
458-
}
459-
460-
class MergedSystemTrustManger(private val otherTrustManager: X509TrustManager) : X509TrustManager {
461-
private val systemTrustManager : X509TrustManager
462-
init {
463-
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
464-
trustManagerFactory.init(null as KeyStore?)
465-
systemTrustManager = trustManagerFactory.trustManagers.first { it is X509TrustManager } as X509TrustManager
466-
}
467-
468-
override fun checkClientTrusted(chain: Array<out X509Certificate>, authType: String?) {
469-
try {
470-
otherTrustManager.checkClientTrusted(chain, authType)
471-
} catch (e: CertificateException) {
472-
systemTrustManager.checkClientTrusted(chain, authType)
473-
}
474-
}
475-
476-
override fun checkServerTrusted(chain: Array<out X509Certificate>, authType: String?) {
477-
try {
478-
otherTrustManager.checkServerTrusted(chain, authType)
479-
} catch (e: CertificateException) {
480-
systemTrustManager.checkServerTrusted(chain, authType)
481-
}
482-
}
483-
484-
override fun getAcceptedIssuers(): Array<X509Certificate> {
485-
return otherTrustManager.acceptedIssuers + systemTrustManager.acceptedIssuers
486-
}
487-
}

0 commit comments

Comments
 (0)