|
11 | 11 | import os |
12 | 12 | from typing import Any |
13 | 13 | from typing import Dict |
| 14 | +from typing import Generic |
14 | 15 | from typing import List |
15 | 16 | from typing import Optional |
| 17 | +from typing import TypeVar |
16 | 18 | from typing import cast |
17 | 19 | import warnings |
18 | 20 |
|
@@ -157,38 +159,63 @@ class EventLog(BaseModel): |
157 | 159 |
|
158 | 160 |
|
159 | 161 | class TcbInfo(BaseModel): |
| 162 | + """Base TCB (Trusted Computing Base) information structure.""" |
| 163 | + |
160 | 164 | mrtd: str |
161 | 165 | rtmr0: str |
162 | 166 | rtmr1: str |
163 | 167 | rtmr2: str |
164 | 168 | rtmr3: str |
165 | | - os_image_hash: str = "" |
166 | | - compose_hash: str |
167 | | - device_id: str |
168 | 169 | app_compose: str |
169 | 170 | event_log: List[EventLog] |
170 | 171 |
|
171 | 172 |
|
172 | | -class InfoResponse(BaseModel): |
| 173 | +class TcbInfoV03x(TcbInfo): |
| 174 | + """TCB information for dstack OS version 0.3.x.""" |
| 175 | + |
| 176 | + rootfs_hash: str |
| 177 | + |
| 178 | + |
| 179 | +class TcbInfoV05x(TcbInfo): |
| 180 | + """TCB information for dstack OS version 0.5.x.""" |
| 181 | + |
| 182 | + mr_aggregated: str |
| 183 | + os_image_hash: str |
| 184 | + compose_hash: str |
| 185 | + device_id: str |
| 186 | + |
| 187 | + |
| 188 | +# Type variable for TCB info versions |
| 189 | +T = TypeVar("T", bound=TcbInfo) |
| 190 | + |
| 191 | + |
| 192 | +class InfoResponse(BaseModel, Generic[T]): |
173 | 193 | app_id: str |
174 | 194 | instance_id: str |
175 | 195 | app_cert: str |
176 | | - tcb_info: TcbInfo |
| 196 | + tcb_info: T |
177 | 197 | app_name: str |
178 | 198 | device_id: str |
179 | 199 | os_image_hash: str = "" |
180 | 200 | key_provider_info: str |
181 | 201 | compose_hash: str |
182 | 202 |
|
183 | 203 | @classmethod |
184 | | - def parse_response(cls, obj: Any) -> "InfoResponse": |
| 204 | + def parse_response(cls, obj: Any, tcb_info_type: type[T]) -> "InfoResponse[T]": |
| 205 | + """Parse response from service, automatically deserializing tcb_info. |
| 206 | +
|
| 207 | + Args: |
| 208 | + obj: Raw response object from service |
| 209 | + tcb_info_type: The specific TcbInfo subclass to use for parsing |
| 210 | +
|
| 211 | + """ |
185 | 212 | if ( |
186 | 213 | isinstance(obj, dict) |
187 | 214 | and "tcb_info" in obj |
188 | 215 | and isinstance(obj["tcb_info"], str) |
189 | 216 | ): |
190 | 217 | obj = dict(obj) |
191 | | - obj["tcb_info"] = TcbInfo(**json.loads(obj["tcb_info"])) |
| 218 | + obj["tcb_info"] = tcb_info_type(**json.loads(obj["tcb_info"])) |
192 | 219 | return cls(**obj) |
193 | 220 |
|
194 | 221 |
|
@@ -311,10 +338,10 @@ async def get_quote( |
311 | 338 | result = await self._send_rpc_request("GetQuote", {"report_data": hex}) |
312 | 339 | return GetQuoteResponse(**result) |
313 | 340 |
|
314 | | - async def info(self) -> InfoResponse: |
| 341 | + async def info(self) -> InfoResponse[TcbInfo]: |
315 | 342 | """Fetch service information including parsed TCB info.""" |
316 | 343 | result = await self._send_rpc_request("Info", {}) |
317 | | - return InfoResponse.parse_response(result) |
| 344 | + return InfoResponse.parse_response(result, TcbInfoV05x) |
318 | 345 |
|
319 | 346 | async def emit_event( |
320 | 347 | self, |
@@ -391,7 +418,7 @@ def get_quote( |
391 | 418 | raise NotImplementedError |
392 | 419 |
|
393 | 420 | @call_async |
394 | | - def info(self) -> InfoResponse: |
| 421 | + def info(self) -> InfoResponse[TcbInfo]: |
395 | 422 | """Fetch service information including parsed TCB info.""" |
396 | 423 | raise NotImplementedError |
397 | 424 |
|
@@ -503,6 +530,11 @@ async def tdx_quote( |
503 | 530 |
|
504 | 531 | return GetQuoteResponse(**result) |
505 | 532 |
|
| 533 | + async def info(self) -> InfoResponse[TcbInfo]: |
| 534 | + """Fetch service information including parsed TCB info.""" |
| 535 | + result = await self._send_rpc_request("Info", {}) |
| 536 | + return InfoResponse.parse_response(result, TcbInfoV03x) |
| 537 | + |
506 | 538 |
|
507 | 539 | class TappdClient(DstackClient): |
508 | 540 | """Deprecated client kept for backward compatibility. |
@@ -537,6 +569,11 @@ def tdx_quote( |
537 | 569 | """Use ``get_quote`` instead (deprecated).""" |
538 | 570 | raise NotImplementedError |
539 | 571 |
|
| 572 | + @call_async |
| 573 | + def info(self) -> InfoResponse[TcbInfo]: |
| 574 | + """Fetch service information including parsed TCB info.""" |
| 575 | + raise NotImplementedError |
| 576 | + |
540 | 577 | @call_async |
541 | 578 | def __enter__(self): |
542 | 579 | raise NotImplementedError |
|
0 commit comments