Skip to content

Commit 7b65665

Browse files
committed
Tidy up client snapshot service code & tests
1 parent cd56f53 commit 7b65665

File tree

1 file changed

+97
-81
lines changed

1 file changed

+97
-81
lines changed

mithril-client/src/services/snapshot.rs

Lines changed: 97 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -201,18 +201,18 @@ impl SnapshotService for MithrilClientSnapshotService {
201201
async fn download(
202202
&self,
203203
snapshot_entity: &SignedEntity<Snapshot>,
204-
pathdir: &Path,
204+
download_dir: &Path,
205205
genesis_verification_key: &str,
206206
progress_target: ProgressDrawTarget,
207207
) -> StdResult<PathBuf> {
208208
debug!("Snapshot service: download.");
209209

210-
let unpack_dir = pathdir.join("db");
210+
let db_dir = download_dir.join("db");
211211
let progress_bar = MultiProgress::with_draw_target(progress_target);
212212
progress_bar.println("1/7 - Checking local disk info…")?;
213213
let unpacker = SnapshotUnpacker;
214214

215-
if let Err(e) = unpacker.check_prerequisites(&unpack_dir, snapshot_entity.artifact.size) {
215+
if let Err(e) = unpacker.check_prerequisites(&db_dir, snapshot_entity.artifact.size) {
216216
self.check_disk_space_error(e)?;
217217
}
218218

@@ -237,49 +237,53 @@ impl SnapshotService for MithrilClientSnapshotService {
237237
.unwrap()
238238
.with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap())
239239
.progress_chars("#>-"));
240-
let filepath = self
240+
let snapshot_path = self
241241
.snapshot_client
242-
.download(&snapshot_entity.artifact, pathdir, pb)
242+
.download(&snapshot_entity.artifact, download_dir, pb)
243243
.await
244-
.map_err(|e| format!("Could not download file in '{}': {e}", pathdir.display()))?;
244+
.map_err(|e| {
245+
format!(
246+
"Could not download file in '{}': {e}",
247+
download_dir.display()
248+
)
249+
})?;
245250

246251
progress_bar.println("5/7 - Unpacking the snapshot…")?;
247-
let unpacker = unpacker.unpack_snapshot(&filepath, &unpack_dir);
252+
let unpacker = unpacker.unpack_snapshot(&snapshot_path, &db_dir);
248253
self.wait_spinner(&progress_bar, unpacker).await?;
249254

250255
progress_bar.println("6/7 - Computing the snapshot digest…")?;
251256
let unpacked_snapshot_digest = self
252257
.immutable_digester
253-
.compute_digest(&unpack_dir, &certificate.beacon)
258+
.compute_digest(&db_dir, &certificate.beacon)
254259
.await
255-
.map_err(|e| {
256-
format!(
257-
"Could not compute digest in '{}': {e}",
258-
unpack_dir.display()
259-
)
260-
})?;
260+
.map_err(|e| format!("Could not compute digest in '{}': {e}", db_dir.display()))?;
261261

262262
progress_bar.println("7/7 - Verifying the snapshot signature…")?;
263-
let mut protocol_message = certificate.protocol_message.clone();
264-
protocol_message.set_message_part(
265-
ProtocolMessagePartKey::SnapshotDigest,
266-
unpacked_snapshot_digest,
267-
);
268-
if protocol_message.compute_hash() != certificate.signed_message {
263+
let expected_message = {
264+
let mut protocol_message = certificate.protocol_message.clone();
265+
protocol_message.set_message_part(
266+
ProtocolMessagePartKey::SnapshotDigest,
267+
unpacked_snapshot_digest,
268+
);
269+
protocol_message.compute_hash()
270+
};
271+
272+
if expected_message != certificate.signed_message {
269273
debug!("Digest verification failed, removing unpacked files & directory.");
270274

271-
if let Err(e) = std::fs::remove_dir_all(&unpack_dir) {
272-
warn!("Error while removing unpacked files & directory: {e}.");
275+
if let Err(error) = std::fs::remove_dir_all(&db_dir) {
276+
warn!("Error while removing unpacked files & directory: {error}.");
273277
}
274278

275279
return Err(SnapshotServiceError::CouldNotVerifySnapshot {
276280
digest: snapshot_entity.artifact.digest.clone(),
277-
path: filepath.canonicalize().unwrap(),
281+
path: snapshot_path.canonicalize().unwrap(),
278282
}
279283
.into());
280284
}
281285

282-
Ok(unpack_dir)
286+
Ok(db_dir)
283287
}
284288
}
285289

@@ -298,34 +302,44 @@ mod tests {
298302
test_utils::fake_data,
299303
};
300304
use std::{
305+
ffi::OsStr,
301306
fs::{create_dir_all, File},
302307
io::Write,
303308
};
304309

305310
use crate::{
306311
aggregator_client::{AggregatorClient, MockAggregatorHTTPClient},
307312
dependencies::DependenciesBuilder,
313+
services::mock::*,
308314
FromSnapshotMessageAdapter,
309315
};
310316

311-
use super::super::mock::*;
312-
313317
use super::*;
314318

315319
/// see [`archive_file_path`] to see where the dummy will be created
316320
fn build_dummy_snapshot(digest: &str, data_expected: &str, test_dir: &Path) {
321+
// create a fake file to archive
322+
let data_file_path = {
323+
let data_file_path = test_dir.join("db").join("test_data.txt");
324+
create_dir_all(data_file_path.parent().unwrap()).unwrap();
325+
326+
let mut source_file = File::create(data_file_path.as_path()).unwrap();
327+
write!(source_file, "{data_expected}").unwrap();
328+
329+
data_file_path
330+
};
331+
332+
// create the archive
317333
let archive_file_path = test_dir.join(format!("snapshot-{digest}"));
318-
let data_file_path = test_dir.join(Path::new("db/test_data.txt"));
319-
create_dir_all(data_file_path.parent().unwrap()).unwrap();
320-
let mut source_file = File::create(data_file_path.as_path()).unwrap();
321-
write!(source_file, "{data_expected}").unwrap();
322334
let archive_file = File::create(archive_file_path).unwrap();
323335
let archive_encoder = GzEncoder::new(&archive_file, Compression::default());
324336
let mut archive_builder = tar::Builder::new(archive_encoder);
325337
archive_builder
326338
.append_dir_all(".", data_file_path.parent().unwrap())
327339
.unwrap();
328340
archive_builder.into_inner().unwrap().finish().unwrap();
341+
342+
// remove the fake file
329343
let _ = std::fs::remove_dir_all(data_file_path.parent().unwrap());
330344
}
331345

@@ -361,6 +375,36 @@ mod tests {
361375
}
362376
}
363377

378+
fn get_mocks_for_snapshot_service_configured_to_make_download_succeed() -> (
379+
MockAggregatorHTTPClient,
380+
MockCertificateVerifierImpl,
381+
DumbImmutableDigester,
382+
) {
383+
let mut http_client = MockAggregatorHTTPClient::new();
384+
http_client.expect_probe().returning(|_| Ok(()));
385+
http_client
386+
.expect_download()
387+
.returning(move |_, _, _| Ok(()))
388+
.times(1);
389+
http_client.expect_get_content().returning(|_| {
390+
let mut message = CertificateMessage::dummy();
391+
message.signed_message = message.protocol_message.compute_hash();
392+
let message = serde_json::to_string(&message).unwrap();
393+
394+
Ok(message)
395+
});
396+
397+
let mut certificate_verifier = MockCertificateVerifierImpl::new();
398+
certificate_verifier
399+
.expect_verify_certificate_chain()
400+
.returning(|_, _, _| Ok(()))
401+
.times(1);
402+
403+
let dumb_digester = DumbImmutableDigester::new("snapshot-digest-123", true);
404+
405+
(http_client, certificate_verifier, dumb_digester)
406+
}
407+
364408
fn get_dep_builder(http_client: Arc<dyn AggregatorClient>) -> DependenciesBuilder {
365409
let config_builder: ConfigBuilder<DefaultState> = ConfigBuilder::default();
366410
let config = config_builder
@@ -472,40 +516,25 @@ mod tests {
472516
async fn test_download_snapshot_ok() {
473517
let test_path = std::env::temp_dir().join("test_download_snapshot_ok");
474518
let _ = std::fs::remove_dir_all(&test_path);
475-
let mut http_client = MockAggregatorHTTPClient::new();
476-
http_client.expect_probe().returning(|_| Ok(()));
477-
http_client
478-
.expect_download()
479-
.returning(move |_, _, _| Ok(()))
480-
.times(1);
481-
http_client.expect_get_content().returning(|_| {
482-
let mut message = CertificateMessage::dummy();
483-
message.signed_message = message.protocol_message.compute_hash();
484-
let message = serde_json::to_string(&message).unwrap();
485519

486-
Ok(message)
487-
});
520+
let (http_client, certificate_verifier, digester) =
521+
get_mocks_for_snapshot_service_configured_to_make_download_succeed();
522+
488523
let mut builder = get_dep_builder(Arc::new(http_client));
489-
let mut certificate_verifier = MockCertificateVerifierImpl::new();
490-
certificate_verifier
491-
.expect_verify_certificate_chain()
492-
.returning(|_, _, _| Ok(()))
493-
.times(1);
494524
builder.certificate_verifier = Some(Arc::new(certificate_verifier));
495-
builder.immutable_digester = Some(Arc::new(DumbImmutableDigester::new(
496-
"snapshot-digest-123",
497-
true,
498-
)));
499-
let snapshot = FromSnapshotMessageAdapter::adapt(get_snapshot_message());
525+
builder.immutable_digester = Some(Arc::new(digester));
500526
let snapshot_service = builder.get_snapshot_service().await.unwrap();
501527

502-
let (_, verifier) = setup_genesis();
503-
let genesis_verification_key = verifier.to_verification_key();
528+
let snapshot = FromSnapshotMessageAdapter::adapt(get_snapshot_message());
504529
build_dummy_snapshot(
505530
"digest-10.tar.gz",
506531
"1234567890".repeat(124).as_str(),
507532
&test_path,
508533
);
534+
535+
let (_, verifier) = setup_genesis();
536+
let genesis_verification_key = verifier.to_verification_key();
537+
509538
let filepath = snapshot_service
510539
.download(
511540
&snapshot,
@@ -515,42 +544,27 @@ mod tests {
515544
)
516545
.await
517546
.expect("Snapshot download should succeed.");
518-
assert!(filepath.exists());
519-
let unpack_dir = filepath
520-
.parent()
521-
.expect("Test downloaded file must be in a directory.")
522-
.join("db");
523-
assert!(unpack_dir.is_dir());
547+
assert!(
548+
filepath.is_dir(),
549+
"Unpacked location must be in a directory."
550+
);
551+
assert_eq!(Some(OsStr::new("db")), filepath.file_name());
524552
}
525553

526554
#[tokio::test]
527555
async fn test_download_snapshot_invalid_digest() {
528556
let test_path = std::env::temp_dir().join("test_download_snapshot_invalid_digest");
529557
let _ = std::fs::remove_dir_all(&test_path);
530-
let mut http_client = MockAggregatorHTTPClient::new();
531-
http_client.expect_probe().returning(|_| Ok(()));
532-
http_client
533-
.expect_download()
534-
.returning(move |_, _, _| Ok(()))
535-
.times(1);
536-
http_client.expect_get_content().returning(|_| {
537-
let mut message = CertificateMessage::dummy();
538-
message.signed_message = message.protocol_message.compute_hash();
539-
let message = serde_json::to_string(&message).unwrap();
540558

541-
Ok(message)
542-
});
543-
let http_client = Arc::new(http_client);
544-
let mut dep_builder = get_dep_builder(http_client);
545-
let mut certificate_verifier = MockCertificateVerifierImpl::new();
546-
certificate_verifier
547-
.expect_verify_certificate_chain()
548-
.returning(|_, _, _| Ok(()))
549-
.times(1);
559+
let (http_client, certificate_verifier, _) =
560+
get_mocks_for_snapshot_service_configured_to_make_download_succeed();
550561
let immutable_digester = DumbImmutableDigester::new("snapshot-digest-KO", true);
562+
563+
let mut dep_builder = get_dep_builder(Arc::new(http_client));
551564
dep_builder.certificate_verifier = Some(Arc::new(certificate_verifier));
552565
dep_builder.immutable_digester = Some(Arc::new(immutable_digester));
553566
let snapshot_service = dep_builder.get_snapshot_service().await.unwrap();
567+
554568
let mut signed_entity = FromSnapshotMessageAdapter::adapt(get_snapshot_message());
555569
signed_entity.artifact.digest = "digest-10".to_string();
556570

@@ -561,6 +575,7 @@ mod tests {
561575
"1234567890".repeat(124).as_str(),
562576
&test_path,
563577
);
578+
564579
let err = snapshot_service
565580
.download(
566581
&signed_entity,
@@ -600,14 +615,15 @@ mod tests {
600615
let test_path = std::env::temp_dir().join("test_download_snapshot_dir_already_exists");
601616
let _ = std::fs::remove_dir_all(&test_path);
602617
create_dir_all(test_path.join("db")).unwrap();
618+
603619
let http_client = MockAggregatorHTTPClient::new();
604-
let http_client = Arc::new(http_client);
605-
let mut dep_builder = get_dep_builder(http_client);
620+
let mut dep_builder = get_dep_builder(Arc::new(http_client));
606621
let snapshot_service = dep_builder.get_snapshot_service().await.unwrap();
607622

608623
let (_, verifier) = setup_genesis();
609624
let genesis_verification_key = verifier.to_verification_key();
610625
let snapshot = FromSnapshotMessageAdapter::adapt(get_snapshot_message());
626+
611627
let err = snapshot_service
612628
.download(
613629
&snapshot,

0 commit comments

Comments
 (0)