@@ -5,6 +5,7 @@ use std::{
55} ;
66
77use anyhow:: { anyhow, bail, Context , Result } ;
8+ use dcap_qvl:: quote:: { Quote , Report } ;
89use dstack_kms_rpc as rpc;
910use dstack_types:: {
1011 shared_filenames:: {
@@ -17,7 +18,7 @@ use fs_err as fs;
1718use ra_rpc:: client:: { CertInfo , RaClient , RaClientConfig } ;
1819use ra_tls:: cert:: generate_ra_cert;
1920use serde:: { Deserialize , Serialize } ;
20- use tdx_attest:: extend_rtmr3;
21+ use tdx_attest:: { extend_rtmr3, get_quote } ;
2122use tracing:: { info, warn} ;
2223
2324use crate :: {
@@ -364,6 +365,7 @@ impl<'a> Stage0<'a> {
364365 let ( _, ca_pem) = x509_parser:: pem:: parse_x509_pem ( keys. ca_cert . as_bytes ( ) )
365366 . context ( "Failed to parse ca cert" ) ?;
366367 let x509 = ca_pem. parse_x509 ( ) . context ( "Failed to parse ca cert" ) ?;
368+ self . ensure_provider_id_matches ( x509. public_key ( ) . raw ) ?;
367369 let id = hex:: encode ( x509. public_key ( ) . raw ) ;
368370 let provider_info = KeyProviderInfo :: new ( "kms" . into ( ) , id) ;
369371 emit_key_provider_info ( & provider_info) ?;
@@ -373,6 +375,20 @@ impl<'a> Stage0<'a> {
373375 Ok ( ( ) )
374376 }
375377
378+ fn ensure_provider_id_matches ( & self , provider_id : & [ u8 ] ) -> Result < ( ) > {
379+ let expected_key_provider_id = & self . shared . app_compose . key_provider_id ;
380+ if expected_key_provider_id. is_empty ( ) {
381+ return Ok ( ( ) ) ;
382+ } ;
383+ if expected_key_provider_id != provider_id {
384+ bail ! (
385+ "Unexpected key provider id: {:?}, expected: {:?}" ,
386+ hex_fmt:: HexFmt ( provider_id) ,
387+ hex_fmt:: HexFmt ( expected_key_provider_id)
388+ ) ;
389+ }
390+ Ok ( ( ) )
391+ }
376392 async fn get_keys_from_local_key_provider ( & self ) -> Result < ( ) > {
377393 info ! ( "Getting keys from local key provider" ) ;
378394 let provision = self
@@ -387,6 +403,7 @@ impl<'a> Stage0<'a> {
387403 fs:: write ( self . app_keys_file ( ) , keys_json) . context ( "Failed to write app keys" ) ?;
388404
389405 // write to RTMR
406+ self . ensure_provider_id_matches ( & provision. mr ) ?;
390407 let provider_info = KeyProviderInfo :: new ( "local-sgx" . into ( ) , hex:: encode ( provision. mr ) ) ;
391408 emit_key_provider_info ( & provider_info) ?;
392409 Ok ( ( ) )
@@ -470,6 +487,8 @@ impl<'a> Stage0<'a> {
470487 let key_provider = self . shared . app_compose . key_provider ( ) ;
471488 let mut instance_info = self . shared . instance_info . clone ( ) ;
472489
490+ validate_compose_hash ( & compose_hash) . context ( "Failed to validate compose hash" ) ?;
491+
473492 if instance_info. app_id . is_empty ( ) {
474493 instance_info. app_id = truncated_compose_hash. to_vec ( ) ;
475494 }
@@ -536,6 +555,25 @@ impl<'a> Stage0<'a> {
536555 }
537556}
538557
558+ fn validate_compose_hash ( compose_hash : & [ u8 ] ) -> Result < ( ) > {
559+ // If configid is not all zero, use it as compose_hash
560+ let ( _, quote) = get_quote ( & [ 0u8 ; 64 ] , None ) . context ( "Failed to get quote" ) ?;
561+ let quote = Quote :: parse ( & quote) . context ( "Failed to parse quote" ) ?;
562+ let configid = match quote. report {
563+ Report :: SgxEnclave ( _report) => bail ! ( "SGX quote is not supported" ) ,
564+ Report :: TD10 ( report) => report. mr_config_id ,
565+ Report :: TD15 ( report) => report. base . mr_config_id ,
566+ } ;
567+ info ! ( "mr_config_id: {}" , hex_fmt:: HexFmt ( & configid) ) ;
568+ if configid == [ 0u8 ; 48 ] {
569+ return Ok ( ( ) ) ;
570+ }
571+ if & configid[ ..32 ] != compose_hash {
572+ bail ! ( "mr_config_id does not match compose hash" ) ;
573+ }
574+ Ok ( ( ) )
575+ }
576+
539577impl Stage1 < ' _ > {
540578 fn resolve ( & self , path : & str ) -> String {
541579 path. to_string ( )
0 commit comments