22//
33// SPDX-License-Identifier: Apache-2.0
44
5- use std:: { ffi:: OsStr , path:: Path , time:: Duration } ;
5+ use std:: {
6+ ffi:: OsStr ,
7+ path:: { Path , PathBuf } ,
8+ time:: Duration ,
9+ } ;
610
7- use anyhow:: { bail, Context , Result } ;
11+ use anyhow:: { anyhow , bail, Context , Result } ;
812use cc_eventlog:: TdxEventLog as EventLog ;
9- use dstack_mr:: RtmrLog ;
13+ use dstack_mr:: { RtmrLog , TdxMeasurementDetails , TdxMeasurements } ;
1014use dstack_types:: VmConfig ;
1115use ra_tls:: attestation:: { Attestation , VerifiedAttestation } ;
16+ use serde:: { Deserialize , Serialize } ;
1217use sha2:: { Digest as _, Sha256 , Sha384 } ;
1318use tokio:: { io:: AsyncWriteExt , process:: Command } ;
14- use tracing:: info;
19+ use tracing:: { debug , info, warn } ;
1520
1621use crate :: types:: {
1722 AcpiTables , RtmrEventEntry , RtmrEventStatus , RtmrMismatch , VerificationDetails ,
@@ -143,6 +148,14 @@ fn collect_rtmr_mismatch(
143148 }
144149}
145150
151+ const MEASUREMENT_CACHE_VERSION : u32 = 1 ;
152+
153+ #[ derive( Clone , Serialize , Deserialize ) ]
154+ struct CachedMeasurement {
155+ version : u32 ,
156+ measurements : TdxMeasurements ,
157+ }
158+
146159pub struct CvmVerifier {
147160 pub image_cache_dir : String ,
148161 pub download_url : String ,
@@ -158,6 +171,178 @@ impl CvmVerifier {
158171 }
159172 }
160173
174+ fn measurement_cache_dir ( & self ) -> PathBuf {
175+ Path :: new ( & self . image_cache_dir ) . join ( "measurements" )
176+ }
177+
178+ fn measurement_cache_path ( & self , cache_key : & str ) -> PathBuf {
179+ self . measurement_cache_dir ( )
180+ . join ( format ! ( "{cache_key}.json" ) )
181+ }
182+
183+ fn vm_config_cache_key ( vm_config : & VmConfig ) -> Result < String > {
184+ let serialized = serde_json:: to_vec ( vm_config)
185+ . context ( "Failed to serialize VM config for cache key computation" ) ?;
186+ Ok ( hex:: encode ( Sha256 :: digest ( & serialized) ) )
187+ }
188+
189+ fn load_measurements_from_cache ( & self , cache_key : & str ) -> Result < Option < TdxMeasurements > > {
190+ let path = self . measurement_cache_path ( cache_key) ;
191+ if !path. exists ( ) {
192+ return Ok ( None ) ;
193+ }
194+
195+ let path_display = path. display ( ) . to_string ( ) ;
196+ let contents = match fs_err:: read ( & path) {
197+ Ok ( data) => data,
198+ Err ( e) => {
199+ warn ! ( "Failed to read measurement cache {}: {e:?}" , path_display) ;
200+ return Ok ( None ) ;
201+ }
202+ } ;
203+
204+ let cached: CachedMeasurement = match serde_json:: from_slice ( & contents) {
205+ Ok ( entry) => entry,
206+ Err ( e) => {
207+ warn ! ( "Failed to parse measurement cache {}: {e:?}" , path_display) ;
208+ return Ok ( None ) ;
209+ }
210+ } ;
211+
212+ if cached. version != MEASUREMENT_CACHE_VERSION {
213+ debug ! (
214+ "Ignoring measurement cache {} due to version mismatch (found {}, expected {})" ,
215+ path_display, cached. version, MEASUREMENT_CACHE_VERSION
216+ ) ;
217+ return Ok ( None ) ;
218+ }
219+
220+ debug ! ( "Loaded measurement cache entry {}" , cache_key) ;
221+ Ok ( Some ( cached. measurements ) )
222+ }
223+
224+ fn store_measurements_in_cache (
225+ & self ,
226+ cache_key : & str ,
227+ measurements : & TdxMeasurements ,
228+ ) -> Result < ( ) > {
229+ let cache_dir = self . measurement_cache_dir ( ) ;
230+ fs_err:: create_dir_all ( & cache_dir)
231+ . context ( "Failed to create measurement cache directory" ) ?;
232+
233+ let path = self . measurement_cache_path ( cache_key) ;
234+ let mut tmp = tempfile:: NamedTempFile :: new_in ( & cache_dir)
235+ . context ( "Failed to create temporary cache file" ) ?;
236+
237+ let entry = CachedMeasurement {
238+ version : MEASUREMENT_CACHE_VERSION ,
239+ measurements : measurements. clone ( ) ,
240+ } ;
241+ serde_json:: to_writer ( tmp. as_file_mut ( ) , & entry)
242+ . context ( "Failed to serialize measurement cache entry" ) ?;
243+ tmp. as_file_mut ( )
244+ . sync_all ( )
245+ . context ( "Failed to flush measurement cache entry to disk" ) ?;
246+
247+ tmp. persist ( & path) . map_err ( |e| {
248+ anyhow ! (
249+ "Failed to persist measurement cache to {}: {e}" ,
250+ path. display( )
251+ )
252+ } ) ?;
253+ debug ! ( "Stored measurement cache entry {}" , cache_key) ;
254+ Ok ( ( ) )
255+ }
256+
257+ fn compute_measurement_details (
258+ & self ,
259+ vm_config : & VmConfig ,
260+ fw_path : & Path ,
261+ kernel_path : & Path ,
262+ initrd_path : & Path ,
263+ kernel_cmdline : & str ,
264+ ) -> Result < TdxMeasurementDetails > {
265+ let firmware = fw_path. display ( ) . to_string ( ) ;
266+ let kernel = kernel_path. display ( ) . to_string ( ) ;
267+ let initrd = initrd_path. display ( ) . to_string ( ) ;
268+
269+ let details = dstack_mr:: Machine :: builder ( )
270+ . cpu_count ( vm_config. cpu_count )
271+ . memory_size ( vm_config. memory_size )
272+ . firmware ( & firmware)
273+ . kernel ( & kernel)
274+ . initrd ( & initrd)
275+ . kernel_cmdline ( kernel_cmdline)
276+ . root_verity ( true )
277+ . hotplug_off ( vm_config. hotplug_off )
278+ . maybe_two_pass_add_pages ( vm_config. qemu_single_pass_add_pages )
279+ . maybe_pic ( vm_config. pic )
280+ . maybe_qemu_version ( vm_config. qemu_version . clone ( ) )
281+ . maybe_pci_hole64_size ( if vm_config. pci_hole64_size > 0 {
282+ Some ( vm_config. pci_hole64_size )
283+ } else {
284+ None
285+ } )
286+ . hugepages ( vm_config. hugepages )
287+ . num_gpus ( vm_config. num_gpus )
288+ . num_nvswitches ( vm_config. num_nvswitches )
289+ . build ( )
290+ . measure_with_logs ( )
291+ . context ( "Failed to compute expected MRs" ) ?;
292+
293+ Ok ( details)
294+ }
295+
296+ fn compute_measurements (
297+ & self ,
298+ vm_config : & VmConfig ,
299+ fw_path : & Path ,
300+ kernel_path : & Path ,
301+ initrd_path : & Path ,
302+ kernel_cmdline : & str ,
303+ ) -> Result < TdxMeasurements > {
304+ self . compute_measurement_details (
305+ vm_config,
306+ fw_path,
307+ kernel_path,
308+ initrd_path,
309+ kernel_cmdline,
310+ )
311+ . map ( |details| details. measurements )
312+ }
313+
314+ fn load_or_compute_measurements (
315+ & self ,
316+ vm_config : & VmConfig ,
317+ fw_path : & Path ,
318+ kernel_path : & Path ,
319+ initrd_path : & Path ,
320+ kernel_cmdline : & str ,
321+ ) -> Result < TdxMeasurements > {
322+ let cache_key = Self :: vm_config_cache_key ( vm_config) ?;
323+
324+ if let Some ( measurements) = self . load_measurements_from_cache ( & cache_key) ? {
325+ return Ok ( measurements) ;
326+ }
327+
328+ let measurements = self . compute_measurements (
329+ vm_config,
330+ fw_path,
331+ kernel_path,
332+ initrd_path,
333+ kernel_cmdline,
334+ ) ?;
335+
336+ if let Err ( e) = self . store_measurements_in_cache ( & cache_key, & measurements) {
337+ warn ! (
338+ "Failed to write measurement cache entry for {}: {e:?}" ,
339+ cache_key
340+ ) ;
341+ }
342+
343+ Ok ( measurements)
344+ }
345+
161346 pub async fn verify ( & self , request : & VerificationRequest ) -> Result < VerificationResponse > {
162347 let quote = hex:: decode ( & request. quote ) . context ( "Failed to decode quote hex" ) ?;
163348
@@ -305,39 +490,41 @@ impl CvmVerifier {
305490 let kernel_cmdline = image_info. cmdline + " initrd=initrd" ;
306491
307492 // Use dstack-mr to compute expected MRs
308- let measurement_details = dstack_mr:: Machine :: builder ( )
309- . cpu_count ( vm_config. cpu_count )
310- . memory_size ( vm_config. memory_size )
311- . firmware ( & fw_path. display ( ) . to_string ( ) )
312- . kernel ( & kernel_path. display ( ) . to_string ( ) )
313- . initrd ( & initrd_path. display ( ) . to_string ( ) )
314- . kernel_cmdline ( & kernel_cmdline)
315- . root_verity ( true )
316- . hotplug_off ( vm_config. hotplug_off )
317- . maybe_two_pass_add_pages ( vm_config. qemu_single_pass_add_pages )
318- . maybe_pic ( vm_config. pic )
319- . maybe_qemu_version ( vm_config. qemu_version . clone ( ) )
320- . maybe_pci_hole64_size ( if vm_config. pci_hole64_size > 0 {
321- Some ( vm_config. pci_hole64_size )
322- } else {
323- None
324- } )
325- . hugepages ( vm_config. hugepages )
326- . num_gpus ( vm_config. num_gpus )
327- . num_nvswitches ( vm_config. num_nvswitches )
328- . build ( )
329- . measure_with_logs ( )
330- . context ( "Failed to compute expected MRs" ) ?;
493+ let ( mrs, expected_logs) = if debug {
494+ let TdxMeasurementDetails {
495+ measurements,
496+ rtmr_logs,
497+ acpi_tables,
498+ } = self
499+ . compute_measurement_details (
500+ vm_config,
501+ & fw_path,
502+ & kernel_path,
503+ & initrd_path,
504+ & kernel_cmdline,
505+ )
506+ . context ( "Failed to compute expected measurements" ) ?;
331507
332- let mrs = measurement_details. measurements ;
333- let expected_logs = measurement_details. rtmr_logs ;
334- if debug {
335508 details. acpi_tables = Some ( AcpiTables {
336- tables : hex:: encode ( & measurement_details . acpi_tables . tables ) ,
337- rsdp : hex:: encode ( & measurement_details . acpi_tables . rsdp ) ,
338- loader : hex:: encode ( & measurement_details . acpi_tables . loader ) ,
509+ tables : hex:: encode ( & acpi_tables. tables ) ,
510+ rsdp : hex:: encode ( & acpi_tables. rsdp ) ,
511+ loader : hex:: encode ( & acpi_tables. loader ) ,
339512 } ) ;
340- }
513+
514+ ( measurements, Some ( rtmr_logs) )
515+ } else {
516+ (
517+ self . load_or_compute_measurements (
518+ vm_config,
519+ & fw_path,
520+ & kernel_path,
521+ & initrd_path,
522+ & kernel_cmdline,
523+ )
524+ . context ( "Failed to obtain expected measurements" ) ?,
525+ None ,
526+ )
527+ } ;
341528
342529 let expected_mrs = Mrs {
343530 mrtd : mrs. mrtd . clone ( ) ,
@@ -363,6 +550,9 @@ impl CvmVerifier {
363550 if !debug {
364551 return result;
365552 }
553+ let Some ( expected_logs) = expected_logs. as_ref ( ) else {
554+ return result;
555+ } ;
366556 let mut rtmr_debug = Vec :: new ( ) ;
367557
368558 if expected_mrs. rtmr0 != verified_mrs. rtmr0 {
0 commit comments