Skip to content

Commit 4004641

Browse files
committed
Fix mr_image verification issues
1 parent c8e78eb commit 4004641

File tree

2 files changed

+116
-57
lines changed

2 files changed

+116
-57
lines changed

kms/src/main_service.rs

Lines changed: 115 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{path::Path, sync::Arc};
1+
use std::{ffi::OsStr, path::Path, sync::Arc};
22

33
use anyhow::{bail, Context, Result};
44
use dstack_kms_rpc::{
@@ -17,8 +17,10 @@ use ra_tls::{
1717
kdf,
1818
};
1919
use scale::Decode;
20+
use serde::{Deserialize, Serialize};
2021
use sha2::Digest;
2122
use tokio::{io::AsyncWriteExt, process::Command};
23+
use tracing::info;
2224
use upgrade_authority::BootInfo;
2325

2426
use crate::{
@@ -83,6 +85,49 @@ struct BootConfig {
8385
mr_image: Vec<u8>,
8486
}
8587

88+
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
89+
struct Mrs {
90+
mrtd: String,
91+
rtmr0: String,
92+
rtmr1: String,
93+
rtmr2: String,
94+
}
95+
96+
impl Mrs {
97+
fn assert_eq(&self, other: &Self) -> Result<()> {
98+
let Self {
99+
mrtd,
100+
rtmr0,
101+
rtmr1,
102+
rtmr2,
103+
} = self;
104+
if mrtd != &other.mrtd {
105+
bail!("MRTD does not match");
106+
}
107+
if rtmr0 != &other.rtmr0 {
108+
bail!("RTMR0 does not match");
109+
}
110+
if rtmr1 != &other.rtmr1 {
111+
bail!("RTMR1 does not match");
112+
}
113+
if rtmr2 != &other.rtmr2 {
114+
bail!("RTMR2 does not match");
115+
}
116+
Ok(())
117+
}
118+
}
119+
120+
impl From<&BootInfo> for Mrs {
121+
fn from(report: &BootInfo) -> Self {
122+
Self {
123+
mrtd: hex::encode(&report.mrtd),
124+
rtmr0: hex::encode(&report.rtmr0),
125+
rtmr1: hex::encode(&report.rtmr1),
126+
rtmr2: hex::encode(&report.rtmr2),
127+
}
128+
}
129+
}
130+
86131
impl RpcHandler {
87132
fn ensure_attested(&self) -> Result<&VerifiedAttestation> {
88133
let Some(attestation) = &self.attestation else {
@@ -104,16 +149,56 @@ impl RpcHandler {
104149
.await
105150
}
106151

152+
fn get_cached_mrs(&self, key: &str) -> Result<Mrs> {
153+
let path = self.state.config.image.cache_dir.join("computed").join(key);
154+
if !path.exists() {
155+
bail!("Cached MRs not found");
156+
}
157+
let content = fs::read_to_string(path).context("Failed to read cached MRs")?;
158+
let cached_mrs: Mrs =
159+
serde_json::from_str(&content).context("Failed to parse cached MRs")?;
160+
Ok(cached_mrs)
161+
}
162+
163+
fn cache_mrs(&self, key: &str, mrs: &Mrs) -> Result<()> {
164+
let path = self.state.config.image.cache_dir.join("computed").join(key);
165+
fs::create_dir_all(path.parent().unwrap()).context("Failed to create cache directory")?;
166+
safe_write::safe_write(
167+
&path,
168+
serde_json::to_string(mrs).context("Failed to serialize cached MRs")?,
169+
)
170+
.context("Failed to write cached MRs")?;
171+
Ok(())
172+
}
173+
107174
async fn verify_mr_image(&self, vm_config: &VmConfig, report: &BootInfo) -> Result<()> {
108175
if !self.state.config.image.verify {
176+
info!("Image verification is disabled");
109177
return Ok(());
110178
}
111179
let hex_mr_image = hex::encode(&vm_config.mr_image);
180+
info!("Verifying image {hex_mr_image}");
181+
182+
let verified_mrs: Mrs = report.into();
183+
184+
let cache_key = {
185+
let vm_config =
186+
serde_json::to_vec(vm_config).context("Failed to serialize VM config")?;
187+
hex::encode(sha2::Sha256::new_with_prefix(&vm_config).finalize())
188+
};
189+
if let Ok(cached_mrs) = self.get_cached_mrs(&cache_key) {
190+
cached_mrs
191+
.assert_eq(&verified_mrs)
192+
.context("MRs do not match (cached)")?;
193+
return Ok(());
194+
}
195+
112196
// Create a directory for the image if it doesn't exist
113197
let image_dir = self.state.config.image.cache_dir.join(&hex_mr_image);
114198
// Check if metadata.json exists, if not download the image
115199
let metadata_path = image_dir.join("metadata.json");
116200
if !metadata_path.exists() {
201+
info!("Image {} not found, downloading", hex_mr_image);
117202
tokio::time::timeout(
118203
self.state.config.image.download_timeout,
119204
self.download_image(&hex_mr_image, &image_dir),
@@ -148,44 +233,13 @@ impl RpcHandler {
148233
}
149234

150235
// Parse the expected MRs
151-
let expected_mrs: serde_json::Value =
236+
let expected_mrs: Mrs =
152237
serde_json::from_slice(&output.stdout).context("Failed to parse dstack-mr output")?;
153-
154-
// Compare MRs
155-
let expected_mrtd = expected_mrs["mrtd"]
156-
.as_str()
157-
.context("Missing mrtd in expected MRs")?;
158-
let expected_rtmr0 = expected_mrs["rtmr0"]
159-
.as_str()
160-
.context("Missing rtmr0 in expected MRs")?;
161-
let expected_rtmr1 = expected_mrs["rtmr1"]
162-
.as_str()
163-
.context("Missing rtmr1 in expected MRs")?;
164-
let expected_rtmr2 = expected_mrs["rtmr2"]
165-
.as_str()
166-
.context("Missing rtmr2 in expected MRs")?;
167-
168-
let report_mrtd = hex::encode(&report.mrtd);
169-
let report_rtmr0 = hex::encode(&report.rtmr0);
170-
let report_rtmr1 = hex::encode(&report.rtmr1);
171-
let report_rtmr2 = hex::encode(&report.rtmr2);
172-
173-
if report_mrtd != expected_mrtd {
174-
bail!("MRTD mismatch: {} != {}", report_mrtd, expected_mrtd);
175-
}
176-
177-
if report_rtmr0 != expected_rtmr0 {
178-
bail!("RTMR0 mismatch: {} != {}", report_rtmr0, expected_rtmr0);
179-
}
180-
181-
if report_rtmr1 != expected_rtmr1 {
182-
bail!("RTMR1 mismatch: {} != {}", report_rtmr1, expected_rtmr1);
183-
}
184-
185-
if report_rtmr2 != expected_rtmr2 {
186-
bail!("RTMR2 mismatch: {} != {}", report_rtmr2, expected_rtmr2);
187-
}
188-
238+
self.cache_mrs(&cache_key, &expected_mrs)
239+
.context("Failed to cache MRs")?;
240+
expected_mrs
241+
.assert_eq(&verified_mrs)
242+
.context("MRs do not match")?;
189243
Ok(())
190244
}
191245

@@ -198,9 +252,13 @@ impl RpcHandler {
198252
.download_url
199253
.replace("{MR_IMAGE}", hex_mr_image);
200254

201-
// Create a temporary directory for extraction
202-
let auto_delete_temp_dir =
203-
tempfile::tempdir().context("Failed to create temporary directory")?;
255+
// Create a temporary directory for extraction within the cache directory
256+
let cache_dir = self.state.config.image.cache_dir.join("tmp");
257+
fs::create_dir_all(&cache_dir).context("Failed to create cache directory")?;
258+
let auto_delete_temp_dir = tempfile::Builder::new()
259+
.prefix("tmp-download-")
260+
.tempdir_in(&cache_dir)
261+
.context("Failed to create temporary directory")?;
204262
let tmp_dir = auto_delete_temp_dir.path();
205263
// Download the image tarball
206264
let client = reqwest::Client::new();
@@ -229,14 +287,14 @@ impl RpcHandler {
229287
.context("Failed to write chunk to file")?;
230288
}
231289

232-
let tmp_x_dir = tmp_dir.join("extracted");
233-
fs::create_dir_all(&tmp_x_dir).context("Failed to create extraction directory")?;
290+
let extracted_dir = tmp_dir.join("extracted");
291+
fs::create_dir_all(&extracted_dir).context("Failed to create extraction directory")?;
234292

235293
// Extract the tarball
236294
let output = Command::new("tar")
237295
.arg("xzf")
238296
.arg(&tarball_path)
239-
.current_dir(&tmp_x_dir)
297+
.current_dir(&extracted_dir)
240298
.output()
241299
.await
242300
.context("Failed to extract tarball")?;
@@ -252,7 +310,7 @@ impl RpcHandler {
252310
let output = Command::new("sha256sum")
253311
.arg("-c")
254312
.arg("sha256sum.txt")
255-
.current_dir(&tmp_x_dir)
313+
.current_dir(&extracted_dir)
256314
.output()
257315
.await
258316
.context("Failed to verify checksum")?;
@@ -264,22 +322,23 @@ impl RpcHandler {
264322
);
265323
}
266324
// Remove the files that are not listed in sha256sum.txt
267-
let sha256sum_path = tmp_x_dir.join("sha256sum.txt");
325+
let sha256sum_path = extracted_dir.join("sha256sum.txt");
268326
let files_doc =
269327
fs::read_to_string(&sha256sum_path).context("Failed to read sha256sum.txt")?;
270-
let listed_files = files_doc
328+
let listed_files: Vec<&OsStr> = files_doc
271329
.lines()
272330
.flat_map(|line| line.split_whitespace().nth(1))
273-
.collect::<Vec<_>>();
274-
let files = fs::read_dir(&tmp_x_dir).context("Failed to read directory")?;
331+
.map(|s| s.as_ref())
332+
.collect();
333+
let files = fs::read_dir(&extracted_dir).context("Failed to read directory")?;
275334
for file in files {
276335
let file = file.context("Failed to read directory entry")?;
277-
let path = file.path();
278-
if !listed_files.contains(&path.display().to_string().as_str()) {
279-
if path.is_dir() {
280-
fs::remove_dir_all(&path).context("Failed to remove directory")?;
336+
let filename = file.file_name();
337+
if !listed_files.contains(&filename.as_os_str()) {
338+
if file.path().is_dir() {
339+
fs::remove_dir_all(file.path()).context("Failed to remove directory")?;
281340
} else {
282-
fs::remove_file(&path).context("Failed to remove file")?;
341+
fs::remove_file(file.path()).context("Failed to remove file")?;
283342
}
284343
}
285344
}
@@ -291,7 +350,7 @@ impl RpcHandler {
291350
}
292351

293352
// Move the extracted files to the destination directory
294-
let metadata_path = tmp_x_dir.join("metadata.json");
353+
let metadata_path = extracted_dir.join("metadata.json");
295354
if !metadata_path.exists() {
296355
bail!("metadata.json not found in the extracted archive");
297356
}
@@ -302,7 +361,7 @@ impl RpcHandler {
302361
let dst_dir_parent = dst_dir.parent().context("Failed to get parent directory")?;
303362
fs::create_dir_all(dst_dir_parent).context("Failed to create parent directory")?;
304363
// Move the extracted files to the destination directory
305-
fs::rename(tmp_x_dir, dst_dir)
364+
fs::rename(extracted_dir, dst_dir)
306365
.context("Failed to move extracted files to destination directory")?;
307366
Ok(())
308367
}

vmm/src/app.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ impl App {
464464
let vm_config = serde_json::to_string(&json!({
465465
"mr_image": image.digest.unwrap_or_default(),
466466
"cpu_count": manifest.vcpu,
467-
"memory_size": manifest.memory * 1024 * 1024,
467+
"memory_size": manifest.memory as u64 * 1024 * 1024,
468468
}))?;
469469
json!({
470470
"kms_urls": cfg.cvm.kms_urls,

0 commit comments

Comments
 (0)