Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions benches/bench_channel_async.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use bmrng::channel;
use bmrng::{channel, Request};
use criterion::{criterion_group, criterion_main, Criterion, Throughput};
use tokio::sync::mpsc;

struct Req;

impl Request for Req {
type Response = ();
}

fn rt() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(6)
Expand All @@ -16,15 +22,15 @@ fn benchmark_async(c: &mut Criterion) {

group.bench_function("bmrng async, bounded, capacity = 1", move |b| {
b.to_async(rt()).iter(|| async {
let (tx, rx) = channel::<(), ()>(1);
let (tx, rx) = channel::<Req>(1);
tokio::spawn(async move {
let mut rx = rx;
let req = rx.recv().await;
if let Ok(req) = req {
let _ = req.1.respond(());
let _ = req.responder.respond(());
}
});
let _ = tx.send_receive(()).await;
let _ = tx.send_receive(Req).await;
})
});

Expand Down
14 changes: 10 additions & 4 deletions benches/bench_channel_sync.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use bmrng::channel;
use bmrng::{channel, Request};
use criterion::{criterion_group, criterion_main, Criterion, Throughput};
use tokio::sync::mpsc;

struct Req;

impl Request for Req {
type Response = ();
}

fn rt() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(6)
Expand All @@ -19,15 +25,15 @@ fn benchmark_sync(c: &mut Criterion) {
group.bench_function("bmrng, bounded, capacity = 1", |b| {
b.iter(|| {
rt.block_on(async move {
let (tx, rx) = channel::<(), ()>(1);
let (tx, rx) = channel::<Req>(1);
tokio::spawn(async move {
let mut rx = rx;
let req = rx.recv().await;
if let Ok(req) = req {
let _ = req.1.respond(());
let _ = req.responder.respond(());
}
});
let _ = tx.send_receive(()).await;
let _ = tx.send_receive(Req).await;
});
});
});
Expand Down
136 changes: 79 additions & 57 deletions src/bounded.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::error::{ReceiveError, RequestError, RespondError, SendError};
use crate::{Request, error::{ReceiveError, RequestError, RespondError, SendError}};

use tokio::sync::{mpsc, oneshot};
use tokio::time::{timeout, Duration};
Expand All @@ -8,43 +8,49 @@ use std::pin::Pin;
use std::task::{Context, Poll};

/// The internal data sent in the MPSC request channel, a tuple that contains the request and the oneshot response channel responder
pub type Payload<Req, Res> = (Req, Responder<Res>);
#[derive(Debug)]
pub struct Payload<R: Request> {
/// the request
pub request: R,
/// the responder
pub responder: Responder<R::Response>,
}

/// Send values to the associated [`RequestReceiver`].
#[derive(Debug)]
pub struct RequestSender<Req, Res> {
request_sender: mpsc::Sender<Payload<Req, Res>>,
pub struct RequestSender<R: Request> {
request_sender: mpsc::Sender<Payload<R>>,
timeout_duration: Option<Duration>,
}

/// Receive requests values from the associated [`RequestSender`]
///
/// Instances are created by the [`channel`] function.
#[derive(Debug)]
pub struct RequestReceiver<Req, Res> {
request_receiver: mpsc::Receiver<Payload<Req, Res>>,
pub struct RequestReceiver<R: Request> {
request_receiver: mpsc::Receiver<Payload<R>>,
}

/// Send values back to the [`RequestSender`] or [`RequestReceiver`]
///
/// Instances are created by calling [`RequestSender::send_receive()`] or [`RequestSender::send()`]
#[derive(Debug)]
pub struct Responder<Res> {
response_sender: oneshot::Sender<Res>,
pub struct Responder<R> {
response_sender: oneshot::Sender<R>,
}

/// Receive responses from a [`Responder`]
///
/// Instances are created by calling [`RequestSender::send_receive()`] or [`RequestSender::send()`]
#[derive(Debug)]
pub struct ResponseReceiver<Res> {
pub(crate) response_receiver: Option<oneshot::Receiver<Res>>,
pub struct ResponseReceiver<R> {
pub(crate) response_receiver: Option<oneshot::Receiver<R>>,
pub(crate) timeout_duration: Option<Duration>,
}

impl<Req, Res> RequestSender<Req, Res> {
impl<R: Request> RequestSender<R> {
fn new(
request_sender: mpsc::Sender<Payload<Req, Res>>,
request_sender: mpsc::Sender<Payload<R>>,
timeout_duration: Option<Duration>,
) -> Self {
RequestSender {
Expand All @@ -58,22 +64,22 @@ impl<Req, Res> RequestSender<Req, Res> {
/// Return the [`ResponseReceiver`] which can be used to wait for a response
///
/// This call waits if the request channel is full. It does not wait for a response
pub async fn send(&self, request: Req) -> Result<ResponseReceiver<Res>, SendError<Req>> {
let (response_sender, response_receiver) = oneshot::channel::<Res>();
pub async fn send(&self, request: R) -> Result<ResponseReceiver<R::Response>, SendError<R>> {
let (response_sender, response_receiver) = oneshot::channel::<R::Response>();
let responder = Responder::new(response_sender);
let payload = (request, responder);
let payload = Payload { request, responder };
self.request_sender
.send(payload)
.await
.map_err(|payload| SendError(payload.0 .0))?;
.map_err(|payload| SendError(payload.0.request))?;
let receiver = ResponseReceiver::new(response_receiver, self.timeout_duration);
Ok(receiver)
}

/// Send a request over the MPSC channel, wait for the response and return it
///
/// This call waits if the request channel is full, and while waiting for the response
pub async fn send_receive(&self, request: Req) -> Result<Res, RequestError<Req>> {
pub async fn send_receive(&self, request: R) -> Result<R::Response, RequestError<R>> {
let mut receiver = self.send(request).await?;
receiver.recv().await.map_err(|err| err.into())
}
Expand All @@ -84,7 +90,7 @@ impl<Req, Res> RequestSender<Req, Res> {
}
}

impl<Req, Res> Clone for RequestSender<Req, Res> {
impl<R: Request> Clone for RequestSender<R> {
fn clone(&self) -> Self {
RequestSender {
request_sender: self.request_sender.clone(),
Expand All @@ -93,15 +99,15 @@ impl<Req, Res> Clone for RequestSender<Req, Res> {
}
}

impl<Req, Res> RequestReceiver<Req, Res> {
fn new(receiver: mpsc::Receiver<Payload<Req, Res>>) -> Self {
impl<R: Request> RequestReceiver<R> {
fn new(receiver: mpsc::Receiver<Payload<R>>) -> Self {
RequestReceiver {
request_receiver: receiver,
}
}

/// Receives the next value for this receiver.
pub async fn recv(&mut self) -> Result<Payload<Req, Res>, RequestError<Req>> {
pub async fn recv(&mut self) -> Result<Payload<R>, RequestError<R>> {
match self.request_receiver.recv().await {
Some(payload) => Ok(payload),
None => Err(RequestError::RecvError),
Expand All @@ -114,15 +120,15 @@ impl<Req, Res> RequestReceiver<Req, Res> {
}

/// Converts this receiver into a stream
pub fn into_stream(self) -> impl Stream<Item = Payload<Req, Res>> {
let stream: RequestReceiverStream<Req, Res> = self.into();
pub fn into_stream(self) -> impl Stream<Item = Payload<R>> {
let stream: RequestReceiverStream<R> = self.into();
stream
}
}

impl<Res> ResponseReceiver<Res> {
impl<R> ResponseReceiver<R> {
pub(crate) fn new(
response_receiver: oneshot::Receiver<Res>,
response_receiver: oneshot::Receiver<R>,
timeout_duration: Option<Duration>,
) -> Self {
Self {
Expand All @@ -136,7 +142,7 @@ impl<Res> ResponseReceiver<Res> {
/// If there is a `timeout_duration` set, and the sender takes longer than
/// the timeout_duration to send the response, it aborts waiting and returns
/// [`ReceiveError::TimeoutError`].
pub async fn recv(&mut self) -> Result<Res, ReceiveError> {
pub async fn recv(&mut self) -> Result<R, ReceiveError> {
match self.response_receiver.take() {
Some(response_receiver) => match self.timeout_duration {
Some(duration) => match timeout(duration, response_receiver).await {
Expand All @@ -150,13 +156,13 @@ impl<Res> ResponseReceiver<Res> {
}
}

impl<Res> Responder<Res> {
pub(crate) fn new(response_sender: oneshot::Sender<Res>) -> Self {
impl<R> Responder<R> {
pub(crate) fn new(response_sender: oneshot::Sender<R>) -> Self {
Self { response_sender }
}

/// Responds a request from the [`RequestSender`] which finishes the request
pub fn respond(self, response: Res) -> Result<(), RespondError<Res>> {
pub fn respond(self, response: R) -> Result<(), RespondError<R>> {
self.response_sender.send(response).map_err(RespondError)
}

Expand All @@ -176,27 +182,35 @@ impl<Res> Responder<Res> {
/// # Examples
///
/// ```rust
/// use bmrng::{Request, Payload};
///
/// #[derive(Debug)]
/// struct Req(u32);
/// impl Request for Req {
/// type Response = u32;
/// }
///
/// #[tokio::main]
/// async fn main() {
/// let buffer_size = 100;
/// let (tx, mut rx) = bmrng::channel::<i32, i32>(buffer_size);
/// let (tx, mut rx) = bmrng::channel::<Req>(buffer_size);
/// tokio::spawn(async move {
/// while let Ok((input, mut responder)) = rx.recv().await {
/// if let Err(err) = responder.respond(input * input) {
/// while let Ok(Payload { request, mut responder }) = rx.recv().await {
/// if let Err(err) = responder.respond(request.0 * 2) {
/// println!("sender dropped the response channel");
/// }
/// }
/// });
/// for i in 1..=10 {
/// if let Ok(response) = tx.send_receive(i).await {
/// if let Ok(response) = tx.send_receive(Req(i)).await {
/// println!("Requested {}, got {}", i, response);
/// assert_eq!(response, i * i);
/// assert_eq!(response, i * 2);
/// }
/// }
/// }
/// ```
pub fn channel<Req, Res>(buffer: usize) -> (RequestSender<Req, Res>, RequestReceiver<Req, Res>) {
let (sender, receiver) = mpsc::channel::<Payload<Req, Res>>(buffer);
pub fn channel<R: Request>(buffer: usize) -> (RequestSender<R>, RequestReceiver<R>) {
let (sender, receiver) = mpsc::channel::<Payload<R>>(buffer);
let request_sender = RequestSender::new(sender, None);
let request_receiver = RequestReceiver::new(receiver);
(request_sender, request_receiver)
Expand All @@ -213,50 +227,58 @@ pub fn channel<Req, Res>(buffer: usize) -> (RequestSender<Req, Res>, RequestRece
///
/// ```rust
/// use tokio::time::{Duration, sleep};
/// use bmrng::{Request, Payload};
///
/// #[derive(Debug, PartialEq)]
/// struct Req(u32);
/// impl Request for Req {
/// type Response = u32;
/// }
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = bmrng::channel_with_timeout::<i32, i32>(100, Duration::from_millis(100));
/// let (tx, mut rx) = bmrng::channel_with_timeout::<Req>(100, Duration::from_millis(100));
/// tokio::spawn(async move {
/// match rx.recv().await {
/// Ok((input, mut responder)) => {
/// Ok(Payload { request, mut responder }) => {
/// sleep(Duration::from_millis(200)).await;
/// let res = responder.respond(input * input);
/// let res = responder.respond(request.0 * 2);
/// assert_eq!(res.is_ok(), true);
/// }
/// Err(err) => {
/// println!("all request senders dropped");
/// }
/// }
/// });
/// let response = tx.send_receive(8).await;
/// assert_eq!(response, Err(bmrng::error::RequestError::<i32>::RecvTimeoutError));
/// let response = tx.send_receive(Req(8)).await;
/// assert_eq!(response, Err(bmrng::error::RequestError::<Req>::RecvTimeoutError));
/// }
/// ```
pub fn channel_with_timeout<Req, Res>(
pub fn channel_with_timeout<R: Request>(
buffer: usize,
timeout_duration: Duration,
) -> (RequestSender<Req, Res>, RequestReceiver<Req, Res>) {
let (sender, receiver) = mpsc::channel::<Payload<Req, Res>>(buffer);
) -> (RequestSender<R>, RequestReceiver<R>) {
let (sender, receiver) = mpsc::channel::<Payload<R>>(buffer);
let request_sender = RequestSender::new(sender, Some(timeout_duration));
let request_receiver = RequestReceiver::new(receiver);
(request_sender, request_receiver)
}

/// A wrapper around [`RequestReceiver`] that implements [`Stream`].
#[derive(Debug)]
pub struct RequestReceiverStream<Req, Res> {
inner: RequestReceiver<Req, Res>,
pub struct RequestReceiverStream<R: Request> {
inner: RequestReceiver<R>,
}

impl<Req, Res> RequestReceiverStream<Req, Res> {
impl<R: Request> RequestReceiverStream<R> {
/// Create a new `RequestReceiverStream`.
pub fn new(recv: RequestReceiver<Req, Res>) -> Self {
pub fn new(recv: RequestReceiver<R>) -> Self {
Self { inner: recv }
}

/// Get back the inner `Receiver`.
#[cfg(not(tarpaulin_include))]
pub fn into_inner(self) -> RequestReceiver<Req, Res> {
pub fn into_inner(self) -> RequestReceiver<R> {
self.inner
}

Expand All @@ -267,30 +289,30 @@ impl<Req, Res> RequestReceiverStream<Req, Res> {
}
}

impl<Req, Res> Stream for RequestReceiverStream<Req, Res> {
type Item = Payload<Req, Res>;
impl<R: Request> Stream for RequestReceiverStream<R> {
type Item = Payload<R>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.request_receiver.poll_recv(cx)
}
}

impl<Req, Res> AsRef<RequestReceiver<Req, Res>> for RequestReceiverStream<Req, Res> {
impl<R: Request> AsRef<RequestReceiver<R>> for RequestReceiverStream<R> {
#[cfg(not(tarpaulin_include))]
fn as_ref(&self) -> &RequestReceiver<Req, Res> {
fn as_ref(&self) -> &RequestReceiver<R> {
&self.inner
}
}

impl<Req, Res> AsMut<RequestReceiver<Req, Res>> for RequestReceiverStream<Req, Res> {
impl<R: Request> AsMut<RequestReceiver<R>> for RequestReceiverStream<R> {
#[cfg(not(tarpaulin_include))]
fn as_mut(&mut self) -> &mut RequestReceiver<Req, Res> {
fn as_mut(&mut self) -> &mut RequestReceiver<R> {
&mut self.inner
}
}

impl<Req, Res> From<RequestReceiver<Req, Res>> for RequestReceiverStream<Req, Res> {
fn from(receiver: RequestReceiver<Req, Res>) -> Self {
impl<R: Request> From<RequestReceiver<R>> for RequestReceiverStream<R> {
fn from(receiver: RequestReceiver<R>) -> Self {
RequestReceiverStream::new(receiver)
}
}
Loading