@@ -201,18 +201,18 @@ impl SnapshotService for MithrilClientSnapshotService {
201
201
async fn download (
202
202
& self ,
203
203
snapshot_entity : & SignedEntity < Snapshot > ,
204
- pathdir : & Path ,
204
+ download_dir : & Path ,
205
205
genesis_verification_key : & str ,
206
206
progress_target : ProgressDrawTarget ,
207
207
) -> StdResult < PathBuf > {
208
208
debug ! ( "Snapshot service: download." ) ;
209
209
210
- let unpack_dir = pathdir . join ( "db" ) ;
210
+ let db_dir = download_dir . join ( "db" ) ;
211
211
let progress_bar = MultiProgress :: with_draw_target ( progress_target) ;
212
212
progress_bar. println ( "1/7 - Checking local disk info…" ) ?;
213
213
let unpacker = SnapshotUnpacker ;
214
214
215
- if let Err ( e) = unpacker. check_prerequisites ( & unpack_dir , snapshot_entity. artifact . size ) {
215
+ if let Err ( e) = unpacker. check_prerequisites ( & db_dir , snapshot_entity. artifact . size ) {
216
216
self . check_disk_space_error ( e) ?;
217
217
}
218
218
@@ -237,49 +237,53 @@ impl SnapshotService for MithrilClientSnapshotService {
237
237
. unwrap ( )
238
238
. with_key ( "eta" , |state : & ProgressState , w : & mut dyn Write | write ! ( w, "{:.1}s" , state. eta( ) . as_secs_f64( ) ) . unwrap ( ) )
239
239
. progress_chars ( "#>-" ) ) ;
240
- let filepath = self
240
+ let snapshot_path = self
241
241
. snapshot_client
242
- . download ( & snapshot_entity. artifact , pathdir , pb)
242
+ . download ( & snapshot_entity. artifact , download_dir , pb)
243
243
. await
244
- . map_err ( |e| format ! ( "Could not download file in '{}': {e}" , pathdir. display( ) ) ) ?;
244
+ . map_err ( |e| {
245
+ format ! (
246
+ "Could not download file in '{}': {e}" ,
247
+ download_dir. display( )
248
+ )
249
+ } ) ?;
245
250
246
251
progress_bar. println ( "5/7 - Unpacking the snapshot…" ) ?;
247
- let unpacker = unpacker. unpack_snapshot ( & filepath , & unpack_dir ) ;
252
+ let unpacker = unpacker. unpack_snapshot ( & snapshot_path , & db_dir ) ;
248
253
self . wait_spinner ( & progress_bar, unpacker) . await ?;
249
254
250
255
progress_bar. println ( "6/7 - Computing the snapshot digest…" ) ?;
251
256
let unpacked_snapshot_digest = self
252
257
. immutable_digester
253
- . compute_digest ( & unpack_dir , & certificate. beacon )
258
+ . compute_digest ( & db_dir , & certificate. beacon )
254
259
. await
255
- . map_err ( |e| {
256
- format ! (
257
- "Could not compute digest in '{}': {e}" ,
258
- unpack_dir. display( )
259
- )
260
- } ) ?;
260
+ . map_err ( |e| format ! ( "Could not compute digest in '{}': {e}" , db_dir. display( ) ) ) ?;
261
261
262
262
progress_bar. println ( "7/7 - Verifying the snapshot signature…" ) ?;
263
- let mut protocol_message = certificate. protocol_message . clone ( ) ;
264
- protocol_message. set_message_part (
265
- ProtocolMessagePartKey :: SnapshotDigest ,
266
- unpacked_snapshot_digest,
267
- ) ;
268
- if protocol_message. compute_hash ( ) != certificate. signed_message {
263
+ let expected_message = {
264
+ let mut protocol_message = certificate. protocol_message . clone ( ) ;
265
+ protocol_message. set_message_part (
266
+ ProtocolMessagePartKey :: SnapshotDigest ,
267
+ unpacked_snapshot_digest,
268
+ ) ;
269
+ protocol_message. compute_hash ( )
270
+ } ;
271
+
272
+ if expected_message != certificate. signed_message {
269
273
debug ! ( "Digest verification failed, removing unpacked files & directory." ) ;
270
274
271
- if let Err ( e ) = std:: fs:: remove_dir_all ( & unpack_dir ) {
272
- warn ! ( "Error while removing unpacked files & directory: {e }." ) ;
275
+ if let Err ( error ) = std:: fs:: remove_dir_all ( & db_dir ) {
276
+ warn ! ( "Error while removing unpacked files & directory: {error }." ) ;
273
277
}
274
278
275
279
return Err ( SnapshotServiceError :: CouldNotVerifySnapshot {
276
280
digest : snapshot_entity. artifact . digest . clone ( ) ,
277
- path : filepath . canonicalize ( ) . unwrap ( ) ,
281
+ path : snapshot_path . canonicalize ( ) . unwrap ( ) ,
278
282
}
279
283
. into ( ) ) ;
280
284
}
281
285
282
- Ok ( unpack_dir )
286
+ Ok ( db_dir )
283
287
}
284
288
}
285
289
@@ -298,34 +302,44 @@ mod tests {
298
302
test_utils:: fake_data,
299
303
} ;
300
304
use std:: {
305
+ ffi:: OsStr ,
301
306
fs:: { create_dir_all, File } ,
302
307
io:: Write ,
303
308
} ;
304
309
305
310
use crate :: {
306
311
aggregator_client:: { AggregatorClient , MockAggregatorHTTPClient } ,
307
312
dependencies:: DependenciesBuilder ,
313
+ services:: mock:: * ,
308
314
FromSnapshotMessageAdapter ,
309
315
} ;
310
316
311
- use super :: super :: mock:: * ;
312
-
313
317
use super :: * ;
314
318
315
319
/// see [`archive_file_path`] to see where the dummy will be created
316
320
fn build_dummy_snapshot ( digest : & str , data_expected : & str , test_dir : & Path ) {
321
+ // create a fake file to archive
322
+ let data_file_path = {
323
+ let data_file_path = test_dir. join ( "db" ) . join ( "test_data.txt" ) ;
324
+ create_dir_all ( data_file_path. parent ( ) . unwrap ( ) ) . unwrap ( ) ;
325
+
326
+ let mut source_file = File :: create ( data_file_path. as_path ( ) ) . unwrap ( ) ;
327
+ write ! ( source_file, "{data_expected}" ) . unwrap ( ) ;
328
+
329
+ data_file_path
330
+ } ;
331
+
332
+ // create the archive
317
333
let archive_file_path = test_dir. join ( format ! ( "snapshot-{digest}" ) ) ;
318
- let data_file_path = test_dir. join ( Path :: new ( "db/test_data.txt" ) ) ;
319
- create_dir_all ( data_file_path. parent ( ) . unwrap ( ) ) . unwrap ( ) ;
320
- let mut source_file = File :: create ( data_file_path. as_path ( ) ) . unwrap ( ) ;
321
- write ! ( source_file, "{data_expected}" ) . unwrap ( ) ;
322
334
let archive_file = File :: create ( archive_file_path) . unwrap ( ) ;
323
335
let archive_encoder = GzEncoder :: new ( & archive_file, Compression :: default ( ) ) ;
324
336
let mut archive_builder = tar:: Builder :: new ( archive_encoder) ;
325
337
archive_builder
326
338
. append_dir_all ( "." , data_file_path. parent ( ) . unwrap ( ) )
327
339
. unwrap ( ) ;
328
340
archive_builder. into_inner ( ) . unwrap ( ) . finish ( ) . unwrap ( ) ;
341
+
342
+ // remove the fake file
329
343
let _ = std:: fs:: remove_dir_all ( data_file_path. parent ( ) . unwrap ( ) ) ;
330
344
}
331
345
@@ -361,6 +375,36 @@ mod tests {
361
375
}
362
376
}
363
377
378
+ fn get_mocks_for_snapshot_service_configured_to_make_download_succeed ( ) -> (
379
+ MockAggregatorHTTPClient ,
380
+ MockCertificateVerifierImpl ,
381
+ DumbImmutableDigester ,
382
+ ) {
383
+ let mut http_client = MockAggregatorHTTPClient :: new ( ) ;
384
+ http_client. expect_probe ( ) . returning ( |_| Ok ( ( ) ) ) ;
385
+ http_client
386
+ . expect_download ( )
387
+ . returning ( move |_, _, _| Ok ( ( ) ) )
388
+ . times ( 1 ) ;
389
+ http_client. expect_get_content ( ) . returning ( |_| {
390
+ let mut message = CertificateMessage :: dummy ( ) ;
391
+ message. signed_message = message. protocol_message . compute_hash ( ) ;
392
+ let message = serde_json:: to_string ( & message) . unwrap ( ) ;
393
+
394
+ Ok ( message)
395
+ } ) ;
396
+
397
+ let mut certificate_verifier = MockCertificateVerifierImpl :: new ( ) ;
398
+ certificate_verifier
399
+ . expect_verify_certificate_chain ( )
400
+ . returning ( |_, _, _| Ok ( ( ) ) )
401
+ . times ( 1 ) ;
402
+
403
+ let dumb_digester = DumbImmutableDigester :: new ( "snapshot-digest-123" , true ) ;
404
+
405
+ ( http_client, certificate_verifier, dumb_digester)
406
+ }
407
+
364
408
fn get_dep_builder ( http_client : Arc < dyn AggregatorClient > ) -> DependenciesBuilder {
365
409
let config_builder: ConfigBuilder < DefaultState > = ConfigBuilder :: default ( ) ;
366
410
let config = config_builder
@@ -472,40 +516,25 @@ mod tests {
472
516
async fn test_download_snapshot_ok ( ) {
473
517
let test_path = std:: env:: temp_dir ( ) . join ( "test_download_snapshot_ok" ) ;
474
518
let _ = std:: fs:: remove_dir_all ( & test_path) ;
475
- let mut http_client = MockAggregatorHTTPClient :: new ( ) ;
476
- http_client. expect_probe ( ) . returning ( |_| Ok ( ( ) ) ) ;
477
- http_client
478
- . expect_download ( )
479
- . returning ( move |_, _, _| Ok ( ( ) ) )
480
- . times ( 1 ) ;
481
- http_client. expect_get_content ( ) . returning ( |_| {
482
- let mut message = CertificateMessage :: dummy ( ) ;
483
- message. signed_message = message. protocol_message . compute_hash ( ) ;
484
- let message = serde_json:: to_string ( & message) . unwrap ( ) ;
485
519
486
- Ok ( message)
487
- } ) ;
520
+ let ( http_client, certificate_verifier, digester) =
521
+ get_mocks_for_snapshot_service_configured_to_make_download_succeed ( ) ;
522
+
488
523
let mut builder = get_dep_builder ( Arc :: new ( http_client) ) ;
489
- let mut certificate_verifier = MockCertificateVerifierImpl :: new ( ) ;
490
- certificate_verifier
491
- . expect_verify_certificate_chain ( )
492
- . returning ( |_, _, _| Ok ( ( ) ) )
493
- . times ( 1 ) ;
494
524
builder. certificate_verifier = Some ( Arc :: new ( certificate_verifier) ) ;
495
- builder. immutable_digester = Some ( Arc :: new ( DumbImmutableDigester :: new (
496
- "snapshot-digest-123" ,
497
- true ,
498
- ) ) ) ;
499
- let snapshot = FromSnapshotMessageAdapter :: adapt ( get_snapshot_message ( ) ) ;
525
+ builder. immutable_digester = Some ( Arc :: new ( digester) ) ;
500
526
let snapshot_service = builder. get_snapshot_service ( ) . await . unwrap ( ) ;
501
527
502
- let ( _, verifier) = setup_genesis ( ) ;
503
- let genesis_verification_key = verifier. to_verification_key ( ) ;
528
+ let snapshot = FromSnapshotMessageAdapter :: adapt ( get_snapshot_message ( ) ) ;
504
529
build_dummy_snapshot (
505
530
"digest-10.tar.gz" ,
506
531
"1234567890" . repeat ( 124 ) . as_str ( ) ,
507
532
& test_path,
508
533
) ;
534
+
535
+ let ( _, verifier) = setup_genesis ( ) ;
536
+ let genesis_verification_key = verifier. to_verification_key ( ) ;
537
+
509
538
let filepath = snapshot_service
510
539
. download (
511
540
& snapshot,
@@ -515,42 +544,27 @@ mod tests {
515
544
)
516
545
. await
517
546
. expect ( "Snapshot download should succeed." ) ;
518
- assert ! ( filepath. exists( ) ) ;
519
- let unpack_dir = filepath
520
- . parent ( )
521
- . expect ( "Test downloaded file must be in a directory." )
522
- . join ( "db" ) ;
523
- assert ! ( unpack_dir. is_dir( ) ) ;
547
+ assert ! (
548
+ filepath. is_dir( ) ,
549
+ "Unpacked location must be in a directory."
550
+ ) ;
551
+ assert_eq ! ( Some ( OsStr :: new( "db" ) ) , filepath. file_name( ) ) ;
524
552
}
525
553
526
554
#[ tokio:: test]
527
555
async fn test_download_snapshot_invalid_digest ( ) {
528
556
let test_path = std:: env:: temp_dir ( ) . join ( "test_download_snapshot_invalid_digest" ) ;
529
557
let _ = std:: fs:: remove_dir_all ( & test_path) ;
530
- let mut http_client = MockAggregatorHTTPClient :: new ( ) ;
531
- http_client. expect_probe ( ) . returning ( |_| Ok ( ( ) ) ) ;
532
- http_client
533
- . expect_download ( )
534
- . returning ( move |_, _, _| Ok ( ( ) ) )
535
- . times ( 1 ) ;
536
- http_client. expect_get_content ( ) . returning ( |_| {
537
- let mut message = CertificateMessage :: dummy ( ) ;
538
- message. signed_message = message. protocol_message . compute_hash ( ) ;
539
- let message = serde_json:: to_string ( & message) . unwrap ( ) ;
540
558
541
- Ok ( message)
542
- } ) ;
543
- let http_client = Arc :: new ( http_client) ;
544
- let mut dep_builder = get_dep_builder ( http_client) ;
545
- let mut certificate_verifier = MockCertificateVerifierImpl :: new ( ) ;
546
- certificate_verifier
547
- . expect_verify_certificate_chain ( )
548
- . returning ( |_, _, _| Ok ( ( ) ) )
549
- . times ( 1 ) ;
559
+ let ( http_client, certificate_verifier, _) =
560
+ get_mocks_for_snapshot_service_configured_to_make_download_succeed ( ) ;
550
561
let immutable_digester = DumbImmutableDigester :: new ( "snapshot-digest-KO" , true ) ;
562
+
563
+ let mut dep_builder = get_dep_builder ( Arc :: new ( http_client) ) ;
551
564
dep_builder. certificate_verifier = Some ( Arc :: new ( certificate_verifier) ) ;
552
565
dep_builder. immutable_digester = Some ( Arc :: new ( immutable_digester) ) ;
553
566
let snapshot_service = dep_builder. get_snapshot_service ( ) . await . unwrap ( ) ;
567
+
554
568
let mut signed_entity = FromSnapshotMessageAdapter :: adapt ( get_snapshot_message ( ) ) ;
555
569
signed_entity. artifact . digest = "digest-10" . to_string ( ) ;
556
570
@@ -561,6 +575,7 @@ mod tests {
561
575
"1234567890" . repeat ( 124 ) . as_str ( ) ,
562
576
& test_path,
563
577
) ;
578
+
564
579
let err = snapshot_service
565
580
. download (
566
581
& signed_entity,
@@ -600,14 +615,15 @@ mod tests {
600
615
let test_path = std:: env:: temp_dir ( ) . join ( "test_download_snapshot_dir_already_exists" ) ;
601
616
let _ = std:: fs:: remove_dir_all ( & test_path) ;
602
617
create_dir_all ( test_path. join ( "db" ) ) . unwrap ( ) ;
618
+
603
619
let http_client = MockAggregatorHTTPClient :: new ( ) ;
604
- let http_client = Arc :: new ( http_client) ;
605
- let mut dep_builder = get_dep_builder ( http_client) ;
620
+ let mut dep_builder = get_dep_builder ( Arc :: new ( http_client) ) ;
606
621
let snapshot_service = dep_builder. get_snapshot_service ( ) . await . unwrap ( ) ;
607
622
608
623
let ( _, verifier) = setup_genesis ( ) ;
609
624
let genesis_verification_key = verifier. to_verification_key ( ) ;
610
625
let snapshot = FromSnapshotMessageAdapter :: adapt ( get_snapshot_message ( ) ) ;
626
+
611
627
let err = snapshot_service
612
628
. download (
613
629
& snapshot,
0 commit comments