Skip to content

Commit 96ac391

Browse files
committed
Add cache for verifer
1 parent a24476f commit 96ac391

File tree

1 file changed

+224
-34
lines changed

1 file changed

+224
-34
lines changed

verifier/src/verification.rs

Lines changed: 224 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,21 @@
22
//
33
// SPDX-License-Identifier: Apache-2.0
44

5-
use std::{ffi::OsStr, path::Path, time::Duration};
5+
use std::{
6+
ffi::OsStr,
7+
path::{Path, PathBuf},
8+
time::Duration,
9+
};
610

7-
use anyhow::{bail, Context, Result};
11+
use anyhow::{anyhow, bail, Context, Result};
812
use cc_eventlog::TdxEventLog as EventLog;
9-
use dstack_mr::RtmrLog;
13+
use dstack_mr::{RtmrLog, TdxMeasurementDetails, TdxMeasurements};
1014
use dstack_types::VmConfig;
1115
use ra_tls::attestation::{Attestation, VerifiedAttestation};
16+
use serde::{Deserialize, Serialize};
1217
use sha2::{Digest as _, Sha256, Sha384};
1318
use tokio::{io::AsyncWriteExt, process::Command};
14-
use tracing::info;
19+
use tracing::{debug, info, warn};
1520

1621
use crate::types::{
1722
AcpiTables, RtmrEventEntry, RtmrEventStatus, RtmrMismatch, VerificationDetails,
@@ -143,6 +148,14 @@ fn collect_rtmr_mismatch(
143148
}
144149
}
145150

151+
const MEASUREMENT_CACHE_VERSION: u32 = 1;
152+
153+
#[derive(Clone, Serialize, Deserialize)]
154+
struct CachedMeasurement {
155+
version: u32,
156+
measurements: TdxMeasurements,
157+
}
158+
146159
pub struct CvmVerifier {
147160
pub image_cache_dir: String,
148161
pub download_url: String,
@@ -158,6 +171,178 @@ impl CvmVerifier {
158171
}
159172
}
160173

