Skip to content

Commit b7079b7

Browse files
authored
Fix continuation memory leak in Ares.query (#31)
* Fix continuation memory leak in Ares.query * Use class for QueryReplyHandler * Deallocate QueryReplyHandler for DNSSD * Move defer block after allocate/initialize, use class Move defer deallocation block to after initialization. Use class instead of struct for DNSSD.QueryReplyHandler.
1 parent d9afa74 commit b7079b7

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

Sources/AsyncDNSResolver/c-ares/DNSResolver_c-ares.swift

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,12 @@ class Ares {
160160
preconditionFailure("'arg' is nil. This is a bug.")
161161
}
162162

163-
let handler = QueryReplyHandler(pointer: handlerPointer)
164-
defer { handlerPointer.deallocate() }
163+
let pointer = handlerPointer.assumingMemoryBound(to: QueryReplyHandler.self)
164+
let handler = pointer.pointee
165+
defer {
166+
pointer.deinitialize(count: 1)
167+
pointer.deallocate()
168+
}
165169

166170
handler.handle(status: status, buffer: buf, length: len)
167171
}
@@ -258,7 +262,7 @@ extension Ares {
258262

259263
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
260264
extension Ares {
261-
struct QueryReplyHandler {
265+
class QueryReplyHandler {
262266
private let _handler: (CInt, UnsafeMutablePointer<CUnsignedChar>?, CInt) -> Void
263267

264268
init<Parser: AresQueryReplyParser>(parser: Parser, _ continuation: CheckedContinuation<Parser.Reply, Error>) {
@@ -276,11 +280,6 @@ extension Ares {
276280
}
277281
}
278282

279-
init(pointer: UnsafeMutableRawPointer) {
280-
let handlerPointer = pointer.assumingMemoryBound(to: Self.self)
281-
self = handlerPointer.pointee
282-
}
283-
284283
func handle(status: CInt, buffer: UnsafeMutablePointer<CUnsignedChar>?, length: CInt) {
285284
self._handler(status, buffer, length)
286285
}

Sources/AsyncDNSResolver/dnssd/DNSResolver_dnssd.swift

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,25 @@ struct DNSSD {
115115
byteCount: MemoryLayout<QueryReplyHandler>.stride,
116116
alignment: MemoryLayout<QueryReplyHandler>.alignment
117117
)
118-
// The handler might be called multiple times so don't deallocate inside `callback`
119-
defer { handlerPointer.deallocate() }
120118

121119
handlerPointer.initializeMemory(as: QueryReplyHandler.self, repeating: handler, count: 1)
122120

121+
// The handler might be called multiple times so don't deallocate inside `callback`
122+
defer {
123+
let pointer = handlerPointer.assumingMemoryBound(to: QueryReplyHandler.self)
124+
pointer.deinitialize(count: 1)
125+
pointer.deallocate()
126+
}
127+
123128
// This is called once per record received
124129
let callback: DNSServiceQueryRecordReply = { _, _, _, errorCode, _, _, _, rdlen, rdata, _, context in
125130
guard let handlerPointer = context else {
126131
preconditionFailure("'context' is nil. This is a bug.")
127132
}
128133

129-
let handler = QueryReplyHandler(pointer: handlerPointer)
134+
let pointer = handlerPointer.assumingMemoryBound(to: QueryReplyHandler.self)
135+
let handler = pointer.pointee
136+
130137
// This parses a record then adds it to the stream
131138
handler.handleRecord(errorCode: errorCode, data: rdata, length: rdlen)
132139
}
@@ -171,7 +178,7 @@ struct DNSSD {
171178

172179
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
173180
extension DNSSD {
174-
struct QueryReplyHandler {
181+
class QueryReplyHandler {
175182
private let _handleRecord: (DNSServiceErrorType, UnsafeRawPointer?, UInt16) -> Void
176183

177184
init<Handler: DNSSDQueryReplyHandler>(handler: Handler, _ continuation: AsyncThrowingStream<Handler.Record, Error>.Continuation) {
@@ -189,11 +196,6 @@ extension DNSSD {
189196
}
190197
}
191198

192-
init(pointer: UnsafeMutableRawPointer) {
193-
let handlerPointer = pointer.assumingMemoryBound(to: Self.self)
194-
self = handlerPointer.pointee
195-
}
196-
197199
func handleRecord(errorCode: DNSServiceErrorType, data: UnsafeRawPointer?, length: UInt16) {
198200
self._handleRecord(errorCode, data, length)
199201
}

0 commit comments

Comments
 (0)