7
7
8
8
from typing import (
9
9
Any ,
10
+ Callable ,
10
11
Optional ,
11
12
Set ,
12
13
Sequence ,
21
22
22
23
__all__ = ('DNSResolver' , 'error' )
23
24
25
+ WINDOWS_SELECTOR_ERR_MSG = (
26
+ "aiodns needs a SelectorEventLoop on Windows. See more: "
27
+ "https://github.com/aio-libs/aiodns#note-for-windows-users"
28
+ )
29
+
24
30
25
31
READ = 1
26
32
WRITE = 2
@@ -52,22 +58,26 @@ def __init__(self, nameservers: Optional[Sequence[str]] = None,
52
58
** kwargs : Any ) -> None :
53
59
self .loop = loop or asyncio .get_event_loop ()
54
60
assert self .loop is not None
55
- if sys .platform == 'win32' :
56
- if not isinstance (self .loop , asyncio .SelectorEventLoop ):
61
+ kwargs .pop ('sock_state_cb' , None )
62
+ timeout = kwargs .pop ('timeout' , None )
63
+ self ._timeout = timeout
64
+ self ._event_thread = hasattr (pycares ,"ares_threadsafety" ) and pycares .ares_threadsafety ()
65
+ if self ._event_thread :
66
+ # pycares is thread safe
67
+ self ._channel = pycares .Channel (event_thread = True ,
68
+ timeout = timeout ,
69
+ ** kwargs )
70
+ else :
71
+ if sys .platform == 'win32' and not isinstance (self .loop , asyncio .SelectorEventLoop ):
57
72
try :
58
73
import winloop
59
74
if not isinstance (self .loop , winloop .Loop ):
60
- raise RuntimeError (
61
- 'aiodns needs a SelectorEventLoop on Windows. See more: https://github.com/saghul/aiodns/issues/86' )
75
+ raise RuntimeError (WINDOWS_SELECTOR_ERR_MSG )
62
76
except ModuleNotFoundError :
63
- raise RuntimeError (
64
- 'aiodns needs a SelectorEventLoop on Windows. See more: https://github.com/saghul/aiodns/issues/86' )
65
- kwargs .pop ('sock_state_cb' , None )
66
- timeout = kwargs .pop ('timeout' , None )
67
- self ._timeout = timeout
68
- self ._channel = pycares .Channel (sock_state_cb = self ._sock_state_cb ,
69
- timeout = timeout ,
70
- ** kwargs )
77
+ raise RuntimeError (WINDOWS_SELECTOR_ERR_MSG )
78
+ self ._channel = pycares .Channel (sock_state_cb = self ._sock_state_cb ,
79
+ timeout = timeout ,
80
+ ** kwargs )
71
81
if nameservers :
72
82
self .nameservers = nameservers
73
83
self ._read_fds = set () # type: Set[int]
@@ -91,6 +101,20 @@ def _callback(fut: asyncio.Future, result: Any, errorno: int) -> None:
91
101
else :
92
102
fut .set_result (result )
93
103
104
+ def _get_future_callback (self ) -> Tuple ["asyncio.Future[Any]" , Callable [[Any , int ], None ]]:
105
+ """Return a future and a callback to set the result of the future."""
106
+ cb : Callable [[Any , int ], None ]
107
+ future : "asyncio.Future[Any]" = self .loop .create_future ()
108
+ if self ._event_thread :
109
+ cb = functools .partial ( # type: ignore[assignment]
110
+ self .loop .call_soon_threadsafe ,
111
+ self ._callback , # type: ignore[arg-type]
112
+ future
113
+ )
114
+ else :
115
+ cb = functools .partial (self ._callback , future )
116
+ return future , cb
117
+
94
118
def query (self , host : str , qtype : str , qclass : Optional [str ]= None ) -> asyncio .Future :
95
119
try :
96
120
qtype = query_type_map [qtype ]
@@ -102,32 +126,27 @@ def query(self, host: str, qtype: str, qclass: Optional[str]=None) -> asyncio.Fu
102
126
except KeyError :
103
127
raise ValueError ('invalid query class: {}' .format (qclass ))
104
128
105
- fut = asyncio .Future (loop = self .loop ) # type: asyncio.Future
106
- cb = functools .partial (self ._callback , fut )
129
+ fut , cb = self ._get_future_callback ()
107
130
self ._channel .query (host , qtype , cb , query_class = qclass )
108
131
return fut
109
132
110
133
def gethostbyname (self , host : str , family : socket .AddressFamily ) -> asyncio .Future :
111
- fut = asyncio .Future (loop = self .loop ) # type: asyncio.Future
112
- cb = functools .partial (self ._callback , fut )
134
+ fut , cb = self ._get_future_callback ()
113
135
self ._channel .gethostbyname (host , family , cb )
114
136
return fut
115
137
116
138
def getaddrinfo (self , host : str , family : socket .AddressFamily = socket .AF_UNSPEC , port : Optional [int ] = None , proto : int = 0 , type : int = 0 , flags : int = 0 ) -> asyncio .Future :
117
- fut = asyncio .Future (loop = self .loop ) # type: asyncio.Future
118
- cb = functools .partial (self ._callback , fut )
139
+ fut , cb = self ._get_future_callback ()
119
140
self ._channel .getaddrinfo (host , port , cb , family = family , type = type , proto = proto , flags = flags )
120
141
return fut
121
142
122
143
def getnameinfo (self , sockaddr : Union [Tuple [str , int ], Tuple [str , int , int , int ]], flags : int = 0 ) -> asyncio .Future :
123
- fut = asyncio .Future (loop = self .loop ) # type: asyncio.Future
124
- cb = functools .partial (self ._callback , fut )
144
+ fut , cb = self ._get_future_callback ()
125
145
self ._channel .getnameinfo (sockaddr , flags , cb )
126
146
return fut
127
147
128
148
def gethostbyaddr (self , name : str ) -> asyncio .Future :
129
- fut = asyncio .Future (loop = self .loop ) # type: asyncio.Future
130
- cb = functools .partial (self ._callback , fut )
149
+ fut , cb = self ._get_future_callback ()
131
150
self ._channel .gethostbyaddr (name , cb )
132
151
return fut
133
152
0 commit comments