Skip to content

Commit b6cce69

Browse files
authored
Use c-ares event thread when available (#145)
1 parent bca3ae9 commit b6cce69

File tree

3 files changed

+73
-24
lines changed

3 files changed

+73
-24
lines changed

README.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,17 @@ The API is pretty simple, three functions are provided in the ``DNSResolver`` cl
5858
Note for Windows users
5959
======================
6060

61-
This library requires the asyncio loop to be a `SelectorEventLoop`, which is not the default on Windows.
61+
This library requires the use of an ``asyncio.SelectorEventLoop`` on Windows
62+
**only** when using a custom build of ``pycares`` that links against a system-
63+
provided ``c-ares`` library **without** thread-safety support. This is because
64+
non-thread-safe builds of ``c-ares`` are incompatible with the default
65+
``ProactorEventLoop`` on Windows.
6266

63-
The default can be changed as follows (do this very early in your application):
67+
If you're using the official prebuilt ``pycares`` wheels on PyPI (version 4.7.0 or
68+
later), which include a thread-safe version of ``c-ares``, this limitation does
69+
**not** apply and can be safely ignored.
70+
71+
The default event loop can be changed as follows (do this very early in your application):
6472

6573
.. code:: python
6674

aiodns/__init__.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from typing import (
99
Any,
10+
Callable,
1011
Optional,
1112
Set,
1213
Sequence,
@@ -21,6 +22,11 @@
2122

2223
__all__ = ('DNSResolver', 'error')
2324

25+
WINDOWS_SELECTOR_ERR_MSG = (
26+
"aiodns needs a SelectorEventLoop on Windows. See more: "
27+
"https://github.com/aio-libs/aiodns#note-for-windows-users"
28+
)
29+
2430

2531
READ = 1
2632
WRITE = 2
@@ -52,22 +58,26 @@ def __init__(self, nameservers: Optional[Sequence[str]] = None,
5258
**kwargs: Any) -> None:
5359
self.loop = loop or asyncio.get_event_loop()
5460
assert self.loop is not None
55-
if sys.platform == 'win32':
56-
if not isinstance(self.loop, asyncio.SelectorEventLoop):
61+
kwargs.pop('sock_state_cb', None)
62+
timeout = kwargs.pop('timeout', None)
63+
self._timeout = timeout
64+
self._event_thread = hasattr(pycares,"ares_threadsafety") and pycares.ares_threadsafety()
65+
if self._event_thread:
66+
# pycares is thread safe
67+
self._channel = pycares.Channel(event_thread=True,
68+
timeout=timeout,
69+
**kwargs)
70+
else:
71+
if sys.platform == 'win32' and not isinstance(self.loop, asyncio.SelectorEventLoop):
5772
try:
5873
import winloop
5974
if not isinstance(self.loop , winloop.Loop):
60-
raise RuntimeError(
61-
'aiodns needs a SelectorEventLoop on Windows. See more: https://github.com/saghul/aiodns/issues/86')
75+
raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG)
6276
except ModuleNotFoundError:
63-
raise RuntimeError(
64-
'aiodns needs a SelectorEventLoop on Windows. See more: https://github.com/saghul/aiodns/issues/86')
65-
kwargs.pop('sock_state_cb', None)
66-
timeout = kwargs.pop('timeout', None)
67-
self._timeout = timeout
68-
self._channel = pycares.Channel(sock_state_cb=self._sock_state_cb,
69-
timeout=timeout,
70-
**kwargs)
77+
raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG)
78+
self._channel = pycares.Channel(sock_state_cb=self._sock_state_cb,
79+
timeout=timeout,
80+
**kwargs)
7181
if nameservers:
7282
self.nameservers = nameservers
7383
self._read_fds = set() # type: Set[int]
@@ -91,6 +101,20 @@ def _callback(fut: asyncio.Future, result: Any, errorno: int) -> None:
91101
else:
92102
fut.set_result(result)
93103

