1
-
2
1
import asyncio
3
2
import functools
4
- import pycares
5
3
import socket
6
4
import sys
5
+ from collections .abc import Iterable , Sequence
6
+ from typing import Any , Literal , Optional , TypeVar , Union , overload
7
7
8
+ import pycares
8
9
from typing import (
9
10
Any ,
10
11
Callable ,
22
23
23
24
__all__ = ('DNSResolver' , 'error' )
24
25
26
+ _T = TypeVar ("_T" )
27
+
25
28
WINDOWS_SELECTOR_ERR_MSG = (
26
29
"aiodns needs a SelectorEventLoop on Windows. See more: "
27
30
"https://github.com/aio-libs/aiodns#note-for-windows-users"
55
58
class DNSResolver :
56
59
def __init__ (self , nameservers : Optional [Sequence [str ]] = None ,
57
60
loop : Optional [asyncio .AbstractEventLoop ] = None ,
58
- ** kwargs : Any ) -> None :
61
+ ** kwargs : Any ) -> None : # TODO(PY311): Use Unpack for kwargs.
59
62
self .loop = loop or asyncio .get_event_loop ()
60
63
assert self .loop is not None
61
64
kwargs .pop ('sock_state_cb' , None )
@@ -80,31 +83,33 @@ def __init__(self, nameservers: Optional[Sequence[str]] = None,
80
83
** kwargs )
81
84
if nameservers :
82
85
self .nameservers = nameservers
83
- self ._read_fds = set () # type: Set[int]
84
- self ._write_fds = set () # type: Set[int]
85
- self ._timer = None # type : Optional[asyncio.TimerHandle]
86
+ self ._read_fds : set [ int ] = set ()
87
+ self ._write_fds : set [ int ] = set ()
88
+ self ._timer : Optional [asyncio .TimerHandle ] = None
86
89
87
90
@property
88
91
def nameservers (self ) -> Sequence [str ]:
89
92
return self ._channel .servers
90
93
91
94
@nameservers .setter
92
- def nameservers (self , value : Sequence [str ]) -> None :
93
- self ._channel .servers = value if isinstance (value , list ) else list (value )
95
+ def nameservers (self , value : Iterable [Union [str , bytes ]]) -> None :
96
+ # Remove type ignore after mypy 1.16.0
97
+ # https://github.com/python/mypy/issues/12892
98
+ self ._channel .servers = value # type: ignore[assignment]
94
99
95
100
@staticmethod
96
- def _callback (fut : asyncio .Future , result : Any , errorno : int ) -> None :
101
+ def _callback (fut : asyncio .Future [ _T ] , result : _T , errorno : Optional [ int ] ) -> None :
97
102
if fut .cancelled ():
98
103
return
99
104
if errorno is not None :
100
105
fut .set_exception (error .DNSError (errorno , pycares .errno .strerror (errorno )))
101
106
else :
102
107
fut .set_result (result )
103
108
104
- def _get_future_callback (self ) -> Tuple ["asyncio.Future[Any ]" , Callable [[Any , int ], None ]]:
109
+ def _get_future_callback (self ) -> Tuple ["asyncio.Future[_T ]" , Callable [[_T , int ], None ]]:
105
110
"""Return a future and a callback to set the result of the future."""
106
- cb : Callable [[Any , int ], None ]
107
- future : "asyncio.Future[Any ]" = self .loop .create_future ()
111
+ cb : Callable [[_T , int ], None ]
112
+ future : "asyncio.Future[_T ]" = self .loop .create_future ()
108
113
if self ._event_thread :
109
114
cb = functools .partial ( # type: ignore[assignment]
110
115
self .loop .call_soon_threadsafe ,
@@ -115,7 +120,41 @@ def _get_future_callback(self) -> Tuple["asyncio.Future[Any]", Callable[[Any, in
115
120
cb = functools .partial (self ._callback , future )
116
121
return future , cb
117
122
118
- def query (self , host : str , qtype : str , qclass : Optional [str ]= None ) -> asyncio .Future :
123
+ @overload
124
+ def query (self , host : str , qtype : Literal ["A" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_a_result ]]:
125
+ ...
126
+ @overload
127
+ def query (self , host : str , qtype : Literal ["AAAA" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_aaaa_result ]]:
128
+ ...
129
+ @overload
130
+ def query (self , host : str , qtype : Literal ["CAA" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_caa_result ]]:
131
+ ...
132
+ @overload
133
+ def query (self , host : str , qtype : Literal ["CNAME" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_cname_result ]]:
134
+ ...
135
+ @overload
136
+ def query (self , host : str , qtype : Literal ["MX" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_mx_result ]]:
137
+ ...
138
+ @overload
139
+ def query (self , host : str , qtype : Literal ["NAPTR" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_naptr_result ]]:
140
+ ...
141
+ @overload
142
+ def query (self , host : str , qtype : Literal ["NS" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_ns_result ]]:
143
+ ...
144
+ @overload
145
+ def query (self , host : str , qtype : Literal ["PTR" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_ptr_result ]]:
146
+ ...
147
+ @overload
148
+ def query (self , host : str , qtype : Literal ["SOA" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_soa_result ]]:
149
+ ...
150
+ @overload
151
+ def query (self , host : str , qtype : Literal ["SRV" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_srv_result ]]:
152
+ ...
153
+ @overload
154
+ def query (self , host : str , qtype : Literal ["TXT" ], qclass : Optional [str ] = ...) -> asyncio .Future [list [pycares .ares_query_txt_result ]]:
155
+ ...
156
+
157
+ def query (self , host : str , qtype : str , qclass : Optional [str ]= None ) -> asyncio .Future [list [Any ]]:
119
158
try :
120
159
qtype = query_type_map [qtype ]
121
160
except KeyError :
@@ -126,30 +165,35 @@ def query(self, host: str, qtype: str, qclass: Optional[str]=None) -> asyncio.Fu
126
165
except KeyError :
127
166
raise ValueError ('invalid query class: {}' .format (qclass ))
128
167
168
+ fut : asyncio .Future [list [Any ]]
129
169
fut , cb = self ._get_future_callback ()
130
170
self ._channel .query (host , qtype , cb , query_class = qclass )
131
171
return fut
132
172
133
- def gethostbyname (self , host : str , family : socket .AddressFamily ) -> asyncio .Future :
173
+ def gethostbyname (self , host : str , family : socket .AddressFamily ) -> asyncio .Future [pycares .ares_host_result ]:
174
+ fut : asyncio .Future [pycares .ares_host_result ]
134
175
fut , cb = self ._get_future_callback ()
135
176
self ._channel .gethostbyname (host , family , cb )
136
177
return fut
137
178
138
- def getaddrinfo (self , host : str , family : socket .AddressFamily = socket .AF_UNSPEC , port : Optional [int ] = None , proto : int = 0 , type : int = 0 , flags : int = 0 ) -> asyncio .Future :
179
+ def getaddrinfo (self , host : str , family : socket .AddressFamily = socket .AF_UNSPEC , port : Optional [int ] = None , proto : int = 0 , type : int = 0 , flags : int = 0 ) -> asyncio .Future [pycares .ares_addrinfo_result ]:
180
+ fut : asyncio .Future [pycares .ares_addrinfo_result ]
139
181
fut , cb = self ._get_future_callback ()
140
182
self ._channel .getaddrinfo (host , port , cb , family = family , type = type , proto = proto , flags = flags )
141
183
return fut
142
184
143
- def getnameinfo (self , sockaddr : Union [Tuple [str , int ], Tuple [str , int , int , int ]], flags : int = 0 ) -> asyncio .Future :
185
+ def getnameinfo (self , sockaddr : Union [tuple [str , int ], tuple [str , int , int , int ]], flags : int = 0 ) -> asyncio .Future [pycares .ares_nameinfo_result ]:
186
+ fut : asyncio .Future [pycares .ares_nameinfo_result ]
144
187
fut , cb = self ._get_future_callback ()
145
188
self ._channel .getnameinfo (sockaddr , flags , cb )
146
189
return fut
147
190
148
- def gethostbyaddr (self , name : str ) -> asyncio .Future :
191
+ def gethostbyaddr (self , name : str ) -> asyncio .Future [pycares .ares_host_result ]:
192
+ fut : asyncio .Future [pycares .ares_host_result ]
149
193
fut , cb = self ._get_future_callback ()
150
194
self ._channel .gethostbyaddr (name , cb )
151
195
return fut
152
-
196
+
153
197
def cancel (self ) -> None :
154
198
self ._channel .cancel ()
155
199
@@ -177,7 +221,7 @@ def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None:
177
221
self ._timer .cancel ()
178
222
self ._timer = None
179
223
180
- def _handle_event (self , fd : int , event : Any ) -> None :
224
+ def _handle_event (self , fd : int , event : int ) -> None :
181
225
read_fd = pycares .ARES_SOCKET_BAD
182
226
write_fd = pycares .ARES_SOCKET_BAD
183
227
if event == READ :
@@ -193,7 +237,7 @@ def _timer_cb(self) -> None:
193
237
else :
194
238
self ._timer = None
195
239
196
- def _start_timer (self ):
240
+ def _start_timer (self ) -> None :
197
241
timeout = self ._timeout
198
242
if timeout is None or timeout < 0 or timeout > 1 :
199
243
timeout = 1
0 commit comments