1- use std:: { ffi:: OsStr , path:: Path , sync:: Arc } ;
1+ use std:: {
2+ ffi:: OsStr ,
3+ path:: { Path , PathBuf } ,
4+ sync:: Arc ,
5+ } ;
26
37use anyhow:: { bail, Context , Result } ;
48use dstack_kms_rpc:: {
59 kms_server:: { KmsRpc , KmsServer } ,
6- AppId , AppKeyResponse , GetAppKeyRequest , GetKmsKeyRequest , GetMetaResponse ,
7- GetTempCaCertResponse , KmsKeyResponse , KmsKeys , PublicKeyResponse , SignCertRequest ,
8- SignCertResponse ,
10+ AppId , AppKeyResponse , ClearImageCacheRequest , GetAppKeyRequest , GetKmsKeyRequest ,
11+ GetMetaResponse , GetTempCaCertResponse , KmsKeyResponse , KmsKeys , PublicKeyResponse ,
12+ SignCertRequest , SignCertResponse ,
913} ;
1014use dstack_types:: VmConfig ;
1115use fs_err as fs;
@@ -149,8 +153,41 @@ impl RpcHandler {
149153 . await
150154 }
151155
156+ fn image_cache_dir ( & self ) -> PathBuf {
157+ self . state . config . image . cache_dir . join ( "images" )
158+ }
159+
160+ fn mr_cache_dir ( & self ) -> PathBuf {
161+ self . state . config . image . cache_dir . join ( "computed" )
162+ }
163+
164+ fn remove_cache ( & self , parent_dir : & PathBuf , sub_dir : & str ) -> Result < ( ) > {
165+ if sub_dir. is_empty ( ) {
166+ return Ok ( ( ) ) ;
167+ }
168+ if sub_dir == "all" {
169+ fs:: remove_dir_all ( parent_dir) ?;
170+ } else {
171+ let path = parent_dir. join ( sub_dir) ;
172+ if path. is_dir ( ) {
173+ fs:: remove_dir_all ( path) ?;
174+ } else {
175+ fs:: remove_file ( path) ?;
176+ }
177+ }
178+ Ok ( ( ) )
179+ }
180+
181+ fn ensure_admin ( & self , token : & str ) -> Result < ( ) > {
182+ let token_hash = sha2:: Sha256 :: new_with_prefix ( token) . finalize ( ) ;
183+ if token_hash. as_slice ( ) != self . state . config . admin_token_hash . as_slice ( ) {
184+ bail ! ( "Invalid token" ) ;
185+ }
186+ Ok ( ( ) )
187+ }
188+
152189 fn get_cached_mrs ( & self , key : & str ) -> Result < Mrs > {
153- let path = self . state . config . image . cache_dir . join ( "computed" ) . join ( key) ;
190+ let path = self . mr_cache_dir ( ) . join ( key) ;
154191 if !path. exists ( ) {
155192 bail ! ( "Cached MRs not found" ) ;
156193 }
@@ -161,7 +198,7 @@ impl RpcHandler {
161198 }
162199
163200 fn cache_mrs ( & self , key : & str , mrs : & Mrs ) -> Result < ( ) > {
164- let path = self . state . config . image . cache_dir . join ( "computed" ) . join ( key) ;
201+ let path = self . mr_cache_dir ( ) . join ( key) ;
165202 fs:: create_dir_all ( path. parent ( ) . unwrap ( ) ) . context ( "Failed to create cache directory" ) ?;
166203 safe_write:: safe_write (
167204 & path,
@@ -194,7 +231,7 @@ impl RpcHandler {
194231 }
195232
196233 // Create a directory for the image if it doesn't exist
197- let image_dir = self . state . config . image . cache_dir . join ( & hex_os_image_hash) ;
234+ let image_dir = self . image_cache_dir ( ) . join ( & hex_os_image_hash) ;
198235 // Check if metadata.json exists, if not download the image
199236 let metadata_path = image_dir. join ( "metadata.json" ) ;
200237 if !metadata_path. exists ( ) {
@@ -253,7 +290,7 @@ impl RpcHandler {
253290 . replace ( "{OS_IMAGE_HASH}" , hex_os_image_hash) ;
254291
255292 // Create a temporary directory for extraction within the cache directory
256- let cache_dir = self . state . config . image . cache_dir . join ( "tmp" ) ;
293+ let cache_dir = self . image_cache_dir ( ) . join ( "tmp" ) ;
257294 fs:: create_dir_all ( & cache_dir) . context ( "Failed to create cache directory" ) ?;
258295 let auto_delete_temp_dir = tempfile:: Builder :: new ( )
259296 . prefix ( "tmp-download-" )
@@ -578,6 +615,15 @@ impl KmsRpc for RpcHandler {
578615 ] ,
579616 } )
580617 }
618+
619+ async fn clear_image_cache ( self , request : ClearImageCacheRequest ) -> Result < ( ) > {
620+ self . ensure_admin ( & request. token ) ?;
621+ self . remove_cache ( & self . image_cache_dir ( ) , & request. image_hash )
622+ . context ( "Failed to clear image cache" ) ?;
623+ self . remove_cache ( & self . mr_cache_dir ( ) , & request. config_hash )
624+ . context ( "Failed to clear MR cache" ) ?;
625+ Ok ( ( ) )
626+ }
581627}
582628
583629impl RpcCall < KmsState > for RpcHandler {
0 commit comments