Skip to content

Commit 88a104c

Browse files
Fix deadlock when we have more threads than layers
If we have more threads than we have unprocessed layers, some of the cloned senders aren't dropped and the main thread hangs on the result receiving loop. We make sure here to not spawn more threads than the number of unhandled layers Signed-off-by: Pragyan Poudyal <[email protected]>
1 parent 766cf98 commit 88a104c

File tree

1 file changed

+92
-76
lines changed

1 file changed

+92
-76
lines changed

src/oci/mod.rs

Lines changed: 92 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,14 @@ impl<'repo> ImageOp<'repo> {
166166
let raw_config = self.proxy.fetch_config_raw(&self.img).await?;
167167
let config = ImageConfiguration::from_reader(raw_config.as_slice())?;
168168

169-
let (done_chan_sender, done_chan_recver, object_sender) = self.spawn_threads(&config);
169+
let (done_chan_sender, done_chan_recver, object_sender) =
170+
self.spawn_threads(&config)?;
170171

171172
let mut config_maps = DigestMap::new();
172173

173-
for (idx, (mld, cld)) in zip(manifest_layers, config.rootfs().diff_ids()).enumerate() {
174+
let mut idx = 0;
175+
176+
for (mld, cld) in zip(manifest_layers, config.rootfs().diff_ids()) {
174177
let layer_sha256 = sha256_from_digest(cld)?;
175178

176179
if let Some(layer_id) = self.repo.check_stream(&layer_sha256)? {
@@ -182,6 +185,8 @@ impl<'repo> ImageOp<'repo> {
182185
self.ensure_layer(&layer_sha256, mld, idx, object_sender.clone())
183186
.await
184187
.with_context(|| format!("Failed to fetch layer {cld} via {mld:?}"))?;
188+
189+
idx += 1;
185190
}
186191
}
187192

@@ -205,43 +210,39 @@ impl<'repo> ImageOp<'repo> {
205210
fn spawn_threads(
206211
&self,
207212
config: &ImageConfiguration,
208-
) -> (
213+
) -> Result<(
209214
ResultChannelSender,
210215
ResultChannelReceiver,
211216
crossbeam::channel::Sender<EnsureObjectMessages>,
212-
) {
217+
)> {
213218
use crossbeam::channel::{unbounded, Receiver, Sender};
214219

215-
let encoder_threads = 2;
220+
let mut encoder_threads = 2;
216221
let external_object_writer_threads = 4;
217222

218-
let pool = rayon::ThreadPoolBuilder::new()
219-
.num_threads(encoder_threads + external_object_writer_threads)
220-
.build()
221-
.unwrap();
222-
223-
// We need this as writers have internal state that can't be shared between threads
224-
//
225-
// We'll actually need as many writers (not writer threads, but writer instances) as there are layers.
226-
let zstd_writer_channels: Vec<(Sender<WriterMessages>, Receiver<WriterMessages>)> =
227-
(0..encoder_threads).map(|_| unbounded()).collect();
228-
229-
let (object_sender, object_receiver) = unbounded::<EnsureObjectMessages>();
230-
231-
// (layer_sha256, layer_id)
232-
let (done_chan_sender, done_chan_recver) =
233-
std::sync::mpsc::channel::<Result<(Sha256HashValue, Sha256HashValue)>>();
234-
235223
let chunk_len = config.rootfs().diff_ids().len().div_ceil(encoder_threads);
236224

237225
// Divide the layers into chunks of some specific size so each worker
238226
// thread can work on multiple deterministic layers
239-
let mut chunks: Vec<Vec<Sha256HashValue>> = config
227+
let diff_ids: Vec<Sha256HashValue> = config
240228
.rootfs()
241229
.diff_ids()
242230
.iter()
243-
.map(|x| sha256_from_digest(x).unwrap())
244-
.collect::<Vec<Sha256HashValue>>()
231+
.map(|x| sha256_from_digest(x))
232+
.collect::<Result<Vec<Sha256HashValue>, _>>()?;
233+
234+
let mut unhandled_layers = vec![];
235+
236+
// This becomes pretty unreadable with a filter,map chain
237+
for id in diff_ids {
238+
let layer_exists = self.repo.check_stream(&id)?;
239+
240+
if layer_exists.is_none() {
241+
unhandled_layers.push(id);
242+
}
243+
}
244+
245+
let mut chunks: Vec<Vec<Sha256HashValue>> = unhandled_layers
245246
.chunks(chunk_len)
246247
.map(|x| x.to_vec())
247248
.collect();
@@ -255,60 +256,75 @@ impl<'repo> ImageOp<'repo> {
255256
.flat_map(|(i, chunk)| std::iter::repeat(i).take(chunk.len()).collect::<Vec<_>>())
256257
.collect::<Vec<_>>();
257258

258-
let _ = (0..encoder_threads)
259-
.map(|i| {
260-
let repository = self.repo.try_clone().unwrap();
261-
let object_sender = object_sender.clone();
262-
let done_chan_sender = done_chan_sender.clone();
263-
let chunk = std::mem::take(&mut chunks[i]);
264-
let receiver = zstd_writer_channels[i].1.clone();
265-
266-
pool.spawn({
267-
move || {
268-
let start = i * (chunk_len);
269-
let end = start + chunk_len;
270-
271-
let enc = zstd_encoder::MultipleZstdWriters::new(
272-
chunk,
273-
repository,
274-
object_sender,
275-
done_chan_sender,
276-
);
277-
278-
if let Err(e) = enc.recv_data(receiver, start, end) {
279-
eprintln!("zstd_encoder returned with error: {}", e)
280-
}
259+
encoder_threads = encoder_threads.min(chunks.len());
260+
261+
let pool = rayon::ThreadPoolBuilder::new()
262+
.num_threads(encoder_threads + external_object_writer_threads)
263+
.build()
264+
.unwrap();
265+
266+
// We need this as writers have internal state that can't be shared between threads
267+
//
268+
// We'll actually need as many writers (not writer threads, but writer instances) as there are layers.
269+
let zstd_writer_channels: Vec<(Sender<WriterMessages>, Receiver<WriterMessages>)> =
270+
(0..encoder_threads).map(|_| unbounded()).collect();
271+
272+
let (object_sender, object_receiver) = unbounded::<EnsureObjectMessages>();
273+
274+
// (layer_sha256, layer_id)
275+
let (done_chan_sender, done_chan_recver) =
276+
std::sync::mpsc::channel::<Result<(Sha256HashValue, Sha256HashValue)>>();
277+
278+
for i in 0..encoder_threads {
279+
let repository = self.repo.try_clone().unwrap();
280+
let object_sender = object_sender.clone();
281+
let done_chan_sender = done_chan_sender.clone();
282+
let chunk = std::mem::take(&mut chunks[i]);
283+
let receiver = zstd_writer_channels[i].1.clone();
284+
285+
pool.spawn({
286+
move || {
287+
let start = i * (chunk_len);
288+
let end = start + chunk_len;
289+
290+
let enc = zstd_encoder::MultipleZstdWriters::new(
291+
chunk,
292+
repository,
293+
object_sender,
294+
done_chan_sender,
295+
);
296+
297+
if let Err(e) = enc.recv_data(receiver, start, end) {
298+
eprintln!("zstd_encoder returned with error: {}", e)
281299
}
282-
});
283-
})
284-
.collect::<Vec<()>>();
285-
286-
let _ = (0..external_object_writer_threads)
287-
.map(|_| {
288-
pool.spawn({
289-
let repository = self.repo.try_clone().unwrap();
290-
let zstd_writer_channels = zstd_writer_channels
291-
.iter()
292-
.map(|(s, _)| s.clone())
293-
.collect::<Vec<_>>();
294-
let layers_to_chunks = layers_to_chunks.clone();
295-
let external_object_receiver = object_receiver.clone();
296-
297-
move || {
298-
if let Err(e) = handle_external_object(
299-
repository,
300-
external_object_receiver,
301-
zstd_writer_channels,
302-
layers_to_chunks,
303-
) {
304-
eprintln!("handle_external_object returned with error: {}", e);
305-
}
300+
}
301+
});
302+
}
303+
304+
for _ in 0..external_object_writer_threads {
305+
pool.spawn({
306+
let repository = self.repo.try_clone().unwrap();
307+
let zstd_writer_channels = zstd_writer_channels
308+
.iter()
309+
.map(|(s, _)| s.clone())
310+
.collect::<Vec<_>>();
311+
let layers_to_chunks = layers_to_chunks.clone();
312+
let external_object_receiver = object_receiver.clone();
313+
314+
move || {
315+
if let Err(e) = handle_external_object(
316+
repository,
317+
external_object_receiver,
318+
zstd_writer_channels,
319+
layers_to_chunks,
320+
) {
321+
eprintln!("handle_external_object returned with error: {}", e);
306322
}
307-
});
308-
})
309-
.collect::<Vec<_>>();
323+
}
324+
});
325+
}
310326

311-
(done_chan_sender, done_chan_recver, object_sender)
327+
Ok((done_chan_sender, done_chan_recver, object_sender))
312328
}
313329

314330
pub async fn pull(&self) -> Result<(Sha256HashValue, Sha256HashValue)> {

0 commit comments

Comments
 (0)