Skip to content

Commit ddfd3c3

Browse files
authored
Merge pull request #339 from Dstack-TEE/fix-sdks
sdks: set roofs_hash optional in TcbInfo & increase default timeout in Python SDK
2 parents dbe916c + 70ce94f commit ddfd3c3

File tree

5 files changed

+37
-15
lines changed

5 files changed

+37
-15
lines changed

sdk/js/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@phala/dstack-sdk",
3-
"version": "0.5.5",
3+
"version": "0.5.6",
44
"description": "dstack SDK",
55
"main": "dist/node/index.js",
66
"types": "dist/node/index.d.ts",

sdk/js/src/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ export interface TcbInfo {
4949
}
5050

5151
export type TcbInfoV03x = TcbInfo & {
52-
rootfs_hash: string
52+
rootfs_hash?: string
5353
}
5454

5555
export type TcbInfoV05x = TcbInfo & {

sdk/js/src/send-rpc-request.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import http from 'http'
66
import https from 'https'
77
import net from 'net'
88

9-
export const __version__ = "0.5.0"
9+
export const __version__ = "0.5.6"
1010

1111

1212
export function send_rpc_request<T = any>(endpoint: string, path: string, payload: string, timeoutMs?: number): Promise<T> {

sdk/python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
[project]
66
name = "dstack-sdk"
7-
version = "0.5.1"
7+
version = "0.5.2"
88
description = "dstack SDK for Python"
99
authors = [
1010
{name = "Leechael Yim", email = "[email protected]"},

sdk/python/src/dstack_sdk/dstack_client.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
logger = logging.getLogger("dstack_sdk")
2525

26-
__version__ = "0.2.0"
26+
__version__ = "0.5.2"
2727

2828

2929
INIT_MR = "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
@@ -173,7 +173,7 @@ class TcbInfo(BaseModel):
173173
class TcbInfoV03x(TcbInfo):
174174
"""TCB information for dstack OS version 0.3.x."""
175175

176-
rootfs_hash: str
176+
rootfs_hash: Optional[str] = None
177177

178178

179179
class TcbInfoV05x(TcbInfo):
@@ -226,19 +226,27 @@ class BaseClient:
226226
class AsyncDstackClient(BaseClient):
227227
PATH_PREFIX = "/"
228228

229-
def __init__(self, endpoint: str | None = None, use_sync_http: bool = False):
229+
def __init__(
230+
self,
231+
endpoint: str | None = None,
232+
*,
233+
use_sync_http: bool = False,
234+
timeout: float = 3,
235+
):
230236
"""Initialize async client with HTTP or Unix-socket transport.
231237
232238
Args:
233239
endpoint: HTTP/HTTPS URL or Unix socket path
234240
use_sync_http: If True, use sync HTTP client internally
241+
timeout: Timeout in seconds
235242
236243
"""
237244
endpoint = get_endpoint(endpoint)
238245
self.use_sync_http = use_sync_http
239246
self._client: Optional[httpx.AsyncClient] = None
240247
self._sync_client: Optional[httpx.Client] = None
241248
self._client_ref_count = 0
249+
self._timeout = timeout
242250

243251
if endpoint.startswith("http://") or endpoint.startswith("https://"):
244252
self.async_transport = httpx.AsyncHTTPTransport()
@@ -255,14 +263,18 @@ def __init__(self, endpoint: str | None = None, use_sync_http: bool = False):
255263
def _get_client(self) -> httpx.AsyncClient:
256264
if self._client is None:
257265
self._client = httpx.AsyncClient(
258-
transport=self.async_transport, base_url=self.base_url, timeout=0.5
266+
transport=self.async_transport,
267+
base_url=self.base_url,
268+
timeout=self._timeout,
259269
)
260270
return self._client
261271

262272
def _get_sync_client(self) -> httpx.Client:
263273
if self._sync_client is None:
264274
self._sync_client = httpx.Client(
265-
transport=self.sync_transport, base_url=self.base_url, timeout=0.5
275+
transport=self.sync_transport,
276+
base_url=self.base_url,
277+
timeout=self._timeout,
266278
)
267279
return self._sync_client
268280

@@ -392,13 +404,15 @@ async def is_reachable(self) -> bool:
392404
class DstackClient(BaseClient):
393405
PATH_PREFIX = "/"
394406

395-
def __init__(self, endpoint: str | None = None):
407+
def __init__(self, endpoint: str | None = None, *, timeout: float = 3):
396408
"""Initialize client with HTTP or Unix-socket transport.
397409
398410
If a non-HTTP(S) endpoint is provided, it is treated as a Unix socket
399411
path and validated for existence.
400412
"""
401-
self.async_client = AsyncDstackClient(endpoint, use_sync_http=True)
413+
self.async_client = AsyncDstackClient(
414+
endpoint, use_sync_http=True, timeout=timeout
415+
)
402416

403417
@call_async
404418
def get_key(
@@ -463,7 +477,13 @@ class AsyncTappdClient(AsyncDstackClient):
463477
DEPRECATED: Use ``AsyncDstackClient`` instead.
464478
"""
465479

466-
def __init__(self, endpoint: str | None = None, use_sync_http: bool = False):
480+
def __init__(
481+
self,
482+
endpoint: str | None = None,
483+
*,
484+
use_sync_http: bool = False,
485+
timeout: float = 3,
486+
):
467487
"""Initialize deprecated async tappd client wrapper."""
468488
if not use_sync_http:
469489
# Already warned in TappdClient.__init__
@@ -472,7 +492,7 @@ def __init__(self, endpoint: str | None = None, use_sync_http: bool = False):
472492
)
473493

474494
endpoint = get_tappd_endpoint(endpoint)
475-
super().__init__(endpoint, use_sync_http=use_sync_http)
495+
super().__init__(endpoint, use_sync_http=use_sync_http, timeout=timeout)
476496
# Set the correct path prefix for tappd
477497
self.PATH_PREFIX = "/prpc/Tappd."
478498

@@ -542,13 +562,15 @@ class TappdClient(DstackClient):
542562
DEPRECATED: Use ``DstackClient`` instead.
543563
"""
544564

545-
def __init__(self, endpoint: str | None = None):
565+
def __init__(self, endpoint: str | None = None, timeout: float = 3):
546566
"""Initialize deprecated tappd client wrapper."""
547567
emit_deprecation_warning(
548568
"TappdClient is deprecated, please use DstackClient instead"
549569
)
550570
endpoint = get_tappd_endpoint(endpoint)
551-
self.async_client = AsyncTappdClient(endpoint, use_sync_http=True)
571+
self.async_client = AsyncTappdClient(
572+
endpoint, use_sync_http=True, timeout=timeout
573+
)
552574

553575
@call_async
554576
def derive_key(

0 commit comments

Comments
 (0)