174+
fn measurement_cache_dir(&self) -> PathBuf {
175+
Path::new(&self.image_cache_dir).join("measurements")
176+
}
177+
178+
fn measurement_cache_path(&self, cache_key: &str) -> PathBuf {
179+
self.measurement_cache_dir()
180+
.join(format!("{cache_key}.json"))
181+
}
182+
183+
fn vm_config_cache_key(vm_config: &VmConfig) -> Result<String> {
184+
let serialized = serde_json::to_vec(vm_config)
185+
.context("Failed to serialize VM config for cache key computation")?;
186+
Ok(hex::encode(Sha256::digest(&serialized)))
187+
}
188+
189+
fn load_measurements_from_cache(&self, cache_key: &str) -> Result<Option<TdxMeasurements>> {
190+
let path = self.measurement_cache_path(cache_key);
191+
if !path.exists() {
192+
return Ok(None);
193+
}
194+
195+
let path_display = path.display().to_string();
196+
let contents = match fs_err::read(&path) {
197+
Ok(data) => data,
198+
Err(e) => {
199+
warn!("Failed to read measurement cache {}: {e:?}", path_display);
200+
return Ok(None);
201+
}
202+
};
203+
204+
let cached: CachedMeasurement = match serde_json::from_slice(&contents) {
205+
Ok(entry) => entry,
206+
Err(e) => {
207+
warn!("Failed to parse measurement cache {}: {e:?}", path_display);
208+
return Ok(None);
209+
}
210+
};
211+
212+
if cached.version != MEASUREMENT_CACHE_VERSION {
213+
debug!(
214+
"Ignoring measurement cache {} due to version mismatch (found {}, expected {})",
215+
path_display, cached.version, MEASUREMENT_CACHE_VERSION
216+
);
217+
return Ok(None);
218+
}
219+
220+
debug!("Loaded measurement cache entry {}", cache_key);
221+
Ok(Some(cached.measurements))
222+
}
223+
224+
fn store_measurements_in_cache(
225+
&self,
226+
cache_key: &str,
227+
measurements: &TdxMeasurements,
228+
) -> Result<()> {
229+
let cache_dir = self.measurement_cache_dir();
230+
fs_err::create_dir_all(&cache_dir)
231+
.context("Failed to create measurement cache directory")?;
232+
233+
let path = self.measurement_cache_path(cache_key);
234+
let mut tmp = tempfile::NamedTempFile::new_in(&cache_dir)
235+
.context("Failed to create temporary cache file")?;
236+
237+
let entry = CachedMeasurement {
238+
version: MEASUREMENT_CACHE_VERSION,
239+
measurements: measurements.clone(),
240+
};
241+
serde_json::to_writer(tmp.as_file_mut(), &entry)
242+
.context("Failed to serialize measurement cache entry")?;
243+
tmp.as_file_mut()
244+
.sync_all()
245+
.context("Failed to flush measurement cache entry to disk")?;
246+
247+
tmp.persist(&path).map_err(|e| {
248+
anyhow!(
249+
"Failed to persist measurement cache to {}: {e}",
250+
path.display()
251+
)
252+
})?;
253+
debug!("Stored measurement cache entry {}", cache_key);
254+
Ok(())
255+
}
256+
257+
fn compute_measurement_details(
258+
&self,
259+
vm_config: &VmConfig,
260+
fw_path: &Path,
261+
kernel_path: &Path,
262+
initrd_path: &Path,
263+
kernel_cmdline: &str,
264+
) -> Result<TdxMeasurementDetails> {
265+
let firmware = fw_path.display().to_string();
266+
let kernel = kernel_path.display().to_string();
267+
let initrd = initrd_path.display().to_string();
268+
269+
let details = dstack_mr::Machine::builder()
270+
.cpu_count(vm_config.cpu_count)
271+
.memory_size(vm_config.memory_size)
272+
.firmware(&firmware)
273+
.kernel(&kernel)
274+
.initrd(&initrd)
275+
.kernel_cmdline(kernel_cmdline)
276+
.root_verity(true)
277+
.hotplug_off(vm_config.hotplug_off)
278+
.maybe_two_pass_add_pages(vm_config.qemu_single_pass_add_pages)
279+
.maybe_pic(vm_config.pic)
280+
.maybe_qemu_version(vm_config.qemu_version.clone())
281+
.maybe_pci_hole64_size(if vm_config.pci_hole64_size > 0 {
282+
Some(vm_config.pci_hole64_size)
283+
} else {
284+
None
285+
})
286+
.hugepages(vm_config.hugepages)
287+
.num_gpus(vm_config.num_gpus)
288+
.num_nvswitches(vm_config.num_nvswitches)
289+
.build()
290+
.measure_with_logs()
291+
.context("Failed to compute expected MRs")?;
292+
293+
Ok(details)
294+
}
295+
296+
fn compute_measurements(
297+
&self,
298+
vm_config: &VmConfig,
299+
fw_path: &Path,
300+
kernel_path: &Path,
301+
initrd_path: &Path,
302+
kernel_cmdline: &str,
303+
) -> Result<TdxMeasurements> {
304+
self.compute_measurement_details(
305+
vm_config,
306+
fw_path,
307+
kernel_path,
308+
initrd_path,
309+
kernel_cmdline,
310+
)
311+
.map(|details| details.measurements)
312+
}
313+
314+
fn load_or_compute_measurements(
315+
&self,
316+
vm_config: &VmConfig,
317+
fw_path: &Path,
318+
kernel_path: &Path,
319+
initrd_path: &Path,
320+
kernel_cmdline: &str,
321+
) -> Result<TdxMeasurements> {
322+
let cache_key = Self::vm_config_cache_key(vm_config)?;
323+
324+
if let Some(measurements) = self.load_measurements_from_cache(&cache_key)? {
325+
return Ok(measurements);
326+
}
327+
328+
let measurements = self.compute_measurements(
329+
vm_config,
330+
fw_path,
331+
kernel_path,
332+
initrd_path,
333+
kernel_cmdline,
334+
)?;
335+
336+
if let Err(e) = self.store_measurements_in_cache(&cache_key, &measurements) {
337+
warn!(
338+
"Failed to write measurement cache entry for {}: {e:?}",
339+
cache_key
340+
);
341+
}
342+
343+
Ok(measurements)
344+
}
345+
161346
pub async fn verify(&self, request: &VerificationRequest) -> Result<VerificationResponse> {
162347
let quote = hex::decode(&request.quote).context("Failed to decode quote hex")?;
163348

@@ -305,39 +490,41 @@ impl CvmVerifier {
305490
let kernel_cmdline = image_info.cmdline + " initrd=initrd";
306491

307492
// Use dstack-mr to compute expected MRs
308-
let measurement_details = dstack_mr::Machine::builder()
309-
.cpu_count(vm_config.cpu_count)
310-
.memory_size(vm_config.memory_size)
311-
.firmware(&fw_path.display().to_string())
312-
.kernel(&kernel_path.display().to_string())
313-
.initrd(&initrd_path.display().to_string())
314-
.kernel_cmdline(&kernel_cmdline)
315-
.root_verity(true)
316-
.hotplug_off(vm_config.hotplug_off)
317-
.maybe_two_pass_add_pages(vm_config.qemu_single_pass_add_pages)
318-
.maybe_pic(vm_config.pic)
319-
.maybe_qemu_version(vm_config.qemu_version.clone())
320-
.maybe_pci_hole64_size(if vm_config.pci_hole64_size > 0 {
321-
Some(vm_config.pci_hole64_size)
322-
} else {
323-
None
324-
})
325-
.hugepages(vm_config.hugepages)
326-
.num_gpus(vm_config.num_gpus)
327-
.num_nvswitches(vm_config.num_nvswitches)
328-
.build()
329-
.measure_with_logs()
330-
.context("Failed to compute expected MRs")?;
493+
let (mrs, expected_logs) = if debug {
494+
let TdxMeasurementDetails {
495+
measurements,
496+
rtmr_logs,
497+
acpi_tables,
498+
} = self
499+
.compute_measurement_details(
500+
vm_config,
501+
&fw_path,
502+
&kernel_path,
503+
&initrd_path,
504+
&kernel_cmdline,
505+
)
506+
.context("Failed to compute expected measurements")?;
331507

332-
let mrs = measurement_details.measurements;
333-
let expected_logs = measurement_details.rtmr_logs;
334-
if debug {
335508
details.acpi_tables = Some(AcpiTables {
336-
tables: hex::encode(&measurement_details.acpi_tables.tables),
337-
rsdp: hex::encode(&measurement_details.acpi_tables.rsdp),
338-
loader: hex::encode(&measurement_details.acpi_tables.loader),
509+
tables: hex::encode(&acpi_tables.tables),
510+
rsdp: hex::encode(&acpi_tables.rsdp),
511+
loader: hex::encode(&acpi_tables.loader),
339512
});
340-
}
513+
514+
(measurements, Some(rtmr_logs))
515+
} else {
516+
(
517+
self.load_or_compute_measurements(
518+
vm_config,
519+
&fw_path,
520+
&kernel_path,
521+
&initrd_path,
522+
&kernel_cmdline,
523+
)
524+
.context("Failed to obtain expected measurements")?,
525+
None,
526+
)
527+
};
341528

342529
let expected_mrs = Mrs {
343530
mrtd: mrs.mrtd.clone(),
@@ -363,6 +550,9 @@ impl CvmVerifier {
363550
if !debug {
364551
return result;
365552
}
553+
let Some(expected_logs) = expected_logs.as_ref() else {
554+
return result;
555+
};
366556
let mut rtmr_debug = Vec::new();
367557

368558
if expected_mrs.rtmr0 != verified_mrs.rtmr0 {

0 commit comments

Comments
 (0)