|
1 | 1 | import asyncio
|
2 | 2 | import socket
|
3 |
| -from typing import Any, List, Tuple, Type, Union |
| 3 | +import weakref |
| 4 | +from typing import Any, List, Optional, Tuple, Type, Union |
4 | 5 |
|
5 | 6 | from .abc import AbstractResolver, ResolveResult
|
6 | 7 |
|
@@ -88,7 +89,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
|
88 | 89 | if aiodns is None:
|
89 | 90 | raise RuntimeError("Resolver requires aiodns library")
|
90 | 91 |
|
91 |
| - self._resolver = aiodns.DNSResolver(*args, **kwargs) |
| 92 | + self._loop = asyncio.get_running_loop() |
| 93 | + self._manager: Optional[_DNSResolverManager] = None |
| 94 | + # If custom args are provided, create a dedicated resolver instance |
| 95 | + # This means each AsyncResolver with custom args gets its own |
| 96 | + # aiodns.DNSResolver instance |
| 97 | + if args or kwargs: |
| 98 | + self._resolver = aiodns.DNSResolver(*args, **kwargs) |
| 99 | + return |
| 100 | + # Use the shared resolver from the manager for default arguments |
| 101 | + self._manager = _DNSResolverManager() |
| 102 | + self._resolver = self._manager.get_resolver(self, self._loop) |
92 | 103 |
|
93 | 104 | async def resolve(
|
94 | 105 | self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
|
@@ -142,7 +153,78 @@ async def resolve(
|
142 | 153 | return hosts
|
143 | 154 |
|
144 | 155 | async def close(self) -> None:
|
| 156 | + if self._manager: |
| 157 | + # Release the resolver from the manager if using the shared resolver |
| 158 | + self._manager.release_resolver(self, self._loop) |
| 159 | + self._manager = None # Clear reference to manager |
| 160 | + self._resolver = None # type: ignore[assignment] # Clear reference to resolver |
| 161 | + return |
| 162 | + # Otherwise cancel our dedicated resolver |
145 | 163 | self._resolver.cancel()
|
| 164 | + self._resolver = None # type: ignore[assignment] # Clear reference |
| 165 | + |
| 166 | + |
| 167 | +class _DNSResolverManager: |
| 168 | + """Manager for aiodns.DNSResolver objects. |
| 169 | +
|
| 170 | + This class manages shared aiodns.DNSResolver instances |
| 171 | + with no custom arguments across different event loops. |
| 172 | + """ |
| 173 | + |
| 174 | + _instance: Optional["_DNSResolverManager"] = None |
| 175 | + |
| 176 | + def __new__(cls) -> "_DNSResolverManager": |
| 177 | + if cls._instance is None: |
| 178 | + cls._instance = super().__new__(cls) |
| 179 | + cls._instance._init() |
| 180 | + return cls._instance |
| 181 | + |
| 182 | + def _init(self) -> None: |
| 183 | + # Use WeakKeyDictionary to allow event loops to be garbage collected |
| 184 | + self._loop_data: weakref.WeakKeyDictionary[ |
| 185 | + asyncio.AbstractEventLoop, |
| 186 | + tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]], |
| 187 | + ] = weakref.WeakKeyDictionary() |
| 188 | + |
| 189 | + def get_resolver( |
| 190 | + self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop |
| 191 | + ) -> "aiodns.DNSResolver": |
| 192 | + """Get or create the shared aiodns.DNSResolver instance for a specific event loop. |
| 193 | +
|
| 194 | + Args: |
| 195 | + client: The AsyncResolver instance requesting the resolver. |
| 196 | + This is required to track resolver usage. |
| 197 | + loop: The event loop to use for the resolver. |
| 198 | + """ |
| 199 | + # Create a new resolver and client set for this loop if it doesn't exist |
| 200 | + if loop not in self._loop_data: |
| 201 | + resolver = aiodns.DNSResolver(loop=loop) |
| 202 | + client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet() |
| 203 | + self._loop_data[loop] = (resolver, client_set) |
| 204 | + else: |
| 205 | + # Get the existing resolver and client set |
| 206 | + resolver, client_set = self._loop_data[loop] |
| 207 | + |
| 208 | + # Register this client with the loop |
| 209 | + client_set.add(client) |
| 210 | + return resolver |
| 211 | + |
| 212 | + def release_resolver( |
| 213 | + self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop |
| 214 | + ) -> None: |
| 215 | + """Release the resolver for an AsyncResolver client when it's closed. |
| 216 | +
|
| 217 | + Args: |
| 218 | + client: The AsyncResolver instance to release. |
| 219 | + loop: The event loop the resolver was using. |
| 220 | + """ |
| 221 | + # Remove client from its loop's tracking |
| 222 | + resolver, client_set = self._loop_data[loop] |
| 223 | + client_set.discard(client) |
| 224 | + # If no more clients for this loop, cancel and remove its resolver |
| 225 | + if not client_set: |
| 226 | + resolver.cancel() |
| 227 | + del self._loop_data[loop] |
146 | 228 |
|
147 | 229 |
|
148 | 230 | _DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
|
|
0 commit comments