Skip to content

Commit caddb8a

Browse files
committed
imp(sdk/python): Schema for TcbInfo
1 parent c769b2d commit caddb8a

File tree

1 file changed

+47
-10
lines changed

1 file changed

+47
-10
lines changed

sdk/python/src/dstack_sdk/dstack_client.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
import os
1212
from typing import Any
1313
from typing import Dict
14+
from typing import Generic
1415
from typing import List
1516
from typing import Optional
17+
from typing import TypeVar
1618
from typing import cast
1719
import warnings
1820

@@ -157,38 +159,63 @@ class EventLog(BaseModel):
157159

158160

159161
class TcbInfo(BaseModel):
162+
"""Base TCB (Trusted Computing Base) information structure."""
163+
160164
mrtd: str
161165
rtmr0: str
162166
rtmr1: str
163167
rtmr2: str
164168
rtmr3: str
165-
os_image_hash: str = ""
166-
compose_hash: str
167-
device_id: str
168169
app_compose: str
169170
event_log: List[EventLog]
170171

171172

172-
class InfoResponse(BaseModel):
173+
class TcbInfoV03x(TcbInfo):
174+
"""TCB information for dstack OS version 0.3.x."""
175+
176+
rootfs_hash: str
177+
178+
179+
class TcbInfoV05x(TcbInfo):
180+
"""TCB information for dstack OS version 0.5.x."""
181+
182+
mr_aggregated: str
183+
os_image_hash: str
184+
compose_hash: str
185+
device_id: str
186+
187+
188+
# Type variable for TCB info versions
189+
T = TypeVar("T", bound=TcbInfo)
190+
191+
192+
class InfoResponse(BaseModel, Generic[T]):
173193
app_id: str
174194
instance_id: str
175195
app_cert: str
176-
tcb_info: TcbInfo
196+
tcb_info: T
177197
app_name: str
178198
device_id: str
179199
os_image_hash: str = ""
180200
key_provider_info: str
181201
compose_hash: str
182202

183203
@classmethod
184-
def parse_response(cls, obj: Any) -> "InfoResponse":
204+
def parse_response(cls, obj: Any, tcb_info_type: type[T]) -> "InfoResponse[T]":
205+
"""Parse response from service, automatically deserializing tcb_info.
206+
207+
Args:
208+
obj: Raw response object from service
209+
tcb_info_type: The specific TcbInfo subclass to use for parsing
210+
211+
"""
185212
if (
186213
isinstance(obj, dict)
187214
and "tcb_info" in obj
188215
and isinstance(obj["tcb_info"], str)
189216
):
190217
obj = dict(obj)
191-
obj["tcb_info"] = TcbInfo(**json.loads(obj["tcb_info"]))
218+
obj["tcb_info"] = tcb_info_type(**json.loads(obj["tcb_info"]))
192219
return cls(**obj)
193220

194221

@@ -311,10 +338,10 @@ async def get_quote(
311338
result = await self._send_rpc_request("GetQuote", {"report_data": hex})
312339
return GetQuoteResponse(**result)
313340

314-
async def info(self) -> InfoResponse:
341+
async def info(self) -> InfoResponse[TcbInfo]:
315342
"""Fetch service information including parsed TCB info."""
316343
result = await self._send_rpc_request("Info", {})
317-
return InfoResponse.parse_response(result)
344+
return InfoResponse.parse_response(result, TcbInfoV05x)
318345

319346
async def emit_event(
320347
self,
@@ -391,7 +418,7 @@ def get_quote(
391418
raise NotImplementedError
392419

393420
@call_async
394-
def info(self) -> InfoResponse:
421+
def info(self) -> InfoResponse[TcbInfo]:
395422
"""Fetch service information including parsed TCB info."""
396423
raise NotImplementedError
397424

@@ -503,6 +530,11 @@ async def tdx_quote(
503530

504531
return GetQuoteResponse(**result)
505532

533+
async def info(self) -> InfoResponse[TcbInfo]:
534+
"""Fetch service information including parsed TCB info."""
535+
result = await self._send_rpc_request("Info", {})
536+
return InfoResponse.parse_response(result, TcbInfoV03x)
537+
506538

507539
class TappdClient(DstackClient):
508540
"""Deprecated client kept for backward compatibility.
@@ -537,6 +569,11 @@ def tdx_quote(
537569
"""Use ``get_quote`` instead (deprecated)."""
538570
raise NotImplementedError
539571

572+
@call_async
573+
def info(self) -> InfoResponse[TcbInfo]:
574+
"""Fetch service information including parsed TCB info."""
575+
raise NotImplementedError
576+
540577
@call_async
541578
def __enter__(self):
542579
raise NotImplementedError

0 commit comments

Comments
 (0)