Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions aiodns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
_T = TypeVar('_T')

WINDOWS_SELECTOR_ERR_MSG = (
'aiodns needs a SelectorEventLoop on Windows. See more: '
'aiodns cannot use ProactorEventLoop on Windows'
' if pycares has no threadsafety. See more: '
'https://github.com/aio-libs/aiodns#note-for-windows-users'
)

Expand Down Expand Up @@ -106,16 +107,13 @@ def _make_channel(self, **kwargs: Any) -> tuple[bool, pycares.Channel]:
'Falling back to socket state callback: %s',
e,
)
if sys.platform == 'win32' and not isinstance(
self.loop, asyncio.SelectorEventLoop
):
try:
import winloop
if sys.platform == 'win32':
if (
hasattr(asyncio, 'ProactorEventLoop')
and type(self.loop) is asyncio.ProactorEventLoop
):
raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG)

if not isinstance(self.loop, winloop.Loop):
raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG)
except ModuleNotFoundError as ex:
raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG) from ex
return False, pycares.Channel(
sock_state_cb=self._sock_state_cb, timeout=self._timeout, **kwargs
)
Expand Down
156 changes: 14 additions & 142 deletions tests/test_aiodns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import time
import unittest
import unittest.mock
from typing import Any

import pycares
import pytest
Expand Down Expand Up @@ -328,27 +327,6 @@ def setUp(self) -> None:
super().setUp()


@unittest.skipIf(sys.platform != 'win32', 'Only run on Windows')
def test_win32_no_selector_event_loop() -> None:
"""Test DNSResolver with Windows without SelectorEventLoop."""
# Create a non-SelectorEventLoop to trigger the error
mock_loop = unittest.mock.MagicMock(spec=asyncio.AbstractEventLoop)
mock_loop.__class__ = (
asyncio.AbstractEventLoop # type: ignore[assignment]
)

with (
pytest.raises(
RuntimeError, match='aiodns needs a SelectorEventLoop on Windows'
),
unittest.mock.patch(
'aiodns.pycares.ares_threadsafety', return_value=False
),
unittest.mock.patch('sys.platform', 'win32'),
):
aiodns.DNSResolver(loop=mock_loop, timeout=5.0)


@pytest.mark.parametrize(
('platform', 'expected_msg_parts', 'unexpected_msg_parts'),
[
Expand Down Expand Up @@ -440,132 +418,26 @@ async def test_make_channel_ares_error(
resolver._closed = True


def test_win32_import_winloop_error() -> None:
"""Test winloop import error on Windows.

Test handling of ModuleNotFoundError when importing
winloop on Windows.
"""
# Create a mock event loop that is not a SelectorEventLoop
mock_loop = unittest.mock.MagicMock(spec=asyncio.AbstractEventLoop)

# Setup patching for this test
original_import = __import__

def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'winloop':
raise ModuleNotFoundError("No module named 'winloop'")
return original_import(name, *args, **kwargs)

# Patch the Channel class to avoid creating real network resources
mock_channel = unittest.mock.MagicMock()

with (
unittest.mock.patch('sys.platform', 'win32'),
unittest.mock.patch(
'aiodns.pycares.ares_threadsafety', return_value=False
),
unittest.mock.patch('builtins.__import__', side_effect=mock_import),
unittest.mock.patch(
'importlib.import_module', side_effect=mock_import
),
# Also patch Channel creation to avoid socket resource leak
unittest.mock.patch(
'aiodns.pycares.Channel', return_value=mock_channel
),
pytest.raises(RuntimeError, match=aiodns.WINDOWS_SELECTOR_ERR_MSG),
):
aiodns.DNSResolver(loop=mock_loop)


def test_win32_winloop_not_loop_instance() -> None:
"""Test non-winloop.Loop instance on Windows.

Test handling of a loop that is not a winloop.Loop
instance on Windows.
"""
# Create a mock event loop that is not a SelectorEventLoop
mock_loop = unittest.mock.MagicMock(spec=asyncio.AbstractEventLoop)

original_import = __import__

# Create a mock winloop module with a Loop class that's an actual type
class MockLoop:
pass

mock_winloop_module = unittest.mock.MagicMock()
mock_winloop_module.Loop = MockLoop

def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'winloop':
return mock_winloop_module
return original_import(name, *args, **kwargs)

# Patch the Channel class to avoid creating real network resources
mock_channel = unittest.mock.MagicMock()

with (
unittest.mock.patch('sys.platform', 'win32'),
unittest.mock.patch(
'aiodns.pycares.ares_threadsafety', return_value=False
),
unittest.mock.patch('builtins.__import__', side_effect=mock_import),
unittest.mock.patch(
'importlib.import_module', side_effect=mock_import
),
# Also patch Channel creation to avoid socket resource leak
unittest.mock.patch(
'aiodns.pycares.Channel', return_value=mock_channel
),
pytest.raises(RuntimeError, match=aiodns.WINDOWS_SELECTOR_ERR_MSG),
):
aiodns.DNSResolver(loop=mock_loop)


def test_win32_winloop_loop_instance() -> None:
"""Test winloop.Loop instance on Windows.

Test handling of a loop that IS a winloop.Loop instance on Windows.
"""

# Create a mock winloop module with a Loop class
class MockLoop:
pass

# Create a mock event loop that IS a winloop.Loop instance
mock_loop = unittest.mock.MagicMock(spec=asyncio.AbstractEventLoop)
# Make isinstance check pass
mock_loop.__class__ = MockLoop # type: ignore[assignment]

mock_winloop_module = unittest.mock.MagicMock()
mock_winloop_module.Loop = MockLoop

original_import = __import__

def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'winloop':
return mock_winloop_module
return original_import(name, *args, **kwargs)

# Mock channel creation to avoid actual DNS resolution
mock_channel = unittest.mock.MagicMock()

@pytest.mark.skipif(
sys.platform != 'win32',
reason="ProactorEventLoop can't be simulated",
)
@pytest.mark.asyncio
async def test_runtime_error_if_windows_proactor_event_loop() -> None:
with (
unittest.mock.patch('sys.platform', 'win32'),
unittest.mock.patch(
'aiodns.pycares.ares_threadsafety', return_value=False
),
unittest.mock.patch('builtins.__import__', side_effect=mock_import),
unittest.mock.patch(
'importlib.import_module', side_effect=mock_import
),
unittest.mock.patch(
'aiodns.pycares.Channel', return_value=mock_channel
pytest.raises(
RuntimeError,
match=r'aiodns cannot use ProactorEventLoop on Windows'
r' if pycares has no threadsafety. See more: '
r'https://github.com/aio-libs/aiodns#note-for-windows-users',
),
):
# This should not raise an exception since loop
# is a winloop.Loop instance
aiodns.DNSResolver(loop=mock_loop)
# The ProactorEventLoop is chosen by default
async with aiodns.DNSResolver():
pass


@pytest.mark.asyncio
Expand Down
Loading