Skip to content

Commit a055ecf

Browse files
author
Matthias Zimmermann
committed
feat: add timeout to provider builder
1 parent 6f588fa commit a055ecf

File tree

4 files changed

+247
-35
lines changed

4 files changed

+247
-35
lines changed

src/arkiv/client.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -247,14 +247,30 @@ async def is_connected(self, show_traceback: bool = False) -> bool:
247247

248248
async def __aenter__(self) -> AsyncArkiv:
249249
"""Enter async context manager."""
250-
# Initialize pending account if provided
251-
if self._pending_account:
252-
await self._initialize_account_async(self._pending_account)
253-
self._pending_account = None
250+
try:
251+
# Initialize pending account if provided
252+
if self._pending_account:
253+
await self._initialize_account_async(self._pending_account)
254+
self._pending_account = None
255+
256+
# Populate connection cache
257+
await self.is_connected()
258+
return self
259+
except Exception:
260+
# Best-effort cleanup if entering the context fails
261+
logger.debug(
262+
"AsyncArkiv.__aenter__ failed, attempting cleanup before re-raising"
263+
)
264+
try:
265+
await self.arkiv.cleanup_filters()
266+
except Exception:
267+
logger.exception(
268+
"Error while cleaning up filters after __aenter__ failure"
269+
)
254270

255-
# Populate connection cache
256-
await self.is_connected()
257-
return self
271+
await self._disconnect_provider()
272+
self._cached_connected = False
273+
raise
258274

