Skip to content

Commit 5160e36

Browse files
authored
Enable strict concurrency (#54)
### Motivation: Catch potential data races at build time. ### Modifications: - Enabled strict concurrency checking in the Package.swift. - Made a few types Sendable: `AresChannel`, `CAresDNSResolver`, `Ares`, `QueryProcessor`, `DNSSDDNSResolver`. ### Result: Fewer potential data races can sneak in. ### Test Plan Ran tests locally, did not see any concurrency warnings or errors, all tests passed.
1 parent 98d62d4 commit 5160e36

File tree

8 files changed

+83
-67
lines changed

8 files changed

+83
-67
lines changed

Package.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,9 @@ let package = Package(
5050
],
5151
cLanguageStandard: .gnu11
5252
)
53+
54+
for target in package.targets {
55+
var settings = target.swiftSettings ?? []
56+
settings.append(.enableExperimentalFeature("StrictConcurrency=complete"))
57+
target.swiftSettings = settings
58+
}

Sources/AsyncDNSResolver/c-ares/AresChannel.swift

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@ import Foundation
1818
// MARK: - ares_channel
1919

2020
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
21-
class AresChannel {
22-
let pointer: UnsafeMutablePointer<ares_channel?>
23-
let lock = NSLock()
21+
final class AresChannel: @unchecked Sendable {
22+
private let locked_pointer: UnsafeMutablePointer<ares_channel?>
23+
private let lock = NSLock()
2424

25-
private var underlying: ares_channel? {
26-
self.pointer.pointee
25+
// For testing only.
26+
var underlying: ares_channel? {
27+
self.locked_pointer.pointee
2728
}
2829

2930
deinit {
30-
ares_destroy(pointer.pointee)
31-
pointer.deallocate()
31+
// Safe to perform without the lock, as in deinit we know that no more
32+
// strong references to self exist, so nobody can be holding the lock.
33+
ares_destroy(locked_pointer.pointee)
34+
locked_pointer.deallocate()
3235
ares_library_cleanup()
3336
}
3437

@@ -49,7 +52,7 @@ class AresChannel {
4952
try checkAresResult { ares_set_sortlist(pointer.pointee, sortlist) }
5053
}
5154

52-
self.pointer = pointer
55+
self.locked_pointer = pointer
5356
}
5457

5558
func withChannel(_ body: (ares_channel) -> Void) {

Sources/AsyncDNSResolver/c-ares/AresOptions.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import CAsyncDNSResolver
1919
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
2020
extension CAresDNSResolver {
2121
/// Options for ``CAresDNSResolver``.
22-
public struct Options {
22+
public struct Options: Sendable {
2323
public static var `default`: Options {
2424
.init()
2525
}
@@ -91,7 +91,7 @@ extension CAresDNSResolver {
9191

9292
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
9393
extension CAresDNSResolver.Options {
94-
public struct Flags: OptionSet {
94+
public struct Flags: OptionSet, Sendable {
9595
public let rawValue: Int32
9696

9797
public init(rawValue: Int32) {

Sources/AsyncDNSResolver/c-ares/DNSResolver_c-ares.swift

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
import CAsyncDNSResolver
16+
import Foundation
1617

1718
/// ``DNSResolver`` implementation backed by c-ares C library.
1819
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
19-
public class CAresDNSResolver: DNSResolver {
20+
public final class CAresDNSResolver: DNSResolver, Sendable {
2021
let options: Options
2122
let ares: Ares
2223

@@ -121,18 +122,15 @@ extension QueryType {
121122
// MARK: - c-ares query wrapper
122123

123124
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
124-
class Ares {
125+
final class Ares: Sendable {
125126
typealias QueryCallback = @convention(c) (
126127
UnsafeMutableRawPointer?, CInt, CInt, UnsafeMutablePointer<CUnsignedChar>?, CInt
127128
) -> Void
128129

129-
let options: AresOptions
130-
let channel: AresChannel
131-
130+
private let channel: AresChannel
132131
private let queryProcessor: QueryProcessor
133132

134133
init(options: AresOptions) throws {
135-
self.options = options
136134
self.channel = try AresChannel(options: options)
137135

138136
// Need to call `ares_process` or `ares_process_fd` for query callbacks to happen
@@ -145,7 +143,8 @@ class Ares {
145143
name: String,
146144
replyParser: ReplyParser
147145
) async throws -> ReplyParser.Reply {
148-
try await withTaskCancellationHandler(
146+
let channel = self.channel
147+
return try await withTaskCancellationHandler(
149148
operation: {
150149
try await withCheckedThrowingContinuation { continuation in
151150
let handler = QueryReplyHandler(parser: replyParser, continuation)
@@ -178,7 +177,7 @@ class Ares {
178177
}
179178
},
180179
onCancel: {
181-
self.channel.withChannel { channel in
180+
channel.withChannel { channel in
182181
ares_cancel(channel)
183182
}
184183
}
@@ -198,16 +197,18 @@ extension Ares {
198197
// https://github.com/dimbleby/c-ares-resolver/blob/master/src/unix/eventloop.rs // ignore-unacceptable-language
199198
// https://github.com/dimbleby/rust-c-ares/blob/master/src/channel.rs // ignore-unacceptable-language
200199
// https://github.com/dimbleby/rust-c-ares/blob/master/examples/event-loop.rs // ignore-unacceptable-language
201-
class QueryProcessor {
200+
final class QueryProcessor: @unchecked Sendable {
202201
static let defaultPollInterval: UInt64 = 10 * 1_000_000 // 10ms
203202

204203
private let channel: AresChannel
205204
private let pollIntervalNanos: UInt64
206205

207-
private var pollingTask: Task<Void, Error>?
206+
private let lock = NSLock()
207+
private var locked_pollingTask: Task<Void, Error>?
208208

209209
deinit {
210-
self.pollingTask?.cancel()
210+
// No need to lock here as there can exist no more strong references to self.
211+
self.locked_pollingTask?.cancel()
211212
}
212213

213214
init(channel: AresChannel, pollIntervalNanos: UInt64 = QueryProcessor.defaultPollInterval) {
@@ -218,7 +219,7 @@ extension Ares {
218219
/// Asks c-ares for the set of socket descriptors we are waiting on for the `ares_channel`'s pending queries
219220
/// then call `ares_process_fd` if any is ready for read and/or write.
220221
/// c-ares returns up to `ARES_GETSOCK_MAXNUM` socket descriptors only. If more are in use (unlikely) they are not reported back.
221-
func poll() async {
222+
func poll() {
222223
var socks = [ares_socket_t](repeating: ares_socket_t(), count: Int(ARES_GETSOCK_MAXNUM))
223224

224225
self.channel.withChannel { channel in
@@ -249,12 +250,14 @@ extension Ares {
249250
}
250251

251252
private func schedule() {
252-
self.pollingTask = Task { [weak self] in
253+
self.lock.lock()
254+
defer { self.lock.unlock() }
255+
self.locked_pollingTask = Task { [weak self] in
253256
guard let s = self else {
254257
return
255258
}
256259
try await Task.sleep(nanoseconds: s.pollIntervalNanos)
257-
await s.poll()
260+
s.poll()
258261
}
259262
}
260263
}
@@ -291,7 +294,7 @@ extension Ares {
291294
// MARK: - c-ares query reply parsers
292295

293296
protocol AresQueryReplyParser {
294-
associatedtype Reply
297+
associatedtype Reply: Sendable
295298

296299
func parse(buffer: UnsafeMutablePointer<CUnsignedChar>?, length: CInt) throws -> Reply
297300
}

Sources/AsyncDNSResolver/dnssd/DNSResolver_dnssd.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import dnssd
1717

1818
/// ``DNSResolver`` implementation backed by dnssd framework.
1919
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
20-
public struct DNSSDDNSResolver: DNSResolver {
20+
public struct DNSSDDNSResolver: DNSResolver, Sendable {
2121
let dnssd: DNSSD
2222

2323
init() {
@@ -100,7 +100,7 @@ extension QueryType {
100100
// MARK: - dnssd query wrapper
101101

102102
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
103-
struct DNSSD {
103+
struct DNSSD: Sendable {
104104
// Reference: https://gist.github.com/fikeminkel/a9c4bc4d0348527e8df3690e242038d3
105105
func query<ReplyHandler: DNSSDQueryReplyHandler>(
106106
type: QueryType,
@@ -225,7 +225,7 @@ extension DNSSD {
225225
// MARK: - dnssd query reply handlers
226226

227227
protocol DNSSDQueryReplyHandler {
228-
associatedtype Record
228+
associatedtype Record: Sendable
229229
associatedtype Reply
230230

231231
func parseRecord(data: UnsafeRawPointer?, length: UInt16) throws -> Record?

Tests/AsyncDNSResolverTests/c-ares/AresChannelTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ final class AresChannelTests: XCTestCase {
2929
guard let channel = try? AresChannel(options: options) else {
3030
return XCTFail("Channel not initialized")
3131
}
32-
guard let _ = channel.pointer.pointee else {
32+
guard let _ = channel.underlying else {
3333
return XCTFail("Underlying ares_channel is nil")
3434
}
3535
}

Tests/AsyncDNSResolverTests/c-ares/CAresDNSResolverTests.swift

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ final class CAresDNSResolverTests: XCTestCase {
124124
func test_concurrency() async throws {
125125
func run(
126126
times: Int = 100,
127-
_ query: @escaping (_ index: Int) async throws -> Void
127+
_ query: @Sendable @escaping (_ index: Int) async throws -> Void
128128
) async throws {
129129
try await withThrowingTaskGroup(of: Void.self) { group in
130130
for i in 1...times {
@@ -136,65 +136,67 @@ final class CAresDNSResolverTests: XCTestCase {
136136
}
137137
}
138138

139+
let resolver = self.resolver!
140+
let verbose = self.verbose
139141
try await run { i in
140-
let reply = try await self.resolver.queryA(name: "apple.com")
141-
if self.verbose {
142+
let reply = try await resolver.queryA(name: "apple.com")
143+
if verbose {
142144
print("[A] run #\(i) result: \(reply)")
143145
}
144146
}
145147

146148
try await run { i in
147-
let reply = try await self.resolver.queryAAAA(name: "apple.com")
148-
if self.verbose {
149+
let reply = try await resolver.queryAAAA(name: "apple.com")
150+
if verbose {
149151
print("[AAAA] run #\(i) result: \(reply)")
150152
}
151153
}
152154

153155
try await run { i in
154-
let reply = try await self.resolver.queryNS(name: "apple.com")
155-
if self.verbose {
156+
let reply = try await resolver.queryNS(name: "apple.com")
157+
if verbose {
156158
print("[NS] run #\(i) result: \(reply)")
157159
}
158160
}
159161

160162
try await run { i in
161-
let reply = try await self.resolver.queryCNAME(name: "www.apple.com")
162-
if self.verbose {
163+
let reply = try await resolver.queryCNAME(name: "www.apple.com")
164+
if verbose {
163165
print("[CNAME] run #\(i) result: \(String(describing: reply))")
164166
}
165167
}
166168

167169
try await run { i in
168-
let reply = try await self.resolver.querySOA(name: "apple.com")
169-
if self.verbose {
170+
let reply = try await resolver.querySOA(name: "apple.com")
171+
if verbose {
170172
print("[SOA] run #\(i) result: \(String(describing: reply))")
171173
}
172174
}
173175

174176
try await run { i in
175-
let reply = try await self.resolver.queryPTR(name: "47.224.172.17.in-addr.arpa")
176-
if self.verbose {
177+
let reply = try await resolver.queryPTR(name: "47.224.172.17.in-addr.arpa")
178+
if verbose {
177179
print("[PTR] run #\(i) result: \(reply)")
178180
}
179181
}
180182

181183
try await run { i in
182-
let reply = try await self.resolver.queryMX(name: "apple.com")
183-
if self.verbose {
184+
let reply = try await resolver.queryMX(name: "apple.com")
185+
if verbose {
184186
print("[MX] run #\(i) result: \(reply)")
185187
}
186188
}
187189

188190
try await run { i in
189-
let reply = try await self.resolver.queryTXT(name: "apple.com")
190-
if self.verbose {
191+
let reply = try await resolver.queryTXT(name: "apple.com")
192+
if verbose {
191193
print("[TXT] run #\(i) result: \(reply)")
192194
}
193195
}
194196

195197
try await run { i in
196-
let reply = try await self.resolver.querySRV(name: "_caldavs._tcp.google.com")
197-
if self.verbose {
198+
let reply = try await resolver.querySRV(name: "_caldavs._tcp.google.com")
199+
if verbose {
198200
print("[SRV] run #\(i) result: \(reply)")
199201
}
200202
}

0 commit comments

Comments
 (0)