Skip to content

Commit ee9287c

Browse files
authored
Merge pull request #317 from Dstack-TEE/imp-sdk-tcb-info
IMP: Typing & schema for TcbInfo in SDK
2 parents b6baa52 + 9e2f060 commit ee9287c

File tree

4 files changed

+68
-19
lines changed

4 files changed

+68
-19
lines changed

sdk/js/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@phala/dstack-sdk",
3-
"version": "0.5.4",
3+
"version": "0.5.5",
44
"description": "dstack SDK",
55
"main": "dist/node/index.js",
66
"types": "dist/node/index.d.ts",

sdk/js/src/index.ts

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,26 @@ export interface TcbInfo {
4444
rtmr1: string
4545
rtmr2: string
4646
rtmr3: string
47+
app_compose: string
4748
event_log: EventLog[]
4849
}
4950

50-
export interface InfoResponse {
51+
export type TcbInfoV03x = TcbInfo & {
52+
rootfs_hash: string
53+
}
54+
55+
export type TcbInfoV05x = TcbInfo & {
56+
mr_aggregated: string
57+
os_image_hash: string
58+
compose_hash: string
59+
device_id: string
60+
}
61+
62+
export interface InfoResponse<VersionTcbInfo extends TcbInfo> {
5163
app_id: string
5264
instance_id: string
5365
app_cert: string
54-
tcb_info: TcbInfo
66+
tcb_info: VersionTcbInfo
5567
app_name: string
5668
device_id: string
5769
os_image_hash?: string // Optional: empty if OS image is not measured by KMS
@@ -132,7 +144,7 @@ export interface TlsKeyOptions {
132144
usageClientAuth?: boolean;
133145
}
134146

135-
export class DstackClient {
147+
export class DstackClient<T extends TcbInfo = TcbInfoV05x> {
136148
protected endpoint: string
137149

138150
constructor(endpoint: string | undefined = undefined) {
@@ -210,11 +222,11 @@ export class DstackClient {
210222
return Object.freeze(result)
211223
}
212224

213-
async info(): Promise<InfoResponse> {
214-
const result = await send_rpc_request<Omit<InfoResponse, 'tcb_info'> & { tcb_info: string }>(this.endpoint, '/Info', '{}')
225+
async info(): Promise<InfoResponse<T>> {
226+
const result = await send_rpc_request<Omit<InfoResponse<TcbInfo>, 'tcb_info'> & { tcb_info: string }>(this.endpoint, '/Info', '{}')
215227
return Object.freeze({
216228
...result,
217-
tcb_info: JSON.parse(result.tcb_info) as TcbInfo,
229+
tcb_info: JSON.parse(result.tcb_info) as T,
218230
})
219231
}
220232

@@ -283,7 +295,7 @@ export class DstackClient {
283295
}
284296
}
285297

286-
export class TappdClient extends DstackClient {
298+
export class TappdClient extends DstackClient<TcbInfoV03x> {
287299
constructor(endpoint: string | undefined = undefined) {
288300
if (endpoint === undefined) {
289301
if (process.env.TAPPD_SIMULATOR_ENDPOINT) {

sdk/python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
[project]
66
name = "dstack-sdk"
7-
version = "0.5.0"
7+
version = "0.5.1"
88
description = "dstack SDK for Python"
99
authors = [
1010
{name = "Leechael Yim", email = "[email protected]"},

sdk/python/src/dstack_sdk/dstack_client.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
import os
1212
from typing import Any
1313
from typing import Dict
14+
from typing import Generic
1415
from typing import List
1516
from typing import Optional
17+
from typing import TypeVar
1618
from typing import cast
1719
import warnings
1820

@@ -157,38 +159,63 @@ class EventLog(BaseModel):
157159

158160

159161
class TcbInfo(BaseModel):
162+
"""Base TCB (Trusted Computing Base) information structure."""
163+
160164
mrtd: str
161165
rtmr0: str
162166
rtmr1: str
163167
rtmr2: str
164168
rtmr3: str
165-
os_image_hash: str = ""
166-
compose_hash: str
167-
device_id: str
168169
app_compose: str
169170
event_log: List[EventLog]
170171

171172

172-
class InfoResponse(BaseModel):
173+
class TcbInfoV03x(TcbInfo):
174+
"""TCB information for dstack OS version 0.3.x."""
175+
176+
rootfs_hash: str
177+
178+
179+
class TcbInfoV05x(TcbInfo):
180+
"""TCB information for dstack OS version 0.5.x."""
181+
182+
mr_aggregated: str
183+
os_image_hash: str
184+
compose_hash: str
185+
device_id: str
186+
187+
188+
# Type variable for TCB info versions
189+
T = TypeVar("T", bound=TcbInfo)
190+
191+
192+
class InfoResponse(BaseModel, Generic[T]):
173193
app_id: str
174194
instance_id: str
175195
app_cert: str
176-
tcb_info: TcbInfo
196+
tcb_info: T
177197
app_name: str
178198
device_id: str
179199
os_image_hash: str = ""
180200
key_provider_info: str
181201
compose_hash: str
182202

183203
@classmethod
184-
def parse_response(cls, obj: Any) -> "InfoResponse":
204+
def parse_response(cls, obj: Any, tcb_info_type: type[T]) -> "InfoResponse[T]":
205+
"""Parse response from service, automatically deserializing tcb_info.
206+
207+
Args:
208+
obj: Raw response object from service
209+
tcb_info_type: The specific TcbInfo subclass to use for parsing
210+
211+
"""
185212
if (
186213
isinstance(obj, dict)
187214
and "tcb_info" in obj
188215
and isinstance(obj["tcb_info"], str)
189216
):
190217
obj = dict(obj)
191-
obj["tcb_info"] = TcbInfo(**json.loads(obj["tcb_info"]))
218+
obj["tcb_info"] = tcb_info_type(**json.loads(obj["tcb_info"]))
192219
return cls(**obj)
193220

194221

@@ -311,10 +338,10 @@ async def get_quote(
311338
result = await self._send_rpc_request("GetQuote", {"report_data": hex})
312339
return GetQuoteResponse(**result)
313340

314-
async def info(self) -> InfoResponse:
341+
async def info(self) -> InfoResponse[TcbInfo]:
315342
"""Fetch service information including parsed TCB info."""
316343
result = await self._send_rpc_request("Info", {})
317-
return InfoResponse.parse_response(result)
344+
return InfoResponse.parse_response(result, TcbInfoV05x)
318345

319346
async def emit_event(
320347
self,
@@ -391,7 +418,7 @@ def get_quote(
391418
raise NotImplementedError
392419

393420
@call_async
394-
def info(self) -> InfoResponse:
421+
def info(self) -> InfoResponse[TcbInfo]:
395422
"""Fetch service information including parsed TCB info."""
396423
raise NotImplementedError
397424

@@ -503,6 +530,11 @@ async def tdx_quote(
503530

504531
return GetQuoteResponse(**result)
505532

533+
async def info(self) -> InfoResponse[TcbInfo]:
534+
"""Fetch service information including parsed TCB info."""
535+
result = await self._send_rpc_request("Info", {})
536+
return InfoResponse.parse_response(result, TcbInfoV03x)
537+
506538

507539
class TappdClient(DstackClient):
508540
"""Deprecated client kept for backward compatibility.
@@ -537,6 +569,11 @@ def tdx_quote(
537569
"""Use ``get_quote`` instead (deprecated)."""
538570
raise NotImplementedError
539571

572+
@call_async
573+
def info(self) -> InfoResponse[TcbInfo]:
574+
"""Fetch service information including parsed TCB info."""
575+
raise NotImplementedError
576+
540577
@call_async
541578
def __enter__(self):
542579
raise NotImplementedError

0 commit comments

Comments
 (0)