Skip to content

Commit e70c2cf

Browse files
authored
Added support for NIOSSLCustomVerificationCallback for client connection (#1107)
This allows client apps to perform SSL Public Key Pinning, or override the certificate verification logic
1 parent 6f92056 commit e70c2cf

File tree

5 files changed

+80
-8
lines changed

5 files changed

+80
-8
lines changed

Sources/GRPC/ClientConnection.swift

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,8 @@ extension Channel {
411411
connectionIdleTimeout: TimeAmount,
412412
errorDelegate: ClientErrorDelegate?,
413413
requiresZeroLengthWriteWorkaround: Bool,
414-
logger: Logger
414+
logger: Logger,
415+
customVerificationCallback: NIOSSLCustomVerificationCallback?
415416
) -> EventLoopFuture<Void> {
416417
// We add at most 8 handlers to the pipeline.
417418
var handlers: [ChannelHandler] = []
@@ -427,11 +428,20 @@ extension Channel {
427428

428429
if let tlsConfiguration = tlsConfiguration {
429430
do {
430-
let sslClientHandler = try NIOSSLClientHandler(
431-
context: try NIOSSLContext(configuration: tlsConfiguration),
432-
serverHostname: tlsServerHostname
433-
)
434-
handlers.append(sslClientHandler)
431+
if let customVerificationCallback = customVerificationCallback {
432+
let sslClientHandler = try NIOSSLClientHandler(
433+
context: try NIOSSLContext(configuration: tlsConfiguration),
434+
serverHostname: tlsServerHostname,
435+
customVerificationCallback: customVerificationCallback
436+
)
437+
handlers.append(sslClientHandler)
438+
} else {
439+
let sslClientHandler = try NIOSSLClientHandler(
440+
context: try NIOSSLContext(configuration: tlsConfiguration),
441+
serverHostname: tlsServerHostname
442+
)
443+
handlers.append(sslClientHandler)
444+
}
435445
handlers.append(TLSVerificationHandler(logger: logger))
436446
} catch {
437447
return self.eventLoop.makeFailedFuture(error)

Sources/GRPC/ConnectionManager.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,8 @@ extension ConnectionManager {
853853
group: self.eventLoop,
854854
hasTLS: self.configuration.tls != nil
855855
),
856-
logger: self.logger
856+
logger: self.logger,
857+
customVerificationCallback: self.configuration.tls?.customVerificationCallback
857858
)
858859

859860
// Run the debug initializer, if there is one.

Sources/GRPC/GRPCChannel/GRPCChannelBuilder.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,15 @@ extension ClientConnection.Builder.Secure {
235235
self.tls.certificateVerification = certificateVerification
236236
return self
237237
}
238+
239+
/// A custom verification callback that allows completely overriding the certificate verification logic.
240+
@discardableResult
241+
public func withTLSCustomVerificationCallback(
242+
_ callback: @escaping NIOSSLCustomVerificationCallback
243+
) -> Self {
244+
self.tls.customVerificationCallback = callback
245+
return self
246+
}
238247
}
239248

240249
extension ClientConnection.Builder {

Sources/GRPC/TLSConfiguration.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ extension ClientConnection.Configuration {
6868
}
6969
}
7070

71+
/// A custom verification callback that allows completely overriding the certificate verification logic for this connection.
72+
public var customVerificationCallback: NIOSSLCustomVerificationCallback?
73+
7174
/// TLS Configuration with suitable defaults for clients.
7275
///
7376
/// This is a wrapper around `NIOSSL.TLSConfiguration` to restrict input to values which comply
@@ -83,12 +86,15 @@ extension ClientConnection.Configuration {
8386
/// `.fullVerification`.
8487
/// - Parameter hostnameOverride: Value to use for TLS SNI extension; this must not be an IP
8588
/// address, defaults to `nil`.
89+
/// - Parameter customVerificationCallback: A callback to provide to override the certificate verification logic,
90+
/// defaults to `nil`.
8691
public init(
8792
certificateChain: [NIOSSLCertificateSource] = [],
8893
privateKey: NIOSSLPrivateKeySource? = nil,
8994
trustRoots: NIOSSLTrustRoots = .default,
9095
certificateVerification: CertificateVerification = .fullVerification,
91-
hostnameOverride: String? = nil
96+
hostnameOverride: String? = nil,
97+
customVerificationCallback: NIOSSLCustomVerificationCallback? = nil
9298
) {
9399
self.configuration = .forClient(
94100
minimumTLSVersion: .tlsv12,
@@ -99,6 +105,7 @@ extension ClientConnection.Configuration {
99105
applicationProtocols: GRPCApplicationProtocolIdentifier.client
100106
)
101107
self.hostnameOverride = hostnameOverride
108+
self.customVerificationCallback = customVerificationCallback
102109
}
103110

104111
/// Creates a TLS Configuration using the given `NIOSSL.TLSConfiguration`.

Tests/GRPCTests/ClientTLSFailureTests.swift

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,4 +180,49 @@ class ClientTLSFailureTests: GRPCTestCase {
180180
XCTFail("Expected NIOSSLExtraError.failedToValidateHostname")
181181
}
182182
}
183+
184+
func testClientConnectionFailsWhenCertificateValidationDenied() throws {
185+
let errorExpectation = self.expectation(description: "error")
186+
// 2 errors: one for the failed handshake, and another for failing the ready-channel promise
187+
// (because the handshake failed).
188+
errorExpectation.expectedFulfillmentCount = 2
189+
190+
let tlsConfiguration = ClientConnection.Configuration.TLS(
191+
certificateChain: [.certificate(SampleCertificate.client.certificate)],
192+
privateKey: .privateKey(SamplePrivateKey.client),
193+
trustRoots: .certificates([SampleCertificate.ca.certificate]),
194+
hostnameOverride: SampleCertificate.server.commonName,
195+
customVerificationCallback: { _, promise in
196+
// The certificate validation is forced to fail
197+
promise.fail(NIOSSLError.unableToValidateCertificate)
198+
}
199+
)
200+
201+
var configuration = self.makeClientConfiguration(tls: tlsConfiguration)
202+
let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
203+
configuration.errorDelegate = errorRecorder
204+
205+
let stateChangeDelegate = RecordingConnectivityDelegate()
206+
stateChangeDelegate.expectChanges(2) { changes in
207+
XCTAssertEqual(changes, [
208+
Change(from: .idle, to: .connecting),
209+
Change(from: .connecting, to: .shutdown),
210+
])
211+
}
212+
configuration.connectivityStateDelegate = stateChangeDelegate
213+
214+
// Start an RPC to trigger creating a channel.
215+
let echo = Echo_EchoClient(channel: ClientConnection(configuration: configuration))
216+
_ = echo.get(.with { $0.text = "foo" })
217+
218+
self.wait(for: [errorExpectation], timeout: self.defaultTestTimeout)
219+
stateChangeDelegate.waitForExpectedChanges(timeout: .seconds(5))
220+
221+
if let nioSSLError = errorRecorder.errors.first as? NIOSSLError,
222+
case .handshakeFailed(.sslError) = nioSSLError {
223+
// Expected case.
224+
} else {
225+
XCTFail("Expected NIOSSLError.handshakeFailed(BoringSSL.sslError)")
226+
}
227+
}
183228
}

0 commit comments

Comments
 (0)