Skip to content

Commit c856d44

Browse files
authored
Fallback to sock_state_cb if event_thread creation fails (#151)
1 parent 5c4b29c commit c856d44

File tree

4 files changed

+240
-21
lines changed

4 files changed

+240
-21
lines changed

aiodns/__init__.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
22
import functools
3+
import logging
4+
import pycares
35
import socket
46
import sys
57
from collections.abc import Iterable, Sequence
@@ -30,6 +32,7 @@
3032
"https://github.com/aio-libs/aiodns#note-for-windows-users"
3133
)
3234

35+
_LOGGER = logging.getLogger(__name__)
3336

3437
READ = 1
3538
WRITE = 2
@@ -64,29 +67,50 @@ def __init__(self, nameservers: Optional[Sequence[str]] = None,
6467
kwargs.pop('sock_state_cb', None)
6568
timeout = kwargs.pop('timeout', None)
6669
self._timeout = timeout
67-
self._event_thread = hasattr(pycares,"ares_threadsafety") and pycares.ares_threadsafety()
68-
if self._event_thread:
69-
# pycares is thread safe
70-
self._channel = pycares.Channel(event_thread=True,
71-
timeout=timeout,
72-
**kwargs)
73-
else:
74-
if sys.platform == 'win32' and not isinstance(self.loop, asyncio.SelectorEventLoop):
75-
try:
76-
import winloop
77-
if not isinstance(self.loop , winloop.Loop):
78-
raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG)
79-
except ModuleNotFoundError:
80-
raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG)
81-
self._channel = pycares.Channel(sock_state_cb=self._sock_state_cb,
82-
timeout=timeout,
83-
**kwargs)
70+
self._event_thread, self._channel = self._make_channel(**kwargs)
8471
if nameservers:
8572
self.nameservers = nameservers
8673
self._read_fds: set[int] = set()
8774
self._write_fds: set[int] = set()
8875
self._timer: Optional[asyncio.TimerHandle] = None
8976

77+
def _make_channel(self, **kwargs: Any) -> Tuple[bool, pycares.Channel]:
78+
if hasattr(pycares, "ares_threadsafety") and pycares.ares_threadsafety():
79+
# pycares is thread safe
80+
try:
81+
return True, pycares.Channel(
82+
event_thread=True, timeout=self._timeout, **kwargs
83+
)
84+
except pycares.AresError as e:
85+
if sys.platform == "linux":
86+
_LOGGER.warning(
87+
"Failed to create a DNS resolver channel with automatic monitoring of "
88+
"resolver configuration changes, this usually means the system ran "
89+
"out of inotify watches. Falling back to socket state callback. "
90+
"Consider increasing the system inotify watch limit: %s",
91+
e,
92+
)
93+
else:
94+
_LOGGER.warning(
95+
"Failed to create a DNS resolver channel with automatic monitoring "
96+
"of resolver configuration changes. Falling back to socket state "
97+
"callback: %s",
98+
e,
99+
)
100+
if sys.platform == "win32" and not isinstance(
101+
self.loop, asyncio.SelectorEventLoop
102+
):
103+
try:
104+
import winloop
105+
106+
if not isinstance(self.loop, winloop.Loop):
107+
raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG)
108+
except ModuleNotFoundError as ex:
109+
raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG) from ex
110+
return False, pycares.Channel(
111+
sock_state_cb=self._sock_state_cb, timeout=self._timeout, **kwargs
112+
)
113+
90114
@property
91115
def nameservers(self) -> Sequence[str]:
92116
return self._channel.servers

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ addopts =
1010
--showlocals
1111
# coverage reports
1212
--cov=aiodns/ --cov=tests/ --cov-report term
13+
asyncio_default_fixture_loop_scope = function
14+
asyncio_mode = auto
1315
filterwarnings =
1416
error
1517
testpaths = tests/

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
-e .
22
pycares==4.8.0
33
pytest==8.3.5
4+
pytest-asyncio==0.26.0
45
pytest-cov==6.1.1
56
uvloop==0.21.0; platform_system != "Windows" and implementation_name == "cpython"
67
winloop==0.1.8; platform_system == "Windows"

tests/test_aiodns.py

Lines changed: 196 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22

33
import asyncio
44
import ipaddress
5+
import logging
56
import unittest
67
import pytest
78
import socket
89
import sys
910
import time
1011
import unittest.mock
12+
from typing import Any
1113

