Skip to content

Commit a6fc40c

Browse files
[Redis] Add TLS support (#84)
Adds TLS support, pointing to a RediStack fork while we wait for the PR to be merged upstream.
1 parent d1d2655 commit a6fc40c

File tree

5 files changed

+189
-120
lines changed

5 files changed

+189
-120
lines changed

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ let package = Package(
2525
.package(url: "https://github.com/alchemy-swift/cron.git", from: "2.3.2"),
2626
.package(url: "https://github.com/alchemy-swift/pluralize", from: "1.0.1"),
2727
.package(url: "https://github.com/johnsundell/Plot.git", from: "0.8.0"),
28-
.package(url: "https://github.com/Mordil/RediStack.git", from: "1.0.0"),
28+
.package(url: "https://github.com/alchemy-swift/RediStack.git", branch: "ssl-support-1.2.0"),
2929
.package(url: "https://github.com/onevcat/Rainbow", .upToNextMajor(from: "4.0.0")),
3030
.package(url: "https://github.com/vadymmarkov/Fakery", from: "5.0.0"),
3131
],
Lines changed: 1 addition & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,8 @@
11
import NIO
22
import RediStack
33

4-
/// RedisClient conformance. See `RedisClient` for docs.
5-
extension RedisClient: RediStack.RedisClient {
4+
extension RedisClient {
65

7-
// MARK: RediStack.RedisClient
8-
9-
public var eventLoop: EventLoop {
10-
Loop.current
11-
}
12-
13-
public func logging(to logger: Logger) -> RediStack.RedisClient {
14-
provider.getClient().logging(to: logger)
15-
}
16-
17-
public func send(command: String, with arguments: [RESPValue]) -> EventLoopFuture<RESPValue> {
18-
provider.getClient()
19-
.send(command: command, with: arguments).hop(to: Loop.current)
20-
}
21-
22-
public func subscribe(
23-
to channels: [RedisChannelName],
24-
messageReceiver receiver: @escaping RedisSubscriptionMessageReceiver,
25-
onSubscribe subscribeHandler: RedisSubscriptionChangeHandler?,
26-
onUnsubscribe unsubscribeHandler: RedisSubscriptionChangeHandler?
27-
) -> EventLoopFuture<Void> {
28-
provider.getClient()
29-
.subscribe(
30-
to: channels,
31-
messageReceiver: receiver,
32-
onSubscribe: subscribeHandler,
33-
onUnsubscribe: unsubscribeHandler
34-
)
35-
}
36-
37-
public func psubscribe(
38-
to patterns: [String],
39-
messageReceiver receiver: @escaping RedisSubscriptionMessageReceiver,
40-
onSubscribe subscribeHandler: RedisSubscriptionChangeHandler?,
41-
onUnsubscribe unsubscribeHandler: RedisSubscriptionChangeHandler?
42-
) -> EventLoopFuture<Void> {
43-
provider.getClient()
44-
.psubscribe(
45-
to: patterns,
46-
messageReceiver: receiver,
47-
onSubscribe: subscribeHandler,
48-
onUnsubscribe: unsubscribeHandler
49-
)
50-
}
51-
52-
public func unsubscribe(from channels: [RedisChannelName]) -> EventLoopFuture<Void> {
53-
provider.getClient().unsubscribe(from: channels)
54-
}
55-
56-
public func punsubscribe(from patterns: [String]) -> EventLoopFuture<Void> {
57-
provider.getClient().punsubscribe(from: patterns)
58-
}
59-
606
// MARK: - Alchemy sugar
617

628
/// Wrapper around sending commands to Redis.
@@ -113,17 +59,3 @@ extension RedisClient: RediStack.RedisClient {
11359
}
11460
}
11561
}
116-
117-
extension RedisConnection: RedisProvider {
118-
public func getClient() -> RediStack.RedisClient {
119-
self
120-
}
121-
122-
public func shutdown() throws {
123-
try close().wait()
124-
}
125-
126-
public func transaction<T>(_ transaction: @escaping (RedisProvider) async throws -> T) async throws -> T {
127-
try await transaction(self)
128-
}
129-
}
Lines changed: 161 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import NIO
22
import NIOConcurrencyHelpers
3+
import NIOSSL
34
import RediStack
45

56
/// A client for interfacing with a Redis instance.
6-
public struct RedisClient: Service {
7+
public struct RedisClient: Service, RediStack.RedisClient {
78
public struct Identifier: ServiceIdentifier {
89
private let hashable: AnyHashable
910
public init(hashable: AnyHashable) { self.hashable = hashable }
@@ -20,15 +21,77 @@ public struct RedisClient: Service {
2021
try provider.shutdown()
2122
}
2223

24+
// MARK: RediStack.RedisClient
25+
26+
public var eventLoop: EventLoop {
27+
Loop.current
28+
}
29+
30+
public func logging(to logger: Logger) -> RediStack.RedisClient {
31+
provider.logging(to: logger)
32+
}
33+
34+
public func send(command: String, with arguments: [RESPValue]) -> EventLoopFuture<RESPValue> {
35+
wrapError {
36+
try provider.getClient()
37+
.send(command: command, with: arguments).hop(to: Loop.current)
38+
}
39+
}
40+
41+
public func subscribe(
42+
to channels: [RedisChannelName],
43+
messageReceiver receiver: @escaping RedisSubscriptionMessageReceiver,
44+
onSubscribe subscribeHandler: RedisSubscriptionChangeHandler?,
45+
onUnsubscribe unsubscribeHandler: RedisSubscriptionChangeHandler?
46+
) -> EventLoopFuture<Void> {
47+
wrapError {
48+
try provider.getClient()
49+
.subscribe(
50+
to: channels,
51+
messageReceiver: receiver,
52+
onSubscribe: subscribeHandler,
53+
onUnsubscribe: unsubscribeHandler
54+
)
55+
}
56+
}
57+
58+
public func psubscribe(
59+
to patterns: [String],
60+
messageReceiver receiver: @escaping RedisSubscriptionMessageReceiver,
61+
onSubscribe subscribeHandler: RedisSubscriptionChangeHandler?,
62+
onUnsubscribe unsubscribeHandler: RedisSubscriptionChangeHandler?
63+
) -> EventLoopFuture<Void> {
64+
wrapError {
65+
try provider.getClient()
66+
.psubscribe(
67+
to: patterns,
68+
messageReceiver: receiver,
69+
onSubscribe: subscribeHandler,
70+
onUnsubscribe: unsubscribeHandler
71+
)
72+
}
73+
}
74+
75+
public func unsubscribe(from channels: [RedisChannelName]) -> EventLoopFuture<Void> {
76+
wrapError { try provider.getClient().unsubscribe(from: channels) }
77+
}
78+
79+
public func punsubscribe(from patterns: [String]) -> EventLoopFuture<Void> {
80+
wrapError { try provider.getClient().punsubscribe(from: patterns) }
81+
}
82+
83+
// MARK: Creating
84+
2385
/// A single redis connection
2486
public static func connection(
2587
_ host: String,
2688
port: Int = 6379,
2789
password: String? = nil,
2890
database: Int? = nil,
29-
poolSize: RedisConnectionPoolSize = .maximumActiveConnections(1)
91+
poolSize: RedisConnectionPoolSize = .maximumActiveConnections(1),
92+
tlsConfiguration: TLSConfiguration? = nil
3093
) -> RedisClient {
31-
return .cluster(.ip(host: host, port: port), password: password, database: database, poolSize: poolSize)
94+
return .cluster(.ip(host: host, port: port), password: password, database: database, poolSize: poolSize, tlsConfiguration: tlsConfiguration)
3295
}
3396

3497
/// Convenience initializer for creating a redis client with the
@@ -48,29 +111,21 @@ public struct RedisClient: Service {
48111
_ sockets: Socket...,
49112
password: String? = nil,
50113
database: Int? = nil,
51-
poolSize: RedisConnectionPoolSize = .maximumActiveConnections(1)
114+
poolSize: RedisConnectionPoolSize = .maximumActiveConnections(1),
115+
tlsConfiguration: TLSConfiguration? = nil
52116
) -> RedisClient {
53117
return .configuration(
54118
RedisConnectionPool.Configuration(
55-
initialServerConnectionAddresses: sockets.map {
56-
do {
57-
switch $0 {
58-
case let .ip(host, port):
59-
return try .makeAddressResolvingHost(host, port: port)
60-
case let .unix(path):
61-
return try .init(unixDomainSocketPath: path)
62-
}
63-
} catch {
64-
fatalError("Error generating socket address from `Socket` \(self)!")
65-
}
66-
},
119+
initialServerConnectionAddresses: [],
67120
maximumConnectionCount: poolSize,
68121
connectionFactoryConfiguration: RedisConnectionPool.ConnectionFactoryConfiguration(
69122
connectionInitialDatabase: database,
70123
connectionPassword: password,
71-
connectionDefaultLogger: Log.logger
124+
connectionDefaultLogger: Log.logger,
125+
tlsConfiguration: tlsConfiguration
72126
)
73-
)
127+
),
128+
addresses: sockets
74129
)
75130
}
76131

@@ -83,44 +138,36 @@ public struct RedisClient: Service {
83138
public static func configuration(_ config: RedisConnectionPool.Configuration) -> RedisClient {
84139
return RedisClient(provider: ConnectionPool(config: config))
85140
}
86-
}
87-
88-
/// Under the hood provider for `Redis`. Used so either connection pools
89-
/// or connections can be injected into `Redis` for accessing redis.
90-
public protocol RedisProvider {
91-
/// Get a redis client for running commands.
92-
func getClient() -> RediStack.RedisClient
93141

94-
/// Shut down.
95-
func shutdown() throws
96-
97-
/// Runs a transaction on the redis client using a given closure.
98-
///
99-
/// - Parameter transaction: An asynchronous transaction to run on
100-
/// the connection.
101-
/// - Returns: The resulting value of the transaction.
102-
func transaction<T>(_ transaction: @escaping (RedisProvider) async throws -> T) async throws -> T
142+
fileprivate static func configuration(_ config: RedisConnectionPool.Configuration, addresses: [Socket]) -> RedisClient {
143+
return RedisClient(provider: ConnectionPool(config: config, lazyAddresses: addresses))
144+
}
103145
}
104146

105147
/// A connection pool is a redis provider with a pool per `EventLoop`.
106-
private final class ConnectionPool: RedisProvider {
148+
private final class ConnectionPool: RedisProvider, RediStack.RedisClient {
107149
/// Map of `EventLoop` identifiers to respective connection pools.
108150
private var poolStorage: [ObjectIdentifier: RedisConnectionPool] = [:]
109151
private var poolLock = Lock()
152+
private var lazyAddresses: [Socket]?
153+
private var logger: Logger?
110154

111155
/// The configuration to create pools with.
112156
private var config: RedisConnectionPool.Configuration
113157

114-
init(config: RedisConnectionPool.Configuration) {
158+
init(config: RedisConnectionPool.Configuration, lazyAddresses: [Socket]? = nil) {
115159
self.config = config
160+
self.lazyAddresses = lazyAddresses
116161
}
162+
163+
// MARK: - RedisProvider
117164

118-
func getClient() -> RediStack.RedisClient {
119-
getPool()
165+
func getClient() throws -> RediStack.RedisClient {
166+
try getPool()
120167
}
121168

122169
func transaction<T>(_ transaction: @escaping (RedisProvider) async throws -> T) async throws -> T {
123-
let pool = getPool()
170+
let pool = try getPool()
124171
return try await pool.leaseConnection { conn in
125172
pool.eventLoop.asyncSubmit { try await transaction(conn) }
126173
}.get()
@@ -140,17 +187,88 @@ private final class ConnectionPool: RedisProvider {
140187
///
141188
/// - Returns: A `RedisConnectionPool` associated with the current
142189
/// `EventLoop` for sending commands to.
143-
private func getPool() -> RedisConnectionPool {
190+
private func getPool() throws -> RedisConnectionPool {
144191
let loop = Loop.current
145192
let key = ObjectIdentifier(loop)
146-
return poolLock.withLock {
193+
return try poolLock.withLock {
147194
if let pool = self.poolStorage[key] {
148195
return pool
149196
} else {
150-
let newPool = RedisConnectionPool(configuration: self.config, boundEventLoop: loop)
197+
var config = self.config
198+
if let lazyAddresses = lazyAddresses {
199+
let initialAddresses: [SocketAddress] = try lazyAddresses.map {
200+
switch $0 {
201+
case let .ip(host, port):
202+
return try .makeAddressResolvingHost(host, port: port)
203+
case let .unix(path):
204+
return try .init(unixDomainSocketPath: path)
205+
}
206+
}
207+
208+
config = RedisConnectionPool.Configuration(
209+
initialServerConnectionAddresses: initialAddresses,
210+
maximumConnectionCount: config.maximumConnectionCount,
211+
connectionFactoryConfiguration: config.factoryConfiguration,
212+
minimumConnectionCount: config.minimumConnectionCount,
213+
connectionBackoffFactor: config.connectionRetryConfiguration.backoff.factor,
214+
initialConnectionBackoffDelay: config.connectionRetryConfiguration.backoff.initialDelay,
215+
connectionRetryTimeout: config.connectionRetryConfiguration.timeout,
216+
poolDefaultLogger: config.poolDefaultLogger)
217+
}
218+
219+
let newPool = RedisConnectionPool(configuration: config, boundEventLoop: loop)
151220
self.poolStorage[key] = newPool
152-
return newPool
221+
if let logger = logger {
222+
return newPool.logging(to: logger) as? RedisConnectionPool ?? newPool
223+
} else {
224+
return newPool
225+
}
153226
}
154227
}
155228
}
229+
230+
// MARK: RediStack.RedisClient
231+
232+
var eventLoop: EventLoop { Loop.current }
233+
234+
func logging(to logger: Logger) -> RediStack.RedisClient {
235+
self.logger = logger
236+
return self
237+
}
238+
239+
func punsubscribe(from patterns: [String]) -> EventLoopFuture<Void> {
240+
wrapError { try getClient().punsubscribe(from: patterns) }
241+
}
242+
243+
func unsubscribe(from channels: [RedisChannelName]) -> EventLoopFuture<Void> {
244+
wrapError { try getClient().unsubscribe(from: channels) }
245+
}
246+
247+
func send(command: String, with arguments: [RESPValue]) -> EventLoopFuture<RESPValue> {
248+
wrapError { try getClient().send(command: command, with: arguments) }
249+
}
250+
251+
private func wrapError<T>(_ closure: () throws -> EventLoopFuture<T>) -> EventLoopFuture<T> {
252+
do { return try closure() }
253+
catch { return Loop.current.makeFailedFuture(error) }
254+
}
255+
}
256+
257+
extension RedisConnection: RedisProvider {
258+
public func getClient() -> RediStack.RedisClient {
259+
self
260+
}
261+
262+
public func shutdown() throws {
263+
try close().wait()
264+
}
265+
266+
public func transaction<T>(_ transaction: @escaping (RedisProvider) async throws -> T) async throws -> T {
267+
try await transaction(self)
268+
}
269+
}
270+
271+
private func wrapError<T>(_ closure: () throws -> EventLoopFuture<T>) -> EventLoopFuture<T> {
272+
do { return try closure() }
273+
catch { return Loop.current.makeFailedFuture(error) }
156274
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import RediStack
2+
3+
/// Under the hood provider for `Redis`. Used so either connection pools
4+
/// or connections can be injected into `Redis` for accessing redis.
5+
public protocol RedisProvider {
6+
/// Get a redis client for running commands.
7+
func getClient() throws -> RediStack.RedisClient
8+
9+
/// Log with the given logger.
10+
func logging(to logger: Logger) -> RediStack.RedisClient
11+
12+
/// Shut down.
13+
func shutdown() throws
14+
15+
/// Runs a transaction on the redis client using a given closure.
16+
///
17+
/// - Parameter transaction: An asynchronous transaction to run on
18+
/// the connection.
19+
/// - Returns: The resulting value of the transaction.
20+
func transaction<T>(_ transaction: @escaping (RedisProvider) async throws -> T) async throws -> T
21+
}

0 commit comments

Comments
 (0)