Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions lychee-bin/src/commands/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use tokio_stream::wrappers::ReceiverStream;
use lychee_lib::InputSource;
use lychee_lib::RequestError;
use lychee_lib::archive::Archive;
use lychee_lib::waiter::{WaitGroup, WaitGuard};
use lychee_lib::{Client, ErrorKind, Request, Response, Uri};
use lychee_lib::{ResponseBody, Status};

Expand All @@ -35,6 +36,7 @@ where
let (send_req, recv_req) = mpsc::channel(params.cfg.max_concurrency);
let (send_resp, recv_resp) = mpsc::channel(params.cfg.max_concurrency);
let max_concurrency = params.cfg.max_concurrency;
let (waiter, wait_guard) = WaitGroup::new();

// Measure check time
let start = std::time::Instant::now();
Expand Down Expand Up @@ -73,10 +75,16 @@ where
let level = params.cfg.verbose.log_level();

let progress = Progress::new("Extracting links", hide_bar, level, &params.cfg.mode);
let show_results_task = tokio::spawn(collect_responses(recv_resp, progress.clone(), stats));
let show_results_task = tokio::spawn(collect_responses(
recv_resp,
send_req.clone(),
waiter,
progress.clone(),
stats,
));

// Wait until all requests are sent
send_requests(params.requests, send_req, &progress).await?;
send_requests(params.requests, wait_guard, send_req, &progress).await?;
let (cache, client) = handle.await?;

// Wait until all responses are received
Expand Down Expand Up @@ -153,7 +161,8 @@ async fn suggest_archived_links(
// the show_results_task to finish
async fn send_requests<S>(
requests: S,
send_req: mpsc::Sender<Result<Request, RequestError>>,
guard: WaitGuard,
send_req: mpsc::Sender<(WaitGuard, Result<Request, RequestError>)>,
progress: &Progress,
) -> Result<(), ErrorKind>
where
Expand All @@ -162,28 +171,47 @@ where
tokio::pin!(requests);
while let Some(request) = requests.next().await {
progress.inc_length(1);
send_req.send(request).await.expect("Cannot send request");
send_req
.send((guard.clone(), request))
.await
.expect("Cannot send request");
}
Ok(())
}

/// Reads from the request channel and updates the progress bar status
async fn collect_responses(
mut recv_resp: mpsc::Receiver<Result<Response, ErrorKind>>,
recv_resp: mpsc::Receiver<(WaitGuard, Result<Response, ErrorKind>)>,
send_req: mpsc::Sender<(WaitGuard, Result<Request, RequestError>)>,
waiter: WaitGroup,
progress: Progress,
mut stats: ResponseStats,
) -> Result<ResponseStats, ErrorKind> {
while let Some(response) = recv_resp.recv().await {
// Wrap recv_resp until the WaitGroup finishes, at which time the
// recv_resp_until_done stream will be closed. The correctness of
// WaitGroup guarantees that if the waiter finishes, every channel
// with a WaitGuard must be empty.
let mut recv_resp_until_done = ReceiverStream::new(recv_resp)
.take_until(waiter.wait())
.boxed();

while let Some((_guard, response)) = recv_resp_until_done.next().await {
let response = response?;
progress.update(Some(response.body()));
stats.add(response);
}

// unused for now, but will be used for recursion eventually. by holding
// an extra `send_req` endpoint, we prevent the natural termination when
// each channel finishes and closes. instead, we rely on the WaitGroup to
// break the cyclic channels.
let _ = send_req;
Ok(stats)
}

async fn request_channel_task(
recv_req: mpsc::Receiver<Result<Request, RequestError>>,
send_resp: mpsc::Sender<Result<Response, ErrorKind>>,
recv_req: mpsc::Receiver<(WaitGuard, Result<Request, RequestError>)>,
send_resp: mpsc::Sender<(WaitGuard, Result<Response, ErrorKind>)>,
max_concurrency: usize,
client: Client,
cache: Cache,
Expand All @@ -193,7 +221,7 @@ async fn request_channel_task(
StreamExt::for_each_concurrent(
ReceiverStream::new(recv_req),
max_concurrency,
|request: Result<Request, RequestError>| async {
|(guard, request): (WaitGuard, Result<Request, RequestError>)| async {
let response = handle(
&client,
&cache,
Expand All @@ -204,7 +232,7 @@ async fn request_channel_task(
.await;

send_resp
.send(response)
.send((guard, response))
.await
.expect("cannot send response to queue");
},
Expand Down
2 changes: 2 additions & 0 deletions lychee-lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ pub mod ratelimit;
/// local IPs or e-mail addresses
pub mod filter;

pub mod waiter;

#[cfg(test)]
use doc_comment as _; // required for doctest
use ring as _; // required for apple silicon
Expand Down
60 changes: 60 additions & 0 deletions lychee-lib/src/waiter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//! Facility to wait for a dynamic set of tasks to complete, with a single
//! waiter and multiple waitees (things that are waited for). Notably, each
//! waitee can also start more work to be waited for.
//!
//! # Implementation Details
//!
//! The implementation of waiting in this module is just a wrapper around
//! [`tokio::sync::mpsc::channel`]. A [`WaitGroup`] holds the unique
//! [`tokio::sync::mpsc::Receiver`] and each [`WaitGuard`] holds a
//! [`tokio::sync::mpsc::Sender`]. Despite this simple implementation, the
//! [`WaitGroup`] and [`WaitGuard`] wrappers are useful to make this discoverable.

use std::convert::Infallible;
use tokio::sync::mpsc::{Receiver, Sender, channel};

/// Manager for a particular wait group. This can spawn a number of [`WaitGuard`]s
/// and it can then wait for them to all complete.
///
/// Each [`WaitGroup`] is single-use&mdash;calling [`WaitGroup::wait`] to start
/// waiting consumes the [`WaitGroup`]. Additionally, once all [`WaitGuard`]s
/// have been dropped, it is not possible to create any more [`WaitGuard`]s.
#[derive(Debug)]
pub struct WaitGroup {
/// [`Receiver`] is held to wait for multiple [`Sender`]s and detect
/// when they have closed. The [`Infallible`] type means no value can/will
/// ever be received through the channel.
recv: Receiver<Infallible>,
}

/// RAII guard held by a task which is being waited for.
///
/// The existence of values of this type represents outstanding work for
/// its corresponding [`WaitGroup`].
///
/// A [`WaitGuard`] can be cloned using [`WaitGuard::clone`]. This allows
/// a task to spawn additional tasks, recursively.
#[derive(Clone, Debug)]
pub struct WaitGuard {
/// [`Sender`] is held to keep the [`Receiver`] end open (stored in [`WaitGroup`]).
/// The dropping of all senders will cause the receiver to detect and close.
/// The [`Infallible`] type means no value can/will ever be sent through the channel.
_send: Sender<Infallible>,
}

impl WaitGroup {
/// Creates a new [`WaitGroup`] and its first associated [`WaitGuard`].
///
/// Note that [`WaitGroup`] itself has no ability to create new guards.
/// If needed, new guards should be created by cloning the returned [`WaitGuard`].
#[must_use]
pub fn new() -> (Self, WaitGuard) {
let (send, recv) = channel(1);
(Self { recv }, WaitGuard { _send: send })
}

/// Waits, asynchronously, until all the associated [`WaitGuard`]s have finished.
pub async fn wait(mut self) {
let None = self.recv.recv().await;
}
}