diff --git a/.gitignore b/.gitignore index 1b72444a..e5acaa84 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /Cargo.lock /target +*.tar diff --git a/Cargo.toml b/Cargo.toml index d5680c2d..e2ede780 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,11 +21,13 @@ anyhow = { version = "1.0.97", default-features = false } async-compression = { version = "0.4.22", default-features = false, features = ["tokio", "zstd", "gzip"] } clap = { version = "4.5.32", default-features = false, features = ["std", "help", "usage", "derive"] } containers-image-proxy = "0.7.0" +crossbeam = "0.8.4" env_logger = "0.11.7" hex = "0.4.3" indicatif = { version = "0.17.11", features = ["tokio"] } log = "0.4.27" oci-spec = "0.7.1" +rayon = "1.10.0" regex-automata = { version = "0.4.9", default-features = false } rustix = { version = "1.0.3", features = ["fs", "mount", "process"] } serde = "1.0.219" diff --git a/src/lib.rs b/src/lib.rs index 9edd9a12..4e1d9345 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ pub mod repository; pub mod selabel; pub mod splitstream; pub mod util; +pub mod zstd_encoder; /// All files that contain 64 or fewer bytes (size <= INLINE_CONTENT_MAX) should be stored inline /// in the erofs image (and also in splitstreams). All files with 65 or more bytes (size > MAX) diff --git a/src/oci/mod.rs b/src/oci/mod.rs index 8bc8fdc2..8d36d359 100644 --- a/src/oci/mod.rs +++ b/src/oci/mod.rs @@ -18,8 +18,12 @@ use crate::{ fsverity::Sha256HashValue, oci::tar::{get_entry, split_async}, repository::Repository, - splitstream::DigestMap, + splitstream::{ + handle_external_object, DigestMap, EnsureObjectMessages, ResultChannelReceiver, + ResultChannelSender, WriterMessages, + }, util::parse_sha256, + zstd_encoder, }; pub fn import_layer( @@ -83,6 +87,7 @@ impl<'repo> ImageOp<'repo> { let proxy = containers_image_proxy::ImageProxy::new_with_config(config).await?; let img = proxy.open_image(imgref).await.context("Opening image")?; let progress = MultiProgress::new(); + Ok(ImageOp { repo, proxy, @@ -95,47 +100,49 @@ impl<'repo> ImageOp<'repo> { &self, layer_sha256: &Sha256HashValue, descriptor: &Descriptor, - ) -> Result { + layer_num: usize, + object_sender: crossbeam::channel::Sender, + ) -> Result<()> { // We need to use the per_manifest descriptor to download the compressed layer but it gets // stored in the repository via the per_config descriptor. Our return value is the // fsverity digest for the corresponding splitstream. - if let Some(layer_id) = self.repo.check_stream(layer_sha256)? { - self.progress - .println(format!("Already have layer {}", hex::encode(layer_sha256)))?; - Ok(layer_id) - } else { - // Otherwise, we need to fetch it... - let (blob_reader, driver) = self.proxy.get_descriptor(&self.img, descriptor).await?; - - // See https://github.com/containers/containers-image-proxy-rs/issues/71 - let blob_reader = blob_reader.take(descriptor.size()); - - let bar = self.progress.add(ProgressBar::new(descriptor.size())); - bar.set_style(ProgressStyle::with_template("[eta {eta}] {bar:40.cyan/blue} {decimal_bytes:>7}/{decimal_total_bytes:7} {msg}") - .unwrap() - .progress_chars("##-")); - let progress = bar.wrap_async_read(blob_reader); - self.progress - .println(format!("Fetching layer {}", hex::encode(layer_sha256)))?; + // Otherwise, we need to fetch it... + let (blob_reader, driver) = self.proxy.get_descriptor(&self.img, descriptor).await?; + + // See https://github.com/containers/containers-image-proxy-rs/issues/71 + let blob_reader = blob_reader.take(descriptor.size()); + + let bar = self.progress.add(ProgressBar::new(descriptor.size())); + bar.set_style( + ProgressStyle::with_template( + "[eta {eta}] {bar:40.cyan/blue} {decimal_bytes:>7}/{decimal_total_bytes:7} {msg}", + ) + .unwrap() + .progress_chars("##-"), + ); + let progress = bar.wrap_async_read(blob_reader); + self.progress + .println(format!("Fetching layer {}", hex::encode(layer_sha256)))?; + + let mut splitstream = + self.repo + .create_stream(Some(*layer_sha256), None, Some(object_sender)); + match descriptor.media_type() { + MediaType::ImageLayer => { + split_async(progress, &mut splitstream, layer_num).await?; + } + MediaType::ImageLayerGzip => { + split_async(GzipDecoder::new(progress), &mut splitstream, layer_num).await?; + } + MediaType::ImageLayerZstd => { + split_async(ZstdDecoder::new(progress), &mut splitstream, layer_num).await?; + } + other => bail!("Unsupported layer media type {:?}", other), + }; + driver.await?; - let mut splitstream = self.repo.create_stream(Some(*layer_sha256), None); - match descriptor.media_type() { - MediaType::ImageLayer => { - split_async(progress, &mut splitstream).await?; - } - MediaType::ImageLayerGzip => { - split_async(GzipDecoder::new(progress), &mut splitstream).await?; - } - MediaType::ImageLayerZstd => { - split_async(ZstdDecoder::new(progress), &mut splitstream).await?; - } - other => bail!("Unsupported layer media type {:?}", other), - }; - let layer_id = self.repo.write_stream(splitstream, None)?; - driver.await?; - Ok(layer_id) - } + Ok(()) } pub async fn ensure_config( @@ -154,7 +161,6 @@ impl<'repo> ImageOp<'repo> { } else { // We need to add the config to the repo. We need to parse the config and make sure we // have all of the layers first. - // self.progress .println(format!("Fetching config {}", hex::encode(config_sha256)))?; @@ -169,19 +175,40 @@ impl<'repo> ImageOp<'repo> { let raw_config = config?; let config = ImageConfiguration::from_reader(&raw_config[..])?; + let (done_chan_sender, done_chan_recver, object_sender) = + self.spawn_threads(&config)?; + let mut config_maps = DigestMap::new(); + + let mut idx = 0; + for (mld, cld) in zip(manifest_layers, config.rootfs().diff_ids()) { let layer_sha256 = sha256_from_digest(cld)?; - let layer_id = self - .ensure_layer(&layer_sha256, mld) - .await - .with_context(|| format!("Failed to fetch layer {cld} via {mld:?}"))?; + + if let Some(layer_id) = self.repo.check_stream(&layer_sha256)? { + self.progress + .println(format!("Already have layer {}", hex::encode(layer_sha256)))?; + + config_maps.insert(&layer_sha256, &layer_id); + } else { + self.ensure_layer(&layer_sha256, mld, idx, object_sender.clone()) + .await + .with_context(|| format!("Failed to fetch layer {cld} via {mld:?}"))?; + + idx += 1; + } + } + + drop(done_chan_sender); + + while let Ok(res) = done_chan_recver.recv() { + let (layer_sha256, layer_id) = res?; config_maps.insert(&layer_sha256, &layer_id); } - let mut splitstream = self - .repo - .create_stream(Some(config_sha256), Some(config_maps)); + let mut splitstream = + self.repo + .create_stream(Some(config_sha256), Some(config_maps), None); splitstream.write_inline(&raw_config); let config_id = self.repo.write_stream(splitstream, None)?; @@ -189,6 +216,126 @@ impl<'repo> ImageOp<'repo> { } } + fn spawn_threads( + &self, + config: &ImageConfiguration, + ) -> Result<( + ResultChannelSender, + ResultChannelReceiver, + crossbeam::channel::Sender, + )> { + use crossbeam::channel::{unbounded, Receiver, Sender}; + + let mut encoder_threads = 2; + let external_object_writer_threads = 4; + + let chunk_len = config.rootfs().diff_ids().len().div_ceil(encoder_threads); + + // Divide the layers into chunks of some specific size so each worker + // thread can work on multiple deterministic layers + let diff_ids: Vec = config + .rootfs() + .diff_ids() + .iter() + .map(|x| sha256_from_digest(x)) + .collect::, _>>()?; + + let mut unhandled_layers = vec![]; + + // This becomes pretty unreadable with a filter,map chain + for id in diff_ids { + let layer_exists = self.repo.check_stream(&id)?; + + if layer_exists.is_none() { + unhandled_layers.push(id); + } + } + + let mut chunks: Vec> = unhandled_layers + .chunks(chunk_len) + .map(|x| x.to_vec()) + .collect(); + + // Mapping from layer_id -> index in writer_channels + // This is to make sure that all messages relating to a particular layer + // always reach the same writer + let layers_to_chunks = chunks + .iter() + .enumerate() + .flat_map(|(i, chunk)| std::iter::repeat_n(i, chunk.len()).collect::>()) + .collect::>(); + + encoder_threads = encoder_threads.min(chunks.len()); + + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(encoder_threads + external_object_writer_threads) + .build() + .unwrap(); + + // We need this as writers have internal state that can't be shared between threads + // + // We'll actually need as many writers (not writer threads, but writer instances) as there are layers. + let zstd_writer_channels: Vec<(Sender, Receiver)> = + (0..encoder_threads).map(|_| unbounded()).collect(); + + let (object_sender, object_receiver) = unbounded::(); + + // (layer_sha256, layer_id) + let (done_chan_sender, done_chan_recver) = + std::sync::mpsc::channel::>(); + + for i in 0..encoder_threads { + let repository = self.repo.try_clone().unwrap(); + let object_sender = object_sender.clone(); + let done_chan_sender = done_chan_sender.clone(); + let chunk = std::mem::take(&mut chunks[i]); + let receiver = zstd_writer_channels[i].1.clone(); + + pool.spawn({ + move || { + let start = i * (chunk_len); + let end = start + chunk_len; + + let enc = zstd_encoder::MultipleZstdWriters::new( + chunk, + repository, + object_sender, + done_chan_sender, + ); + + if let Err(e) = enc.recv_data(receiver, start, end) { + eprintln!("zstd_encoder returned with error: {}", e) + } + } + }); + } + + for _ in 0..external_object_writer_threads { + pool.spawn({ + let repository = self.repo.try_clone().unwrap(); + let zstd_writer_channels = zstd_writer_channels + .iter() + .map(|(s, _)| s.clone()) + .collect::>(); + let layers_to_chunks = layers_to_chunks.clone(); + let external_object_receiver = object_receiver.clone(); + + move || { + if let Err(e) = handle_external_object( + repository, + external_object_receiver, + zstd_writer_channels, + layers_to_chunks, + ) { + eprintln!("handle_external_object returned with error: {}", e); + } + } + }); + } + + Ok((done_chan_sender, done_chan_recver, object_sender)) + } + pub async fn pull(&self) -> Result<(Sha256HashValue, Sha256HashValue)> { let (_manifest_digest, raw_manifest) = self .proxy @@ -201,6 +348,7 @@ impl<'repo> ImageOp<'repo> { let manifest = ImageManifest::from_reader(raw_manifest.as_slice())?; let config_descriptor = manifest.config(); let layers = manifest.layers(); + self.ensure_config(layers, config_descriptor) .await .with_context(|| format!("Failed to pull config {config_descriptor:?}")) @@ -280,7 +428,7 @@ pub fn write_config( let json = config.to_string()?; let json_bytes = json.as_bytes(); let sha256 = hash(json_bytes); - let mut stream = repo.create_stream(Some(sha256), Some(refs)); + let mut stream = repo.create_stream(Some(sha256), Some(refs), None); stream.write_inline(json_bytes); let id = repo.write_stream(stream, None)?; Ok((sha256, id)) diff --git a/src/oci/tar.rs b/src/oci/tar.rs index 718a1670..ea118c6e 100644 --- a/src/oci/tar.rs +++ b/src/oci/tar.rs @@ -8,7 +8,7 @@ use std::{ path::PathBuf, }; -use anyhow::{bail, ensure, Result}; +use anyhow::{bail, ensure, Context, Result}; use rustix::fs::makedev; use tar::{EntryType, Header, PaxExtensions}; use tokio::io::{AsyncRead, AsyncReadExt}; @@ -16,7 +16,9 @@ use tokio::io::{AsyncRead, AsyncReadExt}; use crate::{ dumpfile, image::{LeafContent, RegularFile, Stat}, - splitstream::{SplitStreamData, SplitStreamReader, SplitStreamWriter}, + splitstream::{ + EnsureObjectMessages, FinishMessage, SplitStreamData, SplitStreamReader, SplitStreamWriter, + }, util::{read_exactish, read_exactish_async}, INLINE_CONTENT_MAX, }; @@ -60,7 +62,7 @@ pub fn split(tar_stream: &mut R, writer: &mut SplitStreamWriter) -> Res if header.entry_type() == EntryType::Regular && actual_size > INLINE_CONTENT_MAX { // non-empty regular file: store the data in the object store let padding = buffer.split_off(actual_size); - writer.write_external(&buffer, padding)?; + writer.write_external(buffer, padding, 0, 0)?; } else { // else: store the data inline in the split stream writer.write_inline(&buffer); @@ -72,7 +74,10 @@ pub fn split(tar_stream: &mut R, writer: &mut SplitStreamWriter) -> Res pub async fn split_async( mut tar_stream: impl AsyncRead + Unpin, writer: &mut SplitStreamWriter<'_>, + layer_num: usize, ) -> Result<()> { + let mut seq_num = 0; + while let Some(header) = read_header_async(&mut tar_stream).await? { // the header always gets stored as inline data writer.write_inline(header.as_bytes()); @@ -90,12 +95,24 @@ pub async fn split_async( if header.entry_type() == EntryType::Regular && actual_size > INLINE_CONTENT_MAX { // non-empty regular file: store the data in the object store let padding = buffer.split_off(actual_size); - writer.write_external(&buffer, padding)?; + writer.write_external(buffer, padding, seq_num, layer_num)?; + seq_num += 1; } else { // else: store the data inline in the split stream writer.write_inline(&buffer); } } + + if let Some(sender) = &writer.object_sender { + sender + .send(EnsureObjectMessages::Finish(FinishMessage { + data: std::mem::take(&mut writer.inline_content), + total_msgs: seq_num, + layer_num, + })) + .with_context(|| format!("Failed to send final message for layer {layer_num}"))?; + } + Ok(()) } diff --git a/src/repository.rs b/src/repository.rs index d30fc9f3..0ea466eb 100644 --- a/src/repository.rs +++ b/src/repository.rs @@ -2,7 +2,7 @@ use std::{ collections::HashSet, ffi::CStr, fs::File, - io::{ErrorKind, Read, Write}, + io::{self, ErrorKind, Read, Write}, os::fd::{AsFd, OwnedFd}, path::{Path, PathBuf}, }; @@ -23,7 +23,7 @@ use crate::{ Sha256HashValue, }, mount::mount_composefs_at, - splitstream::{DigestMap, SplitStreamReader, SplitStreamWriter}, + splitstream::{DigestMap, EnsureObjectMessages, SplitStreamReader, SplitStreamWriter}, util::{parse_sha256, proc_self_fd}, }; @@ -46,6 +46,12 @@ impl Repository { ) } + pub fn try_clone(&self) -> io::Result { + Ok(Self { + repository: self.repository.try_clone()?, + }) + } + pub fn open_path(dirfd: impl AsFd, path: impl AsRef) -> Result { let path = path.as_ref(); @@ -137,12 +143,13 @@ impl Repository { /// Creates a SplitStreamWriter for writing a split stream. /// You should write the data to the returned object and then pass it to .store_stream() to /// store the result. - pub fn create_stream( + pub(crate) fn create_stream( &self, sha256: Option, maps: Option, + object_sender: Option>, ) -> SplitStreamWriter { - SplitStreamWriter::new(self, maps, sha256) + SplitStreamWriter::new(self, maps, sha256, object_sender) } fn parse_object_path(path: impl AsRef<[u8]>) -> Result { @@ -165,7 +172,7 @@ impl Repository { Ok(result) } - fn format_object_path(id: &Sha256HashValue) -> String { + pub fn format_object_path(id: &Sha256HashValue) -> String { format!("objects/{:02x}/{}", id[0], hex::encode(&id[1..])) } @@ -230,9 +237,10 @@ impl Repository { writer: SplitStreamWriter, reference: Option<&str>, ) -> Result { - let Some((.., ref sha256)) = writer.sha256 else { + let Some((.., ref sha256)) = writer.get_sha_builder() else { bail!("Writer doesn't have sha256 enabled"); }; + let stream_path = format!("streams/{}", hex::encode(sha256)); let object_id = writer.done()?; let object_path = Repository::format_object_path(&object_id); @@ -280,7 +288,7 @@ impl Repository { let object_id = match self.has_stream(sha256)? { Some(id) => id, None => { - let mut writer = self.create_stream(Some(*sha256), None); + let mut writer = self.create_stream(Some(*sha256), None, None); callback(&mut writer)?; let object_id = writer.done()?; diff --git a/src/splitstream.rs b/src/splitstream.rs index 319ffe09..833540ba 100644 --- a/src/splitstream.rs +++ b/src/splitstream.rs @@ -5,14 +5,17 @@ use std::io::{BufReader, Read, Write}; -use anyhow::{bail, Result}; -use sha2::{Digest, Sha256}; -use zstd::stream::{read::Decoder, write::Encoder}; +use crossbeam::channel::{Receiver as CrossbeamReceiver, Sender as CrossbeamSender}; + +use anyhow::{bail, Context, Result}; +use sha2::Sha256; +use zstd::stream::read::Decoder; use crate::{ fsverity::{FsVerityHashValue, Sha256HashValue}, repository::Repository, util::read_exactish, + zstd_encoder::ZstdWriter, }; #[derive(Debug)] @@ -60,9 +63,9 @@ impl DigestMap { pub struct SplitStreamWriter<'a> { repo: &'a Repository, - inline_content: Vec, - writer: Encoder<'a, Vec>, - pub sha256: Option<(Sha256, Sha256HashValue)>, + pub(crate) inline_content: Vec, + writer: ZstdWriter, + pub(crate) object_sender: Option>, } impl std::fmt::Debug for SplitStreamWriter<'_> { @@ -71,97 +74,176 @@ impl std::fmt::Debug for SplitStreamWriter<'_> { f.debug_struct("SplitStreamWriter") .field("repo", &self.repo) .field("inline_content", &self.inline_content) - .field("sha256", &self.sha256) .finish() } } +#[derive(Debug)] +pub(crate) struct FinishMessage { + pub(crate) data: Vec, + pub(crate) total_msgs: usize, + pub(crate) layer_num: usize, +} + +#[derive(Eq, Debug)] +pub(crate) struct WriterMessagesData { + pub(crate) digest: Sha256HashValue, + pub(crate) object_data: SplitStreamWriterSenderData, +} + +#[derive(Debug)] +pub(crate) enum WriterMessages { + WriteData(WriterMessagesData), + Finish(FinishMessage), +} + +impl PartialEq for WriterMessagesData { + fn eq(&self, other: &Self) -> bool { + self.object_data.seq_num.eq(&other.object_data.seq_num) + } +} + +impl PartialOrd for WriterMessagesData { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for WriterMessagesData { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.object_data.seq_num.cmp(&other.object_data.seq_num) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct SplitStreamWriterSenderData { + pub(crate) external_data: Vec, + pub(crate) inline_content: Vec, + pub(crate) seq_num: usize, + pub(crate) layer_num: usize, +} +pub(crate) enum EnsureObjectMessages { + Data(SplitStreamWriterSenderData), + Finish(FinishMessage), +} + +pub(crate) type ResultChannelSender = + std::sync::mpsc::Sender>; +pub(crate) type ResultChannelReceiver = + std::sync::mpsc::Receiver>; + +pub(crate) fn handle_external_object( + repository: Repository, + external_object_receiver: CrossbeamReceiver, + zstd_writer_channels: Vec>, + layers_to_chunks: Vec, +) -> Result<()> { + while let Ok(data) = external_object_receiver.recv() { + match data { + EnsureObjectMessages::Data(data) => { + let digest = repository.ensure_object(&data.external_data)?; + let layer_num = data.layer_num; + let writer_chan_sender = &zstd_writer_channels[layers_to_chunks[layer_num]]; + + let msg = WriterMessagesData { + digest, + object_data: data, + }; + + // `send` only fails if all receivers are dropped + writer_chan_sender + .send(WriterMessages::WriteData(msg)) + .with_context(|| format!("Failed to send message for layer {layer_num}"))?; + } + + EnsureObjectMessages::Finish(final_msg) => { + let layer_num = final_msg.layer_num; + let writer_chan_sender = &zstd_writer_channels[layers_to_chunks[layer_num]]; + + writer_chan_sender + .send(WriterMessages::Finish(final_msg)) + .with_context(|| { + format!("Failed to send final message for layer {layer_num}") + })?; + } + } + } + + Ok(()) +} + impl SplitStreamWriter<'_> { - pub fn new( + pub(crate) fn new( repo: &Repository, refs: Option, sha256: Option, + object_sender: Option>, ) -> SplitStreamWriter { - // SAFETY: we surely can't get an error writing the header to a Vec - let mut writer = Encoder::new(vec![], 0).unwrap(); - - match refs { - Some(DigestMap { map }) => { - writer.write_all(&(map.len() as u64).to_le_bytes()).unwrap(); - for ref entry in map { - writer.write_all(&entry.body).unwrap(); - writer.write_all(&entry.verity).unwrap(); - } - } - None => { - writer.write_all(&0u64.to_le_bytes()).unwrap(); - } - } + let inline_content = vec![]; SplitStreamWriter { repo, - inline_content: vec![], - writer, - sha256: sha256.map(|x| (Sha256::new(), x)), + inline_content, + object_sender, + writer: ZstdWriter::new(sha256, refs, repo.try_clone().unwrap()), } } - fn write_fragment(writer: &mut impl Write, size: usize, data: &[u8]) -> Result<()> { - writer.write_all(&(size as u64).to_le_bytes())?; - Ok(writer.write_all(data)?) + pub fn get_sha_builder(&self) -> &Option<(Sha256, Sha256HashValue)> { + &self.writer.sha256_builder } /// flush any buffered inline data, taking new_value as the new value of the buffer fn flush_inline(&mut self, new_value: Vec) -> Result<()> { - if !self.inline_content.is_empty() { - SplitStreamWriter::write_fragment( - &mut self.writer, - self.inline_content.len(), - &self.inline_content, - )?; - self.inline_content = new_value; - } + self.writer.flush_inline(&self.inline_content)?; + self.inline_content = new_value; Ok(()) } /// really, "add inline content to the buffer" /// you need to call .flush_inline() later pub fn write_inline(&mut self, data: &[u8]) { - if let Some((ref mut sha256, ..)) = self.sha256 { - sha256.update(data); - } self.inline_content.extend(data); } - /// write a reference to external data to the stream. If the external data had padding in the - /// stream which is not stored in the object then pass it here as well and it will be stored - /// inline after the reference. - fn write_reference(&mut self, reference: Sha256HashValue, padding: Vec) -> Result<()> { - // Flush the inline data before we store the external reference. Any padding from the - // external data becomes the start of a new inline block. - self.flush_inline(padding)?; + pub fn write_external( + &mut self, + data: Vec, + padding: Vec, + seq_num: usize, + layer_num: usize, + ) -> Result<()> { + match &self.object_sender { + Some(sender) => { + let inline_content = std::mem::replace(&mut self.inline_content, padding); + + sender + .send(EnsureObjectMessages::Data(SplitStreamWriterSenderData { + external_data: data, + inline_content, + seq_num, + layer_num, + })) + .with_context(|| { + format!("Failed to send message to writer for layer {layer_num}") + })?; + } + + None => { + self.flush_inline(padding)?; + self.writer.update_sha(&data); - SplitStreamWriter::write_fragment(&mut self.writer, 0, &reference) - } + let id = self.repo.ensure_object(&data)?; + self.writer.write_fragment(0, &id)?; + } + }; - pub fn write_external(&mut self, data: &[u8], padding: Vec) -> Result<()> { - if let Some((ref mut sha256, ..)) = self.sha256 { - sha256.update(data); - sha256.update(&padding); - } - let id = self.repo.ensure_object(data)?; - self.write_reference(id, padding) + Ok(()) } pub fn done(mut self) -> Result { self.flush_inline(vec![])?; - - if let Some((context, expected)) = self.sha256 { - if Into::::into(context.finalize()) != expected { - bail!("Content doesn't have expected SHA256 hash value!"); - } - } - + self.writer.finalize_sha256_builder()?; self.repo.ensure_object(&self.writer.finish()?) } } diff --git a/src/zstd_encoder.rs b/src/zstd_encoder.rs new file mode 100644 index 00000000..065b5ca5 --- /dev/null +++ b/src/zstd_encoder.rs @@ -0,0 +1,375 @@ +use std::{ + cmp::Reverse, + collections::BinaryHeap, + io::{self, Write}, +}; + +use sha2::{Digest, Sha256}; + +use anyhow::{bail, Context, Result}; +use zstd::Encoder; + +use crate::{ + fsverity::Sha256HashValue, + repository::Repository, + splitstream::{ + DigestMap, EnsureObjectMessages, FinishMessage, ResultChannelSender, + SplitStreamWriterSenderData, WriterMessages, WriterMessagesData, + }, +}; + +pub(crate) struct ZstdWriter { + writer: zstd::Encoder<'static, Vec>, + repository: Repository, + pub(crate) sha256_builder: Option<(Sha256, Sha256HashValue)>, + mode: WriterMode, +} + +pub(crate) struct MultiThreadedState { + last: usize, + heap: BinaryHeap>, + final_sha: Option, + final_message: Option, + object_sender: crossbeam::channel::Sender, + final_result_sender: ResultChannelSender, +} + +pub(crate) enum WriterMode { + SingleThreaded, + MultiThreaded(MultiThreadedState), +} + +pub(crate) struct MultipleZstdWriters { + writers: Vec, +} + +impl MultipleZstdWriters { + pub fn new( + sha256: Vec, + repository: Repository, + object_sender: crossbeam::channel::Sender, + final_result_sender: ResultChannelSender, + ) -> Self { + Self { + writers: sha256 + .iter() + .map(|sha| { + ZstdWriter::new_threaded( + Some(*sha), + None, + repository.try_clone().unwrap(), + object_sender.clone(), + final_result_sender.clone(), + ) + }) + .collect(), + } + } + + pub fn recv_data( + mut self, + enc_chan_recvr: crossbeam::channel::Receiver, + layer_num_start: usize, + layer_num_end: usize, + ) -> Result<()> { + assert!(layer_num_end >= layer_num_start); + + let total_writers = self.writers.len(); + + // layers_to_writers[layer_num] = writer_idx + // Faster than a hash map + let mut layers_to_writers: Vec = vec![0; layer_num_end]; + + for (idx, i) in (layer_num_start..layer_num_end).enumerate() { + layers_to_writers[i] = idx + } + + let mut finished_writers = 0; + + while let Ok(data) = enc_chan_recvr.recv() { + let layer_num = match &data { + WriterMessages::WriteData(d) => d.object_data.layer_num, + WriterMessages::Finish(d) => d.layer_num, + }; + + assert!(layer_num >= layer_num_start && layer_num <= layer_num_end); + + match self.writers[layers_to_writers[layer_num]].handle_received_data(data) { + Ok(finished) => { + if finished { + finished_writers += 1 + } + } + + Err(e) => { + return Err(e); + } + } + + if finished_writers == total_writers { + break; + } + } + + Ok(()) + } +} + +impl ZstdWriter { + pub fn new_threaded( + sha256: Option, + refs: Option, + repository: Repository, + object_sender: crossbeam::channel::Sender, + final_result_sender: ResultChannelSender, + ) -> Self { + Self { + writer: ZstdWriter::instantiate_writer(refs), + repository, + sha256_builder: sha256.map(|x| (Sha256::new(), x)), + + mode: WriterMode::MultiThreaded(MultiThreadedState { + final_sha: None, + last: 0, + heap: BinaryHeap::new(), + final_message: None, + object_sender, + final_result_sender, + }), + } + } + + pub fn new( + sha256: Option, + refs: Option, + repository: Repository, + ) -> Self { + Self { + writer: ZstdWriter::instantiate_writer(refs), + repository, + sha256_builder: sha256.map(|x| (Sha256::new(), x)), + mode: WriterMode::SingleThreaded, + } + } + + fn get_state(&self) -> &MultiThreadedState { + let WriterMode::MultiThreaded(state) = &self.mode else { + panic!("`get_state` called on a single threaded writer") + }; + + state + } + + fn get_state_mut(&mut self) -> &mut MultiThreadedState { + let WriterMode::MultiThreaded(state) = &mut self.mode else { + panic!("`get_state_mut` called on a single threaded writer") + }; + + state + } + + fn instantiate_writer(refs: Option) -> zstd::Encoder<'static, Vec> { + let mut writer = zstd::Encoder::new(vec![], 0).unwrap(); + + match refs { + Some(DigestMap { map }) => { + writer.write_all(&(map.len() as u64).to_le_bytes()).unwrap(); + + for ref entry in map { + writer.write_all(&entry.body).unwrap(); + writer.write_all(&entry.verity).unwrap(); + } + } + + None => { + writer.write_all(&0u64.to_le_bytes()).unwrap(); + } + } + + writer + } + + pub(crate) fn write_fragment(&mut self, size: usize, data: &[u8]) -> Result<()> { + self.writer.write_all(&(size as u64).to_le_bytes())?; + Ok(self.writer.write_all(data)?) + } + + pub(crate) fn update_sha(&mut self, data: &[u8]) { + if let Some((sha256, ..)) = &mut self.sha256_builder { + sha256.update(data); + } + } + + /// Writes all the data in `inline_content`, updating the internal SHA + pub(crate) fn flush_inline(&mut self, inline_content: &[u8]) -> Result<()> { + if inline_content.is_empty() { + return Ok(()); + } + + self.update_sha(inline_content); + + self.write_fragment(inline_content.len(), inline_content)?; + + Ok(()) + } + + /// Keeps popping from the heap until it reaches the message with the largest seq_num, n, + /// given we have every message with seq_num < n + fn write_message(&mut self) -> Result<()> { + loop { + // Gotta keep lifetime of the destructring inside the loop + let state = self.get_state_mut(); + + let Some(data) = state.heap.peek() else { + break; + }; + + if data.0.object_data.seq_num != state.last { + break; + } + + let data = state.heap.pop().unwrap(); + state.last += 1; + + self.flush_inline(&data.0.object_data.inline_content)?; + + if let Some((sha256, ..)) = &mut self.sha256_builder { + sha256.update(data.0.object_data.external_data); + } + + self.write_fragment(0, &data.0.digest)?; + } + + let final_msg = self.get_state_mut().final_message.take(); + + if let Some(final_msg) = final_msg { + // Haven't received all the messages so we reset the final_message field + if self.get_state().last < final_msg.total_msgs { + self.get_state_mut().final_message = Some(final_msg); + return Ok(()); + } + + let sha = self.handle_final_message(final_msg).unwrap(); + self.get_state_mut().final_sha = Some(sha); + } + + Ok(()) + } + + fn add_message_to_heap(&mut self, recv_data: WriterMessagesData) { + self.get_state_mut().heap.push(Reverse(recv_data)); + } + + pub(crate) fn finalize_sha256_builder(&mut self) -> Result { + let sha256_builder = self.sha256_builder.take(); + + if let Some((context, expected)) = sha256_builder { + let final_sha = Into::::into(context.finalize()); + + if final_sha != expected { + bail!( + "Content doesn't have expected SHA256 hash value!\nExpected: {}, final: {}", + hex::encode(expected), + hex::encode(final_sha) + ); + } + + return Ok(final_sha); + } + + bail!("SHA not enabled for writer"); + } + + /// Calls `finish` on the internal writer + pub(crate) fn finish(self) -> io::Result> { + self.writer.finish() + } + + fn handle_final_message(&mut self, final_message: FinishMessage) -> Result { + self.flush_inline(&final_message.data)?; + + let writer = std::mem::replace(&mut self.writer, Encoder::new(vec![], 0).unwrap()); + let finished = writer.finish()?; + + let sha = self.finalize_sha256_builder()?; + + self.get_state() + .object_sender + .send(EnsureObjectMessages::Data(SplitStreamWriterSenderData { + external_data: finished, + inline_content: vec![], + seq_num: final_message.total_msgs, + layer_num: final_message.layer_num, + })) + .context("Failed to send object finalize message")?; + + Ok(sha) + } + + // Cannot `take` ownership of self, as we'll need it later + // + /// Returns whether we have finished writing all the data or not + fn handle_received_data(&mut self, data: WriterMessages) -> Result { + match data { + WriterMessages::WriteData(recv_data) => { + if let Some(final_sha) = self.get_state().final_sha { + // We've already received the final messae + let stream_path = format!("streams/{}", hex::encode(final_sha)); + + let object_path = Repository::format_object_path(&recv_data.digest); + self.repository.ensure_symlink(&stream_path, &object_path)?; + + self.get_state() + .final_result_sender + .send(Ok((final_sha, recv_data.digest))) + .with_context(|| { + format!("Failed to send result for layer {final_sha:?}") + })?; + + return Ok(true); + } + + let seq_num = recv_data.object_data.seq_num; + + self.add_message_to_heap(recv_data); + + if seq_num != self.get_state().last { + return Ok(false); + } + + self.write_message()?; + } + + WriterMessages::Finish(final_msg) => { + if self.get_state().final_message.is_some() { + panic!( + "Received two finalize messages for layer {}. Previous final message {:?}", + final_msg.layer_num, + self.get_state().final_message + ); + } + + // write all pending messages + if !self.get_state().heap.is_empty() { + self.write_message()?; + } + + let total_msgs = final_msg.total_msgs; + + if self.get_state().last >= total_msgs { + // We have received all the messages + // Finalize + let final_sha = self.handle_final_message(final_msg).unwrap(); + self.get_state_mut().final_sha = Some(final_sha); + } else { + // Haven't received all messages. Store the final message until we have + // received all + let state = self.get_state_mut(); + state.final_message = Some(final_msg); + } + } + } + + Ok(false) + } +}