@@ -120,6 +120,7 @@ impl SnapshotDownloadCommand {
120
120
& certificate,
121
121
& message,
122
122
& snapshot_message,
123
+ & db_dir,
123
124
)
124
125
. await ?;
125
126
@@ -233,11 +234,16 @@ impl SnapshotDownloadCommand {
233
234
certificate : & MithrilCertificate ,
234
235
message : & ProtocolMessage ,
235
236
snapshot : & Snapshot ,
237
+ db_dir : & Path ,
236
238
) -> MithrilResult < ( ) > {
237
239
progress_printer. report_step ( step_number, "Verifying the snapshot signature…" ) ?;
238
240
if !certificate. match_message ( message) {
239
241
debug ! ( "Digest verification failed, removing unpacked files & directory." ) ;
240
242
243
+ if let Err ( error) = std:: fs:: remove_dir_all ( db_dir) {
244
+ warn ! ( "Error while removing unpacked files & directory: {error}." ) ;
245
+ }
246
+
241
247
return Err ( anyhow ! (
242
248
"Certificate verification failed (snapshot digest = '{}')." ,
243
249
snapshot. digest. clone( )
@@ -327,3 +333,71 @@ impl Source for SnapshotDownloadCommand {
327
333
Ok ( map)
328
334
}
329
335
}
336
+
337
+ #[ cfg( test) ]
338
+ mod tests {
339
+ use mithril_client:: {
340
+ common:: { Beacon , ProtocolMessagePartKey } ,
341
+ MithrilCertificateMetadata ,
342
+ } ;
343
+
344
+ use super :: * ;
345
+
346
+ fn dummy_certificate ( ) -> MithrilCertificate {
347
+ let mut protocol_message = ProtocolMessage :: new ( ) ;
348
+ protocol_message. set_message_part (
349
+ ProtocolMessagePartKey :: SnapshotDigest ,
350
+ Snapshot :: dummy ( ) . digest . to_string ( ) ,
351
+ ) ;
352
+ protocol_message. set_message_part (
353
+ ProtocolMessagePartKey :: NextAggregateVerificationKey ,
354
+ "whatever" . to_string ( ) ,
355
+ ) ;
356
+ MithrilCertificate {
357
+ hash : "hash" . to_string ( ) ,
358
+ previous_hash : "previous_hash" . to_string ( ) ,
359
+ beacon : Beacon :: new ( "testnet" . to_string ( ) , 10 , 100 ) ,
360
+ metadata : MithrilCertificateMetadata :: dummy ( ) ,
361
+ protocol_message : protocol_message. clone ( ) ,
362
+ signed_message : "signed_message" . to_string ( ) ,
363
+ aggregate_verification_key : String :: new ( ) ,
364
+ multi_signature : String :: new ( ) ,
365
+ genesis_signature : String :: new ( ) ,
366
+ }
367
+ }
368
+
369
+ #[ tokio:: test]
370
+ async fn verify_snapshot_signature_should_remove_db_dir_if_messages_dismatch ( ) {
371
+ let progress_printer = ProgressPrinter :: new ( ProgressOutputType :: Tty , 1 ) ;
372
+ let certificate = dummy_certificate ( ) ;
373
+ let mut message = ProtocolMessage :: new ( ) ;
374
+ message. set_message_part ( ProtocolMessagePartKey :: SnapshotDigest , "digest" . to_string ( ) ) ;
375
+ message. set_message_part (
376
+ ProtocolMessagePartKey :: NextAggregateVerificationKey ,
377
+ "avk" . to_string ( ) ,
378
+ ) ;
379
+ let snapshot = Snapshot :: dummy ( ) ;
380
+ let db_dir = std:: env:: temp_dir ( ) . join ( "db" ) ;
381
+ if db_dir. exists ( ) {
382
+ std:: fs:: remove_dir_all ( & db_dir) . unwrap ( ) ;
383
+ }
384
+ std:: fs:: create_dir_all ( & db_dir) . unwrap ( ) ;
385
+ println ! ( "db_dir: '{:?}'" , db_dir) ;
386
+
387
+ let result = SnapshotDownloadCommand :: verify_snapshot_signature (
388
+ 1 ,
389
+ & progress_printer,
390
+ & certificate,
391
+ & message,
392
+ & snapshot,
393
+ & db_dir,
394
+ )
395
+ . await ;
396
+
397
+ assert ! ( result. is_err( ) ) ;
398
+ assert ! (
399
+ !db_dir. exists( ) ,
400
+ "The db directory should have been removed but it still exists"
401
+ ) ;
402
+ }
403
+ }
0 commit comments