1214
import pycares
1315

1416
import aiodns
17+
import pycares
1518

1619
try:
1720
if sys.platform == "win32":
@@ -228,12 +231,201 @@ def setUp(self) -> None:
228231
@unittest.skipIf(sys.platform != 'win32', 'Only run on Windows')
229232
def test_win32_no_selector_event_loop() -> None:
230233
"""Test DNSResolver with Windows without SelectorEventLoop."""
231-
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
234+
# Create a non-SelectorEventLoop to trigger the error
235+
mock_loop = unittest.mock.MagicMock(spec=asyncio.AbstractEventLoop)
236+
mock_loop.__class__ = (
237+
asyncio.AbstractEventLoop # type: ignore[assignment]
238+
)
239+
240+
with (
241+
pytest.raises(
242+
RuntimeError, match="aiodns needs a SelectorEventLoop on Windows"
243+
),
244+
unittest.mock.patch("aiodns.pycares.ares_threadsafety", return_value=False),
245+
unittest.mock.patch("sys.platform", "win32"),
246+
):
247+
aiodns.DNSResolver(loop=mock_loop, timeout=5.0)
248+
249+
250+
@pytest.mark.parametrize(
251+
("platform", "expected_msg_parts", "unexpected_msg_parts"),
252+
[
253+
(
254+
"linux",
255+
[
256+
"automatic monitoring of",
257+
"resolver configuration changes",
258+
"system ran out of inotify watches",
259+
"Falling back to socket state callback",
260+
"Consider increasing the system inotify watch limit",
261+
],
262+
[],
263+
),
264+
(
265+
"darwin",
266+
[
267+
"automatic monitoring",
268+
"resolver configuration changes",
269+
"Falling back to socket state callback",
270+
],
271+
[
272+
"system ran out of inotify watches",
273+
"Consider increasing the system inotify watch limit",
274+
],
275+
),
276+
(
277+
"win32",
278+
[
279+
"automatic monitoring",
280+
"resolver configuration changes",
281+
"Falling back to socket state callback",
282+
],
283+
[
284+
"system ran out of inotify watches",
285+
"Consider increasing the system inotify watch limit",
286+
],
287+
),
288+
],
289+
)
290+
async def test_make_channel_ares_error(
291+
platform: str,
292+
expected_msg_parts: list[str],
293+
unexpected_msg_parts: list[str],
294+
caplog: pytest.LogCaptureFixture,
295+
) -> None:
296+
"""Test exception handling in _make_channel on different platforms."""
297+
# Set log level to capture warnings
298+
caplog.set_level(logging.WARNING)
299+
300+
# Create a mock loop that is a SelectorEventLoop to avoid Windows-specific errors
301+
mock_loop = unittest.mock.MagicMock(spec=asyncio.SelectorEventLoop)
302+
mock_channel = unittest.mock.MagicMock()
303+
304+
with (
305+
unittest.mock.patch("sys.platform", platform),
306+
# Configure first Channel call to raise AresError, second call to return our mock
307+
unittest.mock.patch(
308+
"aiodns.pycares.Channel",
309+
side_effect=[
310+
pycares.AresError("Mock error"),
311+
mock_channel,
312+
],
313+
),
314+
unittest.mock.patch("aiodns.pycares.ares_threadsafety", return_value=True),
315+
# Also patch asyncio.get_event_loop to return our mock loop
316+
unittest.mock.patch("asyncio.get_event_loop", return_value=mock_loop),
317+
):
318+
# Create resolver which will call _make_channel
319+
resolver = aiodns.DNSResolver(loop=mock_loop)
320+
321+
# Check that event_thread is False due to exception
322+
assert resolver._event_thread is False
323+
324+
# Check expected message parts in the captured log
325+
for part in expected_msg_parts:
326+
assert part in caplog.text
327+
328+
# Check unexpected message parts aren't in the captured log
329+
for part in unexpected_msg_parts:
330+
assert part not in caplog.text
331+
332+
333+
def test_win32_import_winloop_error() -> None:
334+
"""Test handling of ModuleNotFoundError when importing winloop on Windows."""
335+
# Create a mock event loop that is not a SelectorEventLoop
336+
mock_loop = unittest.mock.MagicMock(spec=asyncio.AbstractEventLoop)
337+
338+
# Setup patching for this test
339+
original_import = __import__
340+
341+
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
342+
if name == "winloop":
343+
raise ModuleNotFoundError("No module named 'winloop'")
344+
return original_import(name, *args, **kwargs)
345+
346+
# Patch the Channel class to avoid creating real network resources
347+
mock_channel = unittest.mock.MagicMock()
348+
349+
with (
350+
unittest.mock.patch("sys.platform", "win32"),
351+
unittest.mock.patch("aiodns.pycares.ares_threadsafety", return_value=False),
352+
unittest.mock.patch("builtins.__import__", side_effect=mock_import),
353+
unittest.mock.patch("importlib.import_module", side_effect=mock_import),
354+
# Also patch Channel creation to avoid socket resource leak
355+
unittest.mock.patch("aiodns.pycares.Channel", return_value=mock_channel),
356+
pytest.raises(RuntimeError, match=aiodns.WINDOWS_SELECTOR_ERR_MSG),
357+
):
358+
aiodns.DNSResolver(loop=mock_loop)
359+
360+
361+
def test_win32_winloop_not_loop_instance() -> None:
362+
"""Test handling of a loop that is not a winloop.Loop instance on Windows."""
363+
# Create a mock event loop that is not a SelectorEventLoop
364+
mock_loop = unittest.mock.MagicMock(spec=asyncio.AbstractEventLoop)
365+
366+
original_import = __import__
367+
368+
# Create a mock winloop module with a Loop class that's an actual type
369+
class MockLoop:
370+
pass
371+
372+
mock_winloop_module = unittest.mock.MagicMock()
373+
mock_winloop_module.Loop = MockLoop
374+
375+
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
376+
if name == "winloop":
377+
return mock_winloop_module
378+
return original_import(name, *args, **kwargs)
379+
380+
# Patch the Channel class to avoid creating real network resources
381+
mock_channel = unittest.mock.MagicMock()
382+
383+
with (
384+
unittest.mock.patch("sys.platform", "win32"),
385+
unittest.mock.patch("aiodns.pycares.ares_threadsafety", return_value=False),
386+
unittest.mock.patch("builtins.__import__", side_effect=mock_import),
387+
unittest.mock.patch("importlib.import_module", side_effect=mock_import),
388+
# Also patch Channel creation to avoid socket resource leak
389+
unittest.mock.patch("aiodns.pycares.Channel", return_value=mock_channel),
390+
pytest.raises(RuntimeError, match=aiodns.WINDOWS_SELECTOR_ERR_MSG),
391+
):
392+
aiodns.DNSResolver(loop=mock_loop)
393+
394+
395+
def test_win32_winloop_loop_instance() -> None:
396+
"""Test handling of a loop that IS a winloop.Loop instance on Windows."""
397+
398+
# Create a mock winloop module with a Loop class
399+
class MockLoop:
400+
pass
401+
402+
# Create a mock event loop that IS a winloop.Loop instance
403+
mock_loop = unittest.mock.MagicMock(spec=asyncio.AbstractEventLoop)
404+
# Make isinstance check pass
405+
mock_loop.__class__ = MockLoop # type: ignore[assignment]
406+
407+
mock_winloop_module = unittest.mock.MagicMock()
408+
mock_winloop_module.Loop = MockLoop
409+
410+
original_import = __import__
411+
412+
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
413+
if name == "winloop":
414+
return mock_winloop_module
415+
return original_import(name, *args, **kwargs)
416+
417+
# Mock channel creation to avoid actual DNS resolution
418+
mock_channel = unittest.mock.MagicMock()
419+
232420
with (
233-
pytest.raises(RuntimeError, match="aiodns needs a SelectorEventLoop on Windows"),
234-
unittest.mock.patch('aiodns.pycares.ares_threadsafety', return_value=False)
421+
unittest.mock.patch("sys.platform", "win32"),
422+
unittest.mock.patch("aiodns.pycares.ares_threadsafety", return_value=False),
423+
unittest.mock.patch("builtins.__import__", side_effect=mock_import),
424+
unittest.mock.patch("importlib.import_module", side_effect=mock_import),
425+
unittest.mock.patch("aiodns.pycares.Channel", return_value=mock_channel),
235426
):
236-
aiodns.DNSResolver(loop=asyncio.new_event_loop(), timeout=5.0)
427+
# This should not raise an exception since loop is a winloop.Loop instance
428+
aiodns.DNSResolver(loop=mock_loop)
237429

238430

239431
if __name__ == "__main__": # pragma: no cover

0 commit comments

Comments
 (0)