Skip to content

Commit 44e1802

Browse files
Fix deadlock when we have more threads than layers
If we have more threads than we have unprocessed layers, some of the cloned senders aren't dropped and the main thread hangs on the result receiving loop. We make sure here to not spawn more threads than the number of unhandled layers Signed-off-by: Pragyan Poudyal <[email protected]>
1 parent bc09a83 commit 44e1802

File tree

1 file changed

+92
-76
lines changed

1 file changed

+92
-76
lines changed

src/oci/mod.rs

Lines changed: 92 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,14 @@ impl<'repo> ImageOp<'repo> {
175175
let raw_config = config?;
176176
let config = ImageConfiguration::from_reader(&raw_config[..])?;
177177

178-
let (done_chan_sender, done_chan_recver, object_sender) = self.spawn_threads(&config);
178+
let (done_chan_sender, done_chan_recver, object_sender) =
179+
self.spawn_threads(&config)?;
179180

180181
let mut config_maps = DigestMap::new();
181182

182-
for (idx, (mld, cld)) in zip(manifest_layers, config.rootfs().diff_ids()).enumerate() {
183+
let mut idx = 0;
184+
185+
for (mld, cld) in zip(manifest_layers, config.rootfs().diff_ids()) {
183186
let layer_sha256 = sha256_from_digest(cld)?;
184187

185188
if let Some(layer_id) = self.repo.check_stream(&layer_sha256)? {
@@ -191,6 +194,8 @@ impl<'repo> ImageOp<'repo> {
191194
self.ensure_layer(&layer_sha256, mld, idx, object_sender.clone())
192195
.await
193196
.with_context(|| format!("Failed to fetch layer {cld} via {mld:?}"))?;
197+
198+
idx += 1;
194199
}
195200
}
196201

@@ -214,43 +219,39 @@ impl<'repo> ImageOp<'repo> {
214219
fn spawn_threads(
215220
&self,
216221
config: &ImageConfiguration,
217-
) -> (
222+
) -> Result<(
218223
ResultChannelSender,
219224
ResultChannelReceiver,
220225
crossbeam::channel::Sender<EnsureObjectMessages>,
221-
) {
226+
)> {
222227
use crossbeam::channel::{unbounded, Receiver, Sender};
223228

224-
let encoder_threads = 2;
229+
let mut encoder_threads = 2;
225230
let external_object_writer_threads = 4;
226231

227-
let pool = rayon::ThreadPoolBuilder::new()
228-
.num_threads(encoder_threads + external_object_writer_threads)
229-
.build()
230-
.unwrap();
231-
232-
// We need this as writers have internal state that can't be shared between threads
233-
//
234-
// We'll actually need as many writers (not writer threads, but writer instances) as there are layers.
235-
let zstd_writer_channels: Vec<(Sender<WriterMessages>, Receiver<WriterMessages>)> =
236-
(0..encoder_threads).map(|_| unbounded()).collect();
237-
238-
let (object_sender, object_receiver) = unbounded::<EnsureObjectMessages>();
239-
240-
// (layer_sha256, layer_id)
241-
let (done_chan_sender, done_chan_recver) =
242-
std::sync::mpsc::channel::<Result<(Sha256HashValue, Sha256HashValue)>>();
243-
244232
let chunk_len = config.rootfs().diff_ids().len().div_ceil(encoder_threads);
245233

246234
// Divide the layers into chunks of some specific size so each worker
247235
// thread can work on multiple deterministic layers
248-
let mut chunks: Vec<Vec<Sha256HashValue>> = config
236+
let diff_ids: Vec<Sha256HashValue> = config
249237
.rootfs()
250238
.diff_ids()
251239
.iter()
252-
.map(|x| sha256_from_digest(x).unwrap())
253-
.collect::<Vec<Sha256HashValue>>()
240+
.map(|x| sha256_from_digest(x))
241+
.collect::<Result<Vec<Sha256HashValue>, _>>()?;
242+
243+
let mut unhandled_layers = vec![];
244+
245+
// This becomes pretty unreadable with a filter,map chain
246+
for id in diff_ids {
247+
let layer_exists = self.repo.check_stream(&id)?;
248+
249+
if layer_exists.is_none() {
250+
unhandled_layers.push(id);
251+
}
252+
}
253+
254+
let mut chunks: Vec<Vec<Sha256HashValue>> = unhandled_layers
254255
.chunks(chunk_len)
255256
.map(|x| x.to_vec())
256257
.collect();
@@ -264,60 +265,75 @@ impl<'repo> ImageOp<'repo> {
264265
.flat_map(|(i, chunk)| std::iter::repeat(i).take(chunk.len()).collect::<Vec<_>>())
265266
.collect::<Vec<_>>();
266267

267-
let _ = (0..encoder_threads)
268-
.map(|i| {
269-
let repository = self.repo.try_clone().unwrap();
270-
let object_sender = object_sender.clone();
271-
let done_chan_sender = done_chan_sender.clone();
272-
let chunk = std::mem::take(&mut chunks[i]);
273-
let receiver = zstd_writer_channels[i].1.clone();
274-
275-
pool.spawn({
276-
move || {
277-
let start = i * (chunk_len);
278-
let end = start + chunk_len;
279-
280-
let enc = zstd_encoder::MultipleZstdWriters::new(
281-
chunk,
282-
repository,
283-
object_sender,
284-
done_chan_sender,
285-
);
286-
287-
if let Err(e) = enc.recv_data(receiver, start, end) {
288-
eprintln!("zstd_encoder returned with error: {}", e)
289-
}
268+
encoder_threads = encoder_threads.min(chunks.len());
269+
270+
let pool = rayon::ThreadPoolBuilder::new()
271+
.num_threads(encoder_threads + external_object_writer_threads)
272+
.build()
273+
.unwrap();
274+
275+
// We need this as writers have internal state that can't be shared between threads
276+
//
277+
// We'll actually need as many writers (not writer threads, but writer instances) as there are layers.
278+
let zstd_writer_channels: Vec<(Sender<WriterMessages>, Receiver<WriterMessages>)> =
279+
(0..encoder_threads).map(|_| unbounded()).collect();
280+
281+
let (object_sender, object_receiver) = unbounded::<EnsureObjectMessages>();
282+
283+
// (layer_sha256, layer_id)
284+
let (done_chan_sender, done_chan_recver) =
285+
std::sync::mpsc::channel::<Result<(Sha256HashValue, Sha256HashValue)>>();
286+
287+
for i in 0..encoder_threads {
288+
let repository = self.repo.try_clone().unwrap();
289+
let object_sender = object_sender.clone();
290+
let done_chan_sender = done_chan_sender.clone();
291+
let chunk = std::mem::take(&mut chunks[i]);
292+
let receiver = zstd_writer_channels[i].1.clone();
293+
294+
pool.spawn({
295+
move || {
296+
let start = i * (chunk_len);
297+
let end = start + chunk_len;
298+
299+
let enc = zstd_encoder::MultipleZstdWriters::new(
300+
chunk,
301+
repository,
302+
object_sender,
303+
done_chan_sender,
304+
);
305+
306+
if let Err(e) = enc.recv_data(receiver, start, end) {
307+
eprintln!("zstd_encoder returned with error: {}", e)
290308
}
291-
});
292-
})
293-
.collect::<Vec<()>>();
294-
295-
let _ = (0..external_object_writer_threads)
296-
.map(|_| {
297-
pool.spawn({
298-
let repository = self.repo.try_clone().unwrap();
299-
let zstd_writer_channels = zstd_writer_channels
300-
.iter()
301-
.map(|(s, _)| s.clone())
302-
.collect::<Vec<_>>();
303-
let layers_to_chunks = layers_to_chunks.clone();
304-
let external_object_receiver = object_receiver.clone();
305-
306-
move || {
307-
if let Err(e) = handle_external_object(
308-
repository,
309-
external_object_receiver,
310-
zstd_writer_channels,
311-
layers_to_chunks,
312-
) {
313-
eprintln!("handle_external_object returned with error: {}", e);
314-
}
309+
}
310+
});
311+
}
312+
313+
for _ in 0..external_object_writer_threads {
314+
pool.spawn({
315+
let repository = self.repo.try_clone().unwrap();
316+
let zstd_writer_channels = zstd_writer_channels
317+
.iter()
318+
.map(|(s, _)| s.clone())
319+
.collect::<Vec<_>>();
320+
let layers_to_chunks = layers_to_chunks.clone();
321+
let external_object_receiver = object_receiver.clone();
322+
323+
move || {
324+
if let Err(e) = handle_external_object(
325+
repository,
326+
external_object_receiver,
327+
zstd_writer_channels,
328+
layers_to_chunks,
329+
) {
330+
eprintln!("handle_external_object returned with error: {}", e);
315331
}
316-
});
317-
})
318-
.collect::<Vec<_>>();
332+
}
333+
});
334+
}
319335

320-
(done_chan_sender, done_chan_recver, object_sender)
336+
Ok((done_chan_sender, done_chan_recver, object_sender))
321337
}
322338

323339
pub async fn pull(&self) -> Result<(Sha256HashValue, Sha256HashValue)> {

0 commit comments

Comments
 (0)