259275
async def __aexit__(
260276
self,
@@ -267,9 +283,8 @@ async def __aexit__(
267283
logger.debug("Cleaning up event filters...")
268284
await self.arkiv.cleanup_filters()
269285

270-
# Disconnect provider if it has disconnect method
271-
if hasattr(self.provider, "disconnect"):
272-
await self.provider.disconnect()
286+
# Disconnect provider and close underlying resources
287+
await self._disconnect_provider()
273288

274289
# Then stop the node if managed and update cache
275290
self._cleanup_node()
@@ -299,6 +314,18 @@ async def _initialize_account_async(self, account: NamedAccount) -> None:
299314
f"Account balance for {account.name} ({account.address}): {balance_eth} ETH"
300315
)
301316

317+
async def _disconnect_provider(self) -> None:
318+
"""Best-effort async disconnect of the underlying provider, if supported."""
319+
provider = self.provider
320+
if provider is None:
321+
return
322+
323+
if hasattr(provider, "disconnect"):
324+
try:
325+
await provider.disconnect() # type: ignore[func-returns-value]
326+
except Exception:
327+
logger.exception("Error while disconnecting async provider")
328+
302329
def _cleanup_node(self) -> None:
303330
"""Cleanup node and update connection cache."""
304331
super()._cleanup_node()

src/arkiv/provider.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(self) -> None:
7373
self._transport: TransportType = cast(TransportType, TRANSPORT_DEFAULT)
7474
self._port: int | None = DEFAULT_PORT # Set default port for localhost
7575
self._url: str | None = None
76+
self._timeout_in: int | None = None # timeout in seconds
7677
self._node: ArkivNode | None = None
7778
self._is_async: bool = False # Default to sync providers
7879

@@ -103,7 +104,7 @@ def kaolin(self) -> ProviderBuilder:
103104
self._port = None
104105
return self
105106

106-
def custom(self, url: str) -> ProviderBuilder:
107+
def custom(self, url: str, fallback_url: str | None = None) -> ProviderBuilder:
107108
"""
108109
Configure with custom RPC URL.
109110
@@ -186,6 +187,16 @@ def ws(self) -> ProviderBuilder:
186187
self._transport = cast(TransportType, WS)
187188
return self
188189

190+
def timeout(self, seconds: int) -> ProviderBuilder:
191+
"""
192+
Sets the request timeout for the provider.
193+
194+
Args:
195+
seconds: Timeout duration in seconds
196+
"""
197+
self._timeout_in = seconds
198+
return self
199+
189200
def async_mode(self, async_provider: bool = True) -> ProviderBuilder:
190201
"""
191202
Sets the async provider mode.
@@ -276,9 +287,32 @@ def build(self) -> BaseProvider | AsyncBaseProvider:
276287
if self._transport == HTTP:
277288
# Consider async mode
278289
if self._is_async:
279-
return AsyncHTTPProvider(url)
290+
if self._timeout_in is not None:
291+
import aiohttp
292+
293+
timeout = aiohttp.ClientTimeout(total=self._timeout_in)
294+
return AsyncHTTPProvider(url, request_kwargs={"timeout": timeout})
295+
else:
296+
return AsyncHTTPProvider(url)
280297
else:
281-
return HTTPProvider(url)
298+
if self._timeout_in is not None:
299+
return HTTPProvider(
300+
url, request_kwargs={"timeout": self._timeout_in}
301+
)
302+
else:
303+
return HTTPProvider(url)
282304
# Web socket transport (always async)
283305
else:
306+
if self._timeout_in is not None:
307+
return cast(
308+
AsyncBaseProvider,
309+
WebSocketProvider(
310+
url,
311+
request_timeout=self._timeout_in,
312+
# websocket_kwargs={
313+
# "ping_interval": self._timeout_in,
314+
# "ping_timeout": self._timeout_in * 2,
315+
# },
316+
),
317+
)
284318
return cast(AsyncBaseProvider, WebSocketProvider(url))

tests/conftest.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
Provides either external node connections (via env vars) or containerized test nodes.
55
"""
66

7+
import json
78
import logging
89
import os
10+
import socketserver
11+
import threading
12+
import time
913
from collections.abc import AsyncGenerator, Generator
1014
from pathlib import Path
1115

@@ -27,6 +31,10 @@
2731
WALLET_FILE_ENV_PREFIX = "WALLET_FILE"
2832
WALLET_PASSWORD_ENV_PREFIX = "WALLET_PASSWORD"
2933

34+
SLOW_LOCAL_SERVER_HOST = "127.0.0.1"
35+
SLOW_LOCAL_SERVER_PORT = 9876
36+
SLOW_LOCAL_SERVER_TIMEOUT = 5
37+
3038
ALICE = "alice"
3139
BOB = "bob"
3240

@@ -50,6 +58,40 @@ def _load_env_if_available() -> None:
5058
_load_env_if_available()
5159

5260

61+
@pytest.fixture(scope="session")
62+
def delayed_rpc_server():
63+
class DelayedHandler(socketserver.BaseRequestHandler):
64+
def handle(self):
65+
time.sleep(SLOW_LOCAL_SERVER_TIMEOUT) # Delay > timeout
66+
self.request.sendall(
67+
json.dumps({"jsonrpc": "2.0", "id": 1, "result": 42}).encode()
68+
)
69+
70+
httpd = socketserver.TCPServer(
71+
(SLOW_LOCAL_SERVER_HOST, SLOW_LOCAL_SERVER_PORT), DelayedHandler
72+
)
73+
74+
httpd.allow_reuse_address = True
75+
76+
server_thread = threading.Thread(target=httpd.serve_forever, daemon=True)
77+
server_thread.start()
78+
79+
# Wait & verify server is listening
80+
time.sleep(1.0)
81+
import socket
82+
83+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
84+
assert sock.connect_ex((SLOW_LOCAL_SERVER_HOST, SLOW_LOCAL_SERVER_PORT)) == 0, (
85+
"Server not listening"
86+
)
87+
sock.close()
88+
89+
yield f"http://{SLOW_LOCAL_SERVER_HOST}:{SLOW_LOCAL_SERVER_PORT}"
90+
91+
httpd.shutdown()
92+
httpd.server_close()
93+
94+
5395
@pytest.fixture(scope="session")
5496
def arkiv_container() -> Generator[ArkivNode, None, None]:
5597
"""

0 commit comments

Comments
 (0)