@@ -226,7 +226,7 @@ class BaseClient:
226226class AsyncDstackClient (BaseClient ):
227227 PATH_PREFIX = "/"
228228
229- def __init__ (self , endpoint : str | None = None , use_sync_http : bool = False ):
229+ def __init__ (self , endpoint : str | None = None , * , use_sync_http : bool = False , timeout : float = 3 ):
230230 """Initialize async client with HTTP or Unix-socket transport.
231231
232232 Args:
@@ -239,6 +239,7 @@ def __init__(self, endpoint: str | None = None, use_sync_http: bool = False):
239239 self ._client : Optional [httpx .AsyncClient ] = None
240240 self ._sync_client : Optional [httpx .Client ] = None
241241 self ._client_ref_count = 0
242+ self ._timeout = timeout
242243
243244 if endpoint .startswith ("http://" ) or endpoint .startswith ("https://" ):
244245 self .async_transport = httpx .AsyncHTTPTransport ()
@@ -255,14 +256,14 @@ def __init__(self, endpoint: str | None = None, use_sync_http: bool = False):
255256 def _get_client (self ) -> httpx .AsyncClient :
256257 if self ._client is None :
257258 self ._client = httpx .AsyncClient (
258- transport = self .async_transport , base_url = self .base_url , timeout = 0.5
259+ transport = self .async_transport , base_url = self .base_url , timeout = self . _timeout
259260 )
260261 return self ._client
261262
262263 def _get_sync_client (self ) -> httpx .Client :
263264 if self ._sync_client is None :
264265 self ._sync_client = httpx .Client (
265- transport = self .sync_transport , base_url = self .base_url , timeout = 0.5
266+ transport = self .sync_transport , base_url = self .base_url , timeout = self . _timeout
266267 )
267268 return self ._sync_client
268269
@@ -392,13 +393,13 @@ async def is_reachable(self) -> bool:
392393class DstackClient (BaseClient ):
393394 PATH_PREFIX = "/"
394395
395- def __init__ (self , endpoint : str | None = None ):
396+ def __init__ (self , endpoint : str | None = None , * , timeout : float = 3 ):
396397 """Initialize client with HTTP or Unix-socket transport.
397398
398399 If a non-HTTP(S) endpoint is provided, it is treated as a Unix socket
399400 path and validated for existence.
400401 """
401- self .async_client = AsyncDstackClient (endpoint , use_sync_http = True )
402+ self .async_client = AsyncDstackClient (endpoint , use_sync_http = True , timeout = timeout )
402403
403404 @call_async
404405 def get_key (
@@ -463,7 +464,7 @@ class AsyncTappdClient(AsyncDstackClient):
463464 DEPRECATED: Use ``AsyncDstackClient`` instead.
464465 """
465466
466- def __init__ (self , endpoint : str | None = None , use_sync_http : bool = False ):
467+ def __init__ (self , endpoint : str | None = None , * , use_sync_http : bool = False , timeout : float = 3 ):
467468 """Initialize deprecated async tappd client wrapper."""
468469 if not use_sync_http :
469470 # Already warned in TappdClient.__init__
@@ -472,7 +473,7 @@ def __init__(self, endpoint: str | None = None, use_sync_http: bool = False):
472473 )
473474
474475 endpoint = get_tappd_endpoint (endpoint )
475- super ().__init__ (endpoint , use_sync_http = use_sync_http )
476+ super ().__init__ (endpoint , use_sync_http = use_sync_http , timeout = timeout )
476477 # Set the correct path prefix for tappd
477478 self .PATH_PREFIX = "/prpc/Tappd."
478479
@@ -542,13 +543,13 @@ class TappdClient(DstackClient):
542543 DEPRECATED: Use ``DstackClient`` instead.
543544 """
544545
545- def __init__ (self , endpoint : str | None = None ):
546+ def __init__ (self , endpoint : str | None = None , timeout : float = 3 ):
546547 """Initialize deprecated tappd client wrapper."""
547548 emit_deprecation_warning (
548549 "TappdClient is deprecated, please use DstackClient instead"
549550 )
550551 endpoint = get_tappd_endpoint (endpoint )
551- self .async_client = AsyncTappdClient (endpoint , use_sync_http = True )
552+ self .async_client = AsyncTappdClient (endpoint , use_sync_http = True , timeout = timeout )
552553
553554 @call_async
554555 def derive_key (
0 commit comments