1- use std:: { path:: Path , sync:: Arc } ;
1+ use std:: { ffi :: OsStr , path:: Path , sync:: Arc } ;
22
33use anyhow:: { bail, Context , Result } ;
44use dstack_kms_rpc:: {
@@ -17,8 +17,10 @@ use ra_tls::{
1717 kdf,
1818} ;
1919use scale:: Decode ;
20+ use serde:: { Deserialize , Serialize } ;
2021use sha2:: Digest ;
2122use tokio:: { io:: AsyncWriteExt , process:: Command } ;
23+ use tracing:: info;
2224use upgrade_authority:: BootInfo ;
2325
2426use crate :: {
@@ -83,6 +85,49 @@ struct BootConfig {
8385 mr_image : Vec < u8 > ,
8486}
8587
88+ #[ derive( Debug , Serialize , Deserialize , Clone , Eq , PartialEq ) ]
89+ struct Mrs {
90+ mrtd : String ,
91+ rtmr0 : String ,
92+ rtmr1 : String ,
93+ rtmr2 : String ,
94+ }
95+
96+ impl Mrs {
97+ fn assert_eq ( & self , other : & Self ) -> Result < ( ) > {
98+ let Self {
99+ mrtd,
100+ rtmr0,
101+ rtmr1,
102+ rtmr2,
103+ } = self ;
104+ if mrtd != & other. mrtd {
105+ bail ! ( "MRTD does not match" ) ;
106+ }
107+ if rtmr0 != & other. rtmr0 {
108+ bail ! ( "RTMR0 does not match" ) ;
109+ }
110+ if rtmr1 != & other. rtmr1 {
111+ bail ! ( "RTMR1 does not match" ) ;
112+ }
113+ if rtmr2 != & other. rtmr2 {
114+ bail ! ( "RTMR2 does not match" ) ;
115+ }
116+ Ok ( ( ) )
117+ }
118+ }
119+
120+ impl From < & BootInfo > for Mrs {
121+ fn from ( report : & BootInfo ) -> Self {
122+ Self {
123+ mrtd : hex:: encode ( & report. mrtd ) ,
124+ rtmr0 : hex:: encode ( & report. rtmr0 ) ,
125+ rtmr1 : hex:: encode ( & report. rtmr1 ) ,
126+ rtmr2 : hex:: encode ( & report. rtmr2 ) ,
127+ }
128+ }
129+ }
130+
86131impl RpcHandler {
87132 fn ensure_attested ( & self ) -> Result < & VerifiedAttestation > {
88133 let Some ( attestation) = & self . attestation else {
@@ -104,16 +149,56 @@ impl RpcHandler {
104149 . await
105150 }
106151
152+ fn get_cached_mrs ( & self , key : & str ) -> Result < Mrs > {
153+ let path = self . state . config . image . cache_dir . join ( "computed" ) . join ( key) ;
154+ if !path. exists ( ) {
155+ bail ! ( "Cached MRs not found" ) ;
156+ }
157+ let content = fs:: read_to_string ( path) . context ( "Failed to read cached MRs" ) ?;
158+ let cached_mrs: Mrs =
159+ serde_json:: from_str ( & content) . context ( "Failed to parse cached MRs" ) ?;
160+ Ok ( cached_mrs)
161+ }
162+
163+ fn cache_mrs ( & self , key : & str , mrs : & Mrs ) -> Result < ( ) > {
164+ let path = self . state . config . image . cache_dir . join ( "computed" ) . join ( key) ;
165+ fs:: create_dir_all ( path. parent ( ) . unwrap ( ) ) . context ( "Failed to create cache directory" ) ?;
166+ safe_write:: safe_write (
167+ & path,
168+ serde_json:: to_string ( mrs) . context ( "Failed to serialize cached MRs" ) ?,
169+ )
170+ . context ( "Failed to write cached MRs" ) ?;
171+ Ok ( ( ) )
172+ }
173+
107174 async fn verify_mr_image ( & self , vm_config : & VmConfig , report : & BootInfo ) -> Result < ( ) > {
108175 if !self . state . config . image . verify {
176+ info ! ( "Image verification is disabled" ) ;
109177 return Ok ( ( ) ) ;
110178 }
111179 let hex_mr_image = hex:: encode ( & vm_config. mr_image ) ;
180+ info ! ( "Verifying image {hex_mr_image}" ) ;
181+
182+ let verified_mrs: Mrs = report. into ( ) ;
183+
184+ let cache_key = {
185+ let vm_config =
186+ serde_json:: to_vec ( vm_config) . context ( "Failed to serialize VM config" ) ?;
187+ hex:: encode ( sha2:: Sha256 :: new_with_prefix ( & vm_config) . finalize ( ) )
188+ } ;
189+ if let Ok ( cached_mrs) = self . get_cached_mrs ( & cache_key) {
190+ cached_mrs
191+ . assert_eq ( & verified_mrs)
192+ . context ( "MRs do not match (cached)" ) ?;
193+ return Ok ( ( ) ) ;
194+ }
195+
112196 // Create a directory for the image if it doesn't exist
113197 let image_dir = self . state . config . image . cache_dir . join ( & hex_mr_image) ;
114198 // Check if metadata.json exists, if not download the image
115199 let metadata_path = image_dir. join ( "metadata.json" ) ;
116200 if !metadata_path. exists ( ) {
201+ info ! ( "Image {} not found, downloading" , hex_mr_image) ;
117202 tokio:: time:: timeout (
118203 self . state . config . image . download_timeout ,
119204 self . download_image ( & hex_mr_image, & image_dir) ,
@@ -148,44 +233,13 @@ impl RpcHandler {
148233 }
149234
150235 // Parse the expected MRs
151- let expected_mrs: serde_json :: Value =
236+ let expected_mrs: Mrs =
152237 serde_json:: from_slice ( & output. stdout ) . context ( "Failed to parse dstack-mr output" ) ?;
153-
154- // Compare MRs
155- let expected_mrtd = expected_mrs[ "mrtd" ]
156- . as_str ( )
157- . context ( "Missing mrtd in expected MRs" ) ?;
158- let expected_rtmr0 = expected_mrs[ "rtmr0" ]
159- . as_str ( )
160- . context ( "Missing rtmr0 in expected MRs" ) ?;
161- let expected_rtmr1 = expected_mrs[ "rtmr1" ]
162- . as_str ( )
163- . context ( "Missing rtmr1 in expected MRs" ) ?;
164- let expected_rtmr2 = expected_mrs[ "rtmr2" ]
165- . as_str ( )
166- . context ( "Missing rtmr2 in expected MRs" ) ?;
167-
168- let report_mrtd = hex:: encode ( & report. mrtd ) ;
169- let report_rtmr0 = hex:: encode ( & report. rtmr0 ) ;
170- let report_rtmr1 = hex:: encode ( & report. rtmr1 ) ;
171- let report_rtmr2 = hex:: encode ( & report. rtmr2 ) ;
172-
173- if report_mrtd != expected_mrtd {
174- bail ! ( "MRTD mismatch: {} != {}" , report_mrtd, expected_mrtd) ;
175- }
176-
177- if report_rtmr0 != expected_rtmr0 {
178- bail ! ( "RTMR0 mismatch: {} != {}" , report_rtmr0, expected_rtmr0) ;
179- }
180-
181- if report_rtmr1 != expected_rtmr1 {
182- bail ! ( "RTMR1 mismatch: {} != {}" , report_rtmr1, expected_rtmr1) ;
183- }
184-
185- if report_rtmr2 != expected_rtmr2 {
186- bail ! ( "RTMR2 mismatch: {} != {}" , report_rtmr2, expected_rtmr2) ;
187- }
188-
238+ self . cache_mrs ( & cache_key, & expected_mrs)
239+ . context ( "Failed to cache MRs" ) ?;
240+ expected_mrs
241+ . assert_eq ( & verified_mrs)
242+ . context ( "MRs do not match" ) ?;
189243 Ok ( ( ) )
190244 }
191245
@@ -198,9 +252,13 @@ impl RpcHandler {
198252 . download_url
199253 . replace ( "{MR_IMAGE}" , hex_mr_image) ;
200254
201- // Create a temporary directory for extraction
202- let auto_delete_temp_dir =
203- tempfile:: tempdir ( ) . context ( "Failed to create temporary directory" ) ?;
255+ // Create a temporary directory for extraction within the cache directory
256+ let cache_dir = self . state . config . image . cache_dir . join ( "tmp" ) ;
257+ fs:: create_dir_all ( & cache_dir) . context ( "Failed to create cache directory" ) ?;
258+ let auto_delete_temp_dir = tempfile:: Builder :: new ( )
259+ . prefix ( "tmp-download-" )
260+ . tempdir_in ( & cache_dir)
261+ . context ( "Failed to create temporary directory" ) ?;
204262 let tmp_dir = auto_delete_temp_dir. path ( ) ;
205263 // Download the image tarball
206264 let client = reqwest:: Client :: new ( ) ;
@@ -229,14 +287,14 @@ impl RpcHandler {
229287 . context ( "Failed to write chunk to file" ) ?;
230288 }
231289
232- let tmp_x_dir = tmp_dir. join ( "extracted" ) ;
233- fs:: create_dir_all ( & tmp_x_dir ) . context ( "Failed to create extraction directory" ) ?;
290+ let extracted_dir = tmp_dir. join ( "extracted" ) ;
291+ fs:: create_dir_all ( & extracted_dir ) . context ( "Failed to create extraction directory" ) ?;
234292
235293 // Extract the tarball
236294 let output = Command :: new ( "tar" )
237295 . arg ( "xzf" )
238296 . arg ( & tarball_path)
239- . current_dir ( & tmp_x_dir )
297+ . current_dir ( & extracted_dir )
240298 . output ( )
241299 . await
242300 . context ( "Failed to extract tarball" ) ?;
@@ -252,7 +310,7 @@ impl RpcHandler {
252310 let output = Command :: new ( "sha256sum" )
253311 . arg ( "-c" )
254312 . arg ( "sha256sum.txt" )
255- . current_dir ( & tmp_x_dir )
313+ . current_dir ( & extracted_dir )
256314 . output ( )
257315 . await
258316 . context ( "Failed to verify checksum" ) ?;
@@ -264,22 +322,23 @@ impl RpcHandler {
264322 ) ;
265323 }
266324 // Remove the files that are not listed in sha256sum.txt
267- let sha256sum_path = tmp_x_dir . join ( "sha256sum.txt" ) ;
325+ let sha256sum_path = extracted_dir . join ( "sha256sum.txt" ) ;
268326 let files_doc =
269327 fs:: read_to_string ( & sha256sum_path) . context ( "Failed to read sha256sum.txt" ) ?;
270- let listed_files = files_doc
328+ let listed_files: Vec < & OsStr > = files_doc
271329 . lines ( )
272330 . flat_map ( |line| line. split_whitespace ( ) . nth ( 1 ) )
273- . collect :: < Vec < _ > > ( ) ;
274- let files = fs:: read_dir ( & tmp_x_dir) . context ( "Failed to read directory" ) ?;
331+ . map ( |s| s. as_ref ( ) )
332+ . collect ( ) ;
333+ let files = fs:: read_dir ( & extracted_dir) . context ( "Failed to read directory" ) ?;
275334 for file in files {
276335 let file = file. context ( "Failed to read directory entry" ) ?;
277- let path = file. path ( ) ;
278- if !listed_files. contains ( & path . display ( ) . to_string ( ) . as_str ( ) ) {
279- if path. is_dir ( ) {
280- fs:: remove_dir_all ( & path) . context ( "Failed to remove directory" ) ?;
336+ let filename = file. file_name ( ) ;
337+ if !listed_files. contains ( & filename . as_os_str ( ) ) {
338+ if file . path ( ) . is_dir ( ) {
339+ fs:: remove_dir_all ( file . path ( ) ) . context ( "Failed to remove directory" ) ?;
281340 } else {
282- fs:: remove_file ( & path) . context ( "Failed to remove file" ) ?;
341+ fs:: remove_file ( file . path ( ) ) . context ( "Failed to remove file" ) ?;
283342 }
284343 }
285344 }
@@ -291,7 +350,7 @@ impl RpcHandler {
291350 }
292351
293352 // Move the extracted files to the destination directory
294- let metadata_path = tmp_x_dir . join ( "metadata.json" ) ;
353+ let metadata_path = extracted_dir . join ( "metadata.json" ) ;
295354 if !metadata_path. exists ( ) {
296355 bail ! ( "metadata.json not found in the extracted archive" ) ;
297356 }
@@ -302,7 +361,7 @@ impl RpcHandler {
302361 let dst_dir_parent = dst_dir. parent ( ) . context ( "Failed to get parent directory" ) ?;
303362 fs:: create_dir_all ( dst_dir_parent) . context ( "Failed to create parent directory" ) ?;
304363 // Move the extracted files to the destination directory
305- fs:: rename ( tmp_x_dir , dst_dir)
364+ fs:: rename ( extracted_dir , dst_dir)
306365 . context ( "Failed to move extracted files to destination directory" ) ?;
307366 Ok ( ( ) )
308367 }
0 commit comments