diff --git a/Sources/AsyncDNSResolver/dnssd/DNSResolver_dnssd.swift b/Sources/AsyncDNSResolver/dnssd/DNSResolver_dnssd.swift index 29c306e..bc857ad 100644 --- a/Sources/AsyncDNSResolver/dnssd/DNSResolver_dnssd.swift +++ b/Sources/AsyncDNSResolver/dnssd/DNSResolver_dnssd.swift @@ -110,21 +110,6 @@ struct DNSSD { let recordStream = AsyncThrowingStream { continuation in let handler = QueryReplyHandler(handler: replyHandler, continuation) - // Wrap `handler` into a pointer so we can pass it to DNSServiceQueryRecord - let handlerPointer = UnsafeMutableRawPointer.allocate( - byteCount: MemoryLayout.stride, - alignment: MemoryLayout.alignment - ) - - handlerPointer.initializeMemory(as: QueryReplyHandler.self, repeating: handler, count: 1) - - // The handler might be called multiple times so don't deallocate inside `callback` - defer { - let pointer = handlerPointer.assumingMemoryBound(to: QueryReplyHandler.self) - pointer.deinitialize(count: 1) - pointer.deallocate() - } - // This is called once per record received let callback: DNSServiceQueryRecordReply = { _, _, _, errorCode, _, _, _, rdlen, rdata, _, context in guard let handlerPointer = context else { @@ -138,32 +123,54 @@ struct DNSSD { handler.handleRecord(errorCode: errorCode, data: rdata, length: rdlen) } - let serviceRefPtr = UnsafeMutablePointer.allocate(capacity: 1) - defer { serviceRefPtr.deallocate() } + let serviceRefPointer = UnsafeMutablePointer.allocate(capacity: 1) + + // Wrap 'handler' into a pointer so we can pass it to DNSServiceQueryRecord + let replyHandlerPointer = UnsafeMutablePointer.allocate(capacity: 1) + replyHandlerPointer.initialize(to: handler) // Run the query let _code = DNSServiceQueryRecord( - serviceRefPtr, + serviceRefPointer, kDNSServiceFlagsTimeout, 0, name, UInt16(type.kDNSServiceType), UInt16(kDNSServiceClass_IN), callback, - handlerPointer + replyHandlerPointer ) // Check if query completed successfully guard _code == kDNSServiceErr_NoError else { + DNSSD.deallocatePointers(serviceRefPointer: serviceRefPointer, replyHandlerPointer: replyHandlerPointer) return continuation.finish(throwing: AsyncDNSResolver.Error(dnssdCode: _code)) } - // Read reply from the socket (blocking) then call reply handler - DNSServiceProcessResult(serviceRefPtr.pointee) - DNSServiceRefDeallocate(serviceRefPtr.pointee) + let serviceSockFD = DNSServiceRefSockFD(serviceRefPointer.pointee) + guard serviceSockFD != -1 else { + DNSSD.deallocatePointers(serviceRefPointer: serviceRefPointer, replyHandlerPointer: replyHandlerPointer) + return continuation.finish(throwing: AsyncDNSResolver.Error(code: .internalError, message: "Failed to access the DNSSD service socket")) + } + + let readSource = DispatchSource.makeReadSource(fileDescriptor: serviceSockFD) + readSource.setEventHandler { + // Read reply from the socket (blocking) then call reply handler + DNSServiceProcessResult(serviceRefPointer.pointee) + + readSource.cancel() + continuation.finish() + } + + readSource.setCancelHandler { + DNSSD.deallocatePointers(serviceRefPointer: serviceRefPointer, replyHandlerPointer: replyHandlerPointer) + close(serviceSockFD) + } + readSource.resume() - // Streaming done - continuation.finish() + continuation.onTermination = { _ in + readSource.cancel() + } } // Build reply using records received @@ -173,6 +180,12 @@ struct DNSSD { return try replyHandler.generateReply(records: records) } + + private static func deallocatePointers(serviceRefPointer: UnsafeMutablePointer, replyHandlerPointer: UnsafeMutablePointer) { + DNSServiceRefDeallocate(serviceRefPointer.pointee) + serviceRefPointer.deallocate() + replyHandlerPointer.deallocate() + } } // MARK: - dnssd query reply handler