Skip to content

Multithreaded SplitStream creation #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/Cargo.lock
/target
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this change belongs to a different commit...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, moved it to a separate commit

*.tar
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ anyhow = { version = "1.0.97", default-features = false }
async-compression = { version = "0.4.22", default-features = false, features = ["tokio", "zstd", "gzip"] }
clap = { version = "4.5.32", default-features = false, features = ["std", "help", "usage", "derive"] }
containers-image-proxy = "0.7.0"
crossbeam = "0.8.4"
env_logger = "0.11.7"
hex = "0.4.3"
indicatif = { version = "0.17.11", features = ["tokio"] }
log = "0.4.27"
oci-spec = "0.7.1"
rayon = "1.10.0"
regex-automata = { version = "0.4.9", default-features = false }
rustix = { version = "1.0.3", features = ["fs", "mount", "process"] }
serde = "1.0.219"
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub mod repository;
pub mod selabel;
pub mod splitstream;
pub mod util;
pub mod zstd_encoder;

/// All files that contain 64 or fewer bytes (size <= INLINE_CONTENT_MAX) should be stored inline
/// in the erofs image (and also in splitstreams). All files with 65 or more bytes (size > MAX)
Expand Down
240 changes: 194 additions & 46 deletions src/oci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ use crate::{
fsverity::Sha256HashValue,
oci::tar::{get_entry, split_async},
repository::Repository,
splitstream::DigestMap,
splitstream::{
handle_external_object, DigestMap, EnsureObjectMessages, ResultChannelReceiver,
ResultChannelSender, WriterMessages,
},
util::parse_sha256,
zstd_encoder,
};