104+
def _get_future_callback(self) -> Tuple["asyncio.Future[Any]", Callable[[Any, int], None]]:
105+
"""Return a future and a callback to set the result of the future."""
106+
cb: Callable[[Any, int], None]
107+
future: "asyncio.Future[Any]" = self.loop.create_future()
108+
if self._event_thread:
109+
cb = functools.partial( # type: ignore[assignment]
110+
self.loop.call_soon_threadsafe,
111+
self._callback, # type: ignore[arg-type]
112+
future
113+
)
114+
else:
115+
cb = functools.partial(self._callback, future)
116+
return future, cb
117+
94118
def query(self, host: str, qtype: str, qclass: Optional[str]=None) -> asyncio.Future:
95119
try:
96120
qtype = query_type_map[qtype]
@@ -102,32 +126,27 @@ def query(self, host: str, qtype: str, qclass: Optional[str]=None) -> asyncio.Fu
102126
except KeyError:
103127
raise ValueError('invalid query class: {}'.format(qclass))
104128

105-
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
106-
cb = functools.partial(self._callback, fut)
129+
fut, cb = self._get_future_callback()
107130
self._channel.query(host, qtype, cb, query_class=qclass)
108131
return fut
109132

110133
def gethostbyname(self, host: str, family: socket.AddressFamily) -> asyncio.Future:
111-
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
112-
cb = functools.partial(self._callback, fut)
134+
fut, cb = self._get_future_callback()
113135
self._channel.gethostbyname(host, family, cb)
114136
return fut
115137

116138
def getaddrinfo(self, host: str, family: socket.AddressFamily = socket.AF_UNSPEC, port: Optional[int] = None, proto: int = 0, type: int = 0, flags: int = 0) -> asyncio.Future:
117-
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
118-
cb = functools.partial(self._callback, fut)
139+
fut, cb = self._get_future_callback()
119140
self._channel.getaddrinfo(host, port, cb, family=family, type=type, proto=proto, flags=flags)
120141
return fut
121142

122143
def getnameinfo(self, sockaddr: Union[Tuple[str, int], Tuple[str, int, int, int]], flags: int = 0) -> asyncio.Future:
123-
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
124-
cb = functools.partial(self._callback, fut)
144+
fut, cb = self._get_future_callback()
125145
self._channel.getnameinfo(sockaddr, flags, cb)
126146
return fut
127147

128148
def gethostbyaddr(self, name: str) -> asyncio.Future:
129-
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
130-
cb = functools.partial(self._callback, fut)
149+
fut, cb = self._get_future_callback()
131150
self._channel.gethostbyaddr(name, cb)
132151
return fut
133152

tests/test_aiodns.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import asyncio
44
import ipaddress
55
import unittest
6+
import pytest
67
import socket
78
import sys
89
import time
10+
import unittest.mock
911

1012
import aiodns
1113

@@ -211,5 +213,25 @@ def setUp(self):
211213
self.resolver = aiodns.DNSResolver(loop=self.loop, timeout=5.0)
212214
self.resolver.nameservers = ['8.8.8.8']
213215

216+
217+
class TestNoEventThreadDNS(DNSTest):
218+
"""Test DNSResolver with no event thread."""
219+
220+
def setUp(self):
221+
with unittest.mock.patch('aiodns.pycares.ares_threadsafety', return_value=False):
222+
super().setUp()
223+
224+
225+
@unittest.skipIf(sys.platform != 'win32', 'Only run on Windows')
226+
def test_win32_no_selector_event_loop():
227+
"""Test DNSResolver with Windows without SelectorEventLoop."""
228+
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
229+
with (
230+
pytest.raises(RuntimeError, match="aiodns needs a SelectorEventLoop on Windows"),
231+
unittest.mock.patch('aiodns.pycares.ares_threadsafety', return_value=False)
232+
):
233+
aiodns.DNSResolver(loop=asyncio.new_event_loop(), timeout=5.0)
234+
235+
214236
if __name__ == "__main__": # pragma: no cover
215237
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)