|
1 | 1 | import asyncio
|
2 | 2 | import socket
|
| 3 | +import weakref |
3 | 4 | from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union
|
4 | 5 |
|
5 | 6 | from .abc import AbstractResolver, ResolveResult
|
@@ -93,7 +94,17 @@ def __init__(
|
93 | 94 | if aiodns is None:
|
94 | 95 | raise RuntimeError("Resolver requires aiodns library")
|
95 | 96 |
|
96 |
| - self._resolver = aiodns.DNSResolver(*args, **kwargs) |
| 97 | + self._loop = asyncio.get_running_loop() |
| 98 | + self._manager: Optional[_DNSResolverManager] = None |
| 99 | + # If custom args are provided, create a dedicated resolver instance |
| 100 | + # This means each AsyncResolver with custom args gets its own |
| 101 | + # aiodns.DNSResolver instance |
| 102 | + if args or kwargs: |
| 103 | + self._resolver = aiodns.DNSResolver(*args, **kwargs) |
| 104 | + return |
| 105 | + # Use the shared resolver from the manager for default arguments |
| 106 | + self._manager = _DNSResolverManager() |
| 107 | + self._resolver = self._manager.get_resolver(self, self._loop) |
97 | 108 |
|
98 | 109 | if not hasattr(self._resolver, "gethostbyname"):
|
99 | 110 | # aiodns 1.1 is not available, fallback to DNSResolver.query
|
@@ -180,7 +191,78 @@ async def _resolve_with_query(
|
180 | 191 | return hosts
|
181 | 192 |
|
182 | 193 | async def close(self) -> None:
|
| 194 | + if self._manager: |
| 195 | + # Release the resolver from the manager if using the shared resolver |
| 196 | + self._manager.release_resolver(self, self._loop) |
| 197 | + self._manager = None # Clear reference to manager |
| 198 | + self._resolver = None # type: ignore[assignment] # Clear reference to resolver |
| 199 | + return |
| 200 | + # Otherwise cancel our dedicated resolver |
183 | 201 | self._resolver.cancel()
|
| 202 | + self._resolver = None # type: ignore[assignment] # Clear reference |
| 203 | + |
| 204 | + |
| 205 | +class _DNSResolverManager: |
| 206 | + """Manager for aiodns.DNSResolver objects. |
| 207 | +
|
| 208 | + This class manages shared aiodns.DNSResolver instances |
| 209 | + with no custom arguments across different event loops. |
| 210 | + """ |
| 211 | + |
| 212 | + _instance: Optional["_DNSResolverManager"] = None |
| 213 | + |
| 214 | + def __new__(cls) -> "_DNSResolverManager": |
| 215 | + if cls._instance is None: |
| 216 | + cls._instance = super().__new__(cls) |
| 217 | + cls._instance._init() |
| 218 | + return cls._instance |
| 219 | + |
| 220 | + def _init(self) -> None: |
| 221 | + # Use WeakKeyDictionary to allow event loops to be garbage collected |
| 222 | + self._loop_data: weakref.WeakKeyDictionary[ |
| 223 | + asyncio.AbstractEventLoop, |
| 224 | + tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]], |
| 225 | + ] = weakref.WeakKeyDictionary() |
| 226 | + |
| 227 | + def get_resolver( |
| 228 | + self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop |
| 229 | + ) -> "aiodns.DNSResolver": |
| 230 | + """Get or create the shared aiodns.DNSResolver instance for a specific event loop. |
| 231 | +
|
| 232 | + Args: |
| 233 | + client: The AsyncResolver instance requesting the resolver. |
| 234 | + This is required to track resolver usage. |
| 235 | + loop: The event loop to use for the resolver. |
| 236 | + """ |
| 237 | + # Create a new resolver and client set for this loop if it doesn't exist |
| 238 | + if loop not in self._loop_data: |
| 239 | + resolver = aiodns.DNSResolver(loop=loop) |
| 240 | + client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet() |
| 241 | + self._loop_data[loop] = (resolver, client_set) |
| 242 | + else: |
| 243 | + # Get the existing resolver and client set |
| 244 | + resolver, client_set = self._loop_data[loop] |
| 245 | + |
| 246 | + # Register this client with the loop |
| 247 | + client_set.add(client) |
| 248 | + return resolver |
| 249 | + |
| 250 | + def release_resolver( |
| 251 | + self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop |
| 252 | + ) -> None: |
| 253 | + """Release the resolver for an AsyncResolver client when it's closed. |
| 254 | +
|
| 255 | + Args: |
| 256 | + client: The AsyncResolver instance to release. |
| 257 | + loop: The event loop the resolver was using. |
| 258 | + """ |
| 259 | + # Remove client from its loop's tracking |
| 260 | + resolver, client_set = self._loop_data[loop] |
| 261 | + client_set.discard(client) |
| 262 | + # If no more clients for this loop, cancel and remove its resolver |
| 263 | + if not client_set: |
| 264 | + resolver.cancel() |
| 265 | + del self._loop_data[loop] |
184 | 266 |
|
185 | 267 |
|
186 | 268 | _DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
|
|
0 commit comments