pub fn import_layer(
Expand Down Expand Up @@ -83,6 +87,7 @@ impl<'repo> ImageOp<'repo> {
let proxy = containers_image_proxy::ImageProxy::new_with_config(config).await?;
let img = proxy.open_image(imgref).await.context("Opening image")?;
let progress = MultiProgress::new();

Ok(ImageOp {
repo,
proxy,
Expand All @@ -95,47 +100,49 @@ impl<'repo> ImageOp<'repo> {
&self,
layer_sha256: &Sha256HashValue,
descriptor: &Descriptor,
) -> Result<Sha256HashValue> {
layer_num: usize,
object_sender: crossbeam::channel::Sender<EnsureObjectMessages>,
) -> Result<()> {
// We need to use the per_manifest descriptor to download the compressed layer but it gets
// stored in the repository via the per_config descriptor. Our return value is the
// fsverity digest for the corresponding splitstream.

if let Some(layer_id) = self.repo.check_stream(layer_sha256)? {
self.progress
.println(format!("Already have layer {}", hex::encode(layer_sha256)))?;
Ok(layer_id)
} else {
// Otherwise, we need to fetch it...
let (blob_reader, driver) = self.proxy.get_descriptor(&self.img, descriptor).await?;

// See https://github.com/containers/containers-image-proxy-rs/issues/71
let blob_reader = blob_reader.take(descriptor.size());

let bar = self.progress.add(ProgressBar::new(descriptor.size()));
bar.set_style(ProgressStyle::with_template("[eta {eta}] {bar:40.cyan/blue} {decimal_bytes:>7}/{decimal_total_bytes:7} {msg}")
.unwrap()
.progress_chars("##-"));
let progress = bar.wrap_async_read(blob_reader);
self.progress
.println(format!("Fetching layer {}", hex::encode(layer_sha256)))?;
// Otherwise, we need to fetch it...
let (blob_reader, driver) = self.proxy.get_descriptor(&self.img, descriptor).await?;

// See https://github.com/containers/containers-image-proxy-rs/issues/71
let blob_reader = blob_reader.take(descriptor.size());

let bar = self.progress.add(ProgressBar::new(descriptor.size()));
bar.set_style(
ProgressStyle::with_template(
"[eta {eta}] {bar:40.cyan/blue} {decimal_bytes:>7}/{decimal_total_bytes:7} {msg}",
)
.unwrap()
.progress_chars("##-"),
);
let progress = bar.wrap_async_read(blob_reader);
self.progress
.println(format!("Fetching layer {}", hex::encode(layer_sha256)))?;

let mut splitstream =
self.repo
.create_stream(Some(*layer_sha256), None, Some(object_sender));
match descriptor.media_type() {
MediaType::ImageLayer => {
split_async(progress, &mut splitstream, layer_num).await?;
}
MediaType::ImageLayerGzip => {
split_async(GzipDecoder::new(progress), &mut splitstream, layer_num).await?;
}
MediaType::ImageLayerZstd => {
split_async(ZstdDecoder::new(progress), &mut splitstream, layer_num).await?;
}
other => bail!("Unsupported layer media type {:?}", other),
};
driver.await?;

let mut splitstream = self.repo.create_stream(Some(*layer_sha256), None);
match descriptor.media_type() {
MediaType::ImageLayer => {
split_async(progress, &mut splitstream).await?;
}
MediaType::ImageLayerGzip => {
split_async(GzipDecoder::new(progress), &mut splitstream).await?;
}
MediaType::ImageLayerZstd => {
split_async(ZstdDecoder::new(progress), &mut splitstream).await?;
}
other => bail!("Unsupported layer media type {:?}", other),
};
let layer_id = self.repo.write_stream(splitstream, None)?;
driver.await?;
Ok(layer_id)
}
Ok(())
}

pub async fn ensure_config(
Expand All @@ -154,7 +161,6 @@ impl<'repo> ImageOp<'repo> {
} else {
// We need to add the config to the repo. We need to parse the config and make sure we
// have all of the layers first.
//
self.progress
.println(format!("Fetching config {}", hex::encode(config_sha256)))?;

Expand All @@ -169,26 +175,167 @@ impl<'repo> ImageOp<'repo> {
let raw_config = config?;
let config = ImageConfiguration::from_reader(&raw_config[..])?;

let (done_chan_sender, done_chan_recver, object_sender) =
self.spawn_threads(&config)?;

let mut config_maps = DigestMap::new();

let mut idx = 0;

for (mld, cld) in zip(manifest_layers, config.rootfs().diff_ids()) {
let layer_sha256 = sha256_from_digest(cld)?;
let layer_id = self
.ensure_layer(&layer_sha256, mld)
.await
.with_context(|| format!("Failed to fetch layer {cld} via {mld:?}"))?;

if let Some(layer_id) = self.repo.check_stream(&layer_sha256)? {
self.progress
.println(format!("Already have layer {}", hex::encode(layer_sha256)))?;

config_maps.insert(&layer_sha256, &layer_id);
} else {
self.ensure_layer(&layer_sha256, mld, idx, object_sender.clone())
.await
.with_context(|| format!("Failed to fetch layer {cld} via {mld:?}"))?;

idx += 1;
}
}

drop(done_chan_sender);

while let Ok(res) = done_chan_recver.recv() {
let (layer_sha256, layer_id) = res?;
config_maps.insert(&layer_sha256, &layer_id);
}

let mut splitstream = self
.repo
.create_stream(Some(config_sha256), Some(config_maps));
let mut splitstream =
self.repo
.create_stream(Some(config_sha256), Some(config_maps), None);
splitstream.write_inline(&raw_config);
let config_id = self.repo.write_stream(splitstream, None)?;

Ok((config_sha256, config_id))
}
}

fn spawn_threads(
&self,
config: &ImageConfiguration,
) -> Result<(
ResultChannelSender,
ResultChannelReceiver,
crossbeam::channel::Sender<EnsureObjectMessages>,
)> {
use crossbeam::channel::{unbounded, Receiver, Sender};

let mut encoder_threads = 2;
let external_object_writer_threads = 4;

let chunk_len = config.rootfs().diff_ids().len().div_ceil(encoder_threads);

// Divide the layers into chunks of some specific size so each worker
// thread can work on multiple deterministic layers
let diff_ids: Vec<Sha256HashValue> = config
.rootfs()
.diff_ids()
.iter()
.map(|x| sha256_from_digest(x))
.collect::<Result<Vec<Sha256HashValue>, _>>()?;

let mut unhandled_layers = vec![];

// This becomes pretty unreadable with a filter,map chain
for id in diff_ids {
let layer_exists = self.repo.check_stream(&id)?;

if layer_exists.is_none() {
unhandled_layers.push(id);
}
}

let mut chunks: Vec<Vec<Sha256HashValue>> = unhandled_layers
.chunks(chunk_len)
.map(|x| x.to_vec())
.collect();

// Mapping from layer_id -> index in writer_channels
// This is to make sure that all messages relating to a particular layer
// always reach the same writer
let layers_to_chunks = chunks
.iter()
.enumerate()
.flat_map(|(i, chunk)| std::iter::repeat_n(i, chunk.len()).collect::<Vec<_>>())
.collect::<Vec<_>>();

encoder_threads = encoder_threads.min(chunks.len());

let pool = rayon::ThreadPoolBuilder::new()
.num_threads(encoder_threads + external_object_writer_threads)
.build()
.unwrap();

// We need this as writers have internal state that can't be shared between threads
//
// We'll actually need as many writers (not writer threads, but writer instances) as there are layers.
let zstd_writer_channels: Vec<(Sender<WriterMessages>, Receiver<WriterMessages>)> =
(0..encoder_threads).map(|_| unbounded()).collect();

let (object_sender, object_receiver) = unbounded::<EnsureObjectMessages>();

// (layer_sha256, layer_id)
let (done_chan_sender, done_chan_recver) =
std::sync::mpsc::channel::<Result<(Sha256HashValue, Sha256HashValue)>>();

for i in 0..encoder_threads {
let repository = self.repo.try_clone().unwrap();
let object_sender = object_sender.clone();
let done_chan_sender = done_chan_sender.clone();
let chunk = std::mem::take(&mut chunks[i]);
let receiver = zstd_writer_channels[i].1.clone();

pool.spawn({
move || {
let start = i * (chunk_len);
let end = start + chunk_len;

let enc = zstd_encoder::MultipleZstdWriters::new(
chunk,
repository,
object_sender,
done_chan_sender,
);

if let Err(e) = enc.recv_data(receiver, start, end) {
eprintln!("zstd_encoder returned with error: {}", e)
}
}
});
}

for _ in 0..external_object_writer_threads {
pool.spawn({
let repository = self.repo.try_clone().unwrap();
let zstd_writer_channels = zstd_writer_channels
.iter()
.map(|(s, _)| s.clone())
.collect::<Vec<_>>();
let layers_to_chunks = layers_to_chunks.clone();
let external_object_receiver = object_receiver.clone();

move || {
if let Err(e) = handle_external_object(
repository,
external_object_receiver,
zstd_writer_channels,
layers_to_chunks,
) {
eprintln!("handle_external_object returned with error: {}", e);
}
}
});
}

Ok((done_chan_sender, done_chan_recver, object_sender))
}

pub async fn pull(&self) -> Result<(Sha256HashValue, Sha256HashValue)> {
let (_manifest_digest, raw_manifest) = self
.proxy
Expand All @@ -201,6 +348,7 @@ impl<'repo> ImageOp<'repo> {
let manifest = ImageManifest::from_reader(raw_manifest.as_slice())?;
let config_descriptor = manifest.config();
let layers = manifest.layers();

self.ensure_config(layers, config_descriptor)
.await
.with_context(|| format!("Failed to pull config {config_descriptor:?}"))
Expand Down Expand Up @@ -280,7 +428,7 @@ pub fn write_config(
let json = config.to_string()?;
let json_bytes = json.as_bytes();
let sha256 = hash(json_bytes);
let mut stream = repo.create_stream(Some(sha256), Some(refs));
let mut stream = repo.create_stream(Some(sha256), Some(refs), None);
stream.write_inline(json_bytes);
let id = repo.write_stream(stream, None)?;
Ok((sha256, id))
Expand Down
25 changes: 21 additions & 4 deletions src/oci/tar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@ use std::{
path::PathBuf,
};

use anyhow::{bail, ensure, Result};
use anyhow::{bail, ensure, Context, Result};
use rustix::fs::makedev;
use tar::{EntryType, Header, PaxExtensions};
use tokio::io::{AsyncRead, AsyncReadExt};

use crate::{
dumpfile,
image::{LeafContent, RegularFile, Stat},
splitstream::{SplitStreamData, SplitStreamReader, SplitStreamWriter},
splitstream::{
EnsureObjectMessages, FinishMessage, SplitStreamData, SplitStreamReader, SplitStreamWriter,
},
util::{read_exactish, read_exactish_async},
INLINE_CONTENT_MAX,
};
Expand Down Expand Up @@ -60,7 +62,7 @@ pub fn split<R: Read>(tar_stream: &mut R, writer: &mut SplitStreamWriter) -> Res
if header.entry_type() == EntryType::Regular && actual_size > INLINE_CONTENT_MAX {
// non-empty regular file: store the data in the object store
let padding = buffer.split_off(actual_size);
writer.write_external(&buffer, padding)?;
writer.write_external(buffer, padding, 0, 0)?;
} else {
// else: store the data inline in the split stream
writer.write_inline(&buffer);
Expand All @@ -72,7 +74,10 @@ pub fn split<R: Read>(tar_stream: &mut R, writer: &mut SplitStreamWriter) -> Res
pub async fn split_async(
mut tar_stream: impl AsyncRead + Unpin,
writer: &mut SplitStreamWriter<'_>,
layer_num: usize,
) -> Result<()> {
let mut seq_num = 0;

while let Some(header) = read_header_async(&mut tar_stream).await? {
// the header always gets stored as inline data
writer.write_inline(header.as_bytes());
Expand All @@ -90,12 +95,24 @@ pub async fn split_async(
if header.entry_type() == EntryType::Regular && actual_size > INLINE_CONTENT_MAX {
// non-empty regular file: store the data in the object store
let padding = buffer.split_off(actual_size);
writer.write_external(&buffer, padding)?;
writer.write_external(buffer, padding, seq_num, layer_num)?;
seq_num += 1;
} else {
// else: store the data inline in the split stream
writer.write_inline(&buffer);
}
}

if let Some(sender) = &writer.object_sender {
sender
.send(EnsureObjectMessages::Finish(FinishMessage {
data: std::mem::take(&mut writer.inline_content),
total_msgs: seq_num,
layer_num,
}))
.with_context(|| format!("Failed to send final message for layer {layer_num}"))?;
}

Ok(())
}

Expand Down
Loading