diff --git a/Cargo.toml b/Cargo.toml index 459b5e9c..a463d8f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,13 @@ members = [ "wrappers/shuttle_rand_0.8", "wrappers/shuttle_sync", "wrappers/lazy_static", + "wrappers/tokio/impls/tokio", + "wrappers/tokio/impls/tokio-macros", + "wrappers/tokio/impls/tokio-stream", + "wrappers/tokio/impls/tokio-test", + "wrappers/tokio/impls/tokio-util", + "wrappers/tokio/wrappers/shuttle-tokio", + "wrappers/tokio/wrappers/shuttle-tokio-stream", ] resolver = "2" diff --git a/wrappers/tokio/README.md b/wrappers/tokio/README.md new file mode 100644 index 00000000..9689607a --- /dev/null +++ b/wrappers/tokio/README.md @@ -0,0 +1,25 @@ +# Shuttle support for `tokio` + +This folder contains the implementation and wrapper that enables testing of [tokio](https://crates.io/crates/tokio) applications with Shuttle. + +## How to use + +To use it, add the following in your Cargo.toml: + +``` +[features] +shuttle = [ + "tokio/shuttle", +] + +[dependencies] +tokio = { package = "shuttle-tokio", version = "VERSION_NUMBER" } +``` + +The code will then behave as before when the `shuttle` feature flag is not provided, and will run with Shuttle-compatible primitives when the `shuttle` feature flag is provided. + +## Limitations + +Shuttle's tokio support does not currently model all tokio functionality. Some parts of tokio have not been implemented or may not be modeled faithfully. Keep this in mind when using Shuttle with tokio, as you may encounter missing functionality that needs to be added. If you encounter missing features, please file an issue or, better yet, open a PR to contribute the functionality. + +The list of constructs not supported by Shuttle are in [Issue 241](https://github.com/awslabs/shuttle/issues/241) diff --git a/wrappers/tokio/impls/tokio-macros/Cargo.toml b/wrappers/tokio/impls/tokio-macros/Cargo.toml new file mode 100644 index 00000000..fff0a97e --- /dev/null +++ b/wrappers/tokio/impls/tokio-macros/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "shuttle-tokio-macros-impl" +version = "0.1.0" # Forked from "2.2.0" +edition = "2021" + +[lib] +proc-macro = true + +[features] + +[dependencies] +proc-macro2 = "1.0.60" +quote = "1" +syn = { version = "2.0", features = ["full"] } +shuttle = { path = "../../../../shuttle" } +proc-macro-crate = "3.1.0" + +[dev-dependencies] +shuttle-tokio = { package = "shuttle-tokio-impl", path = "../tokio", version = "*" } diff --git a/wrappers/tokio/impls/tokio-macros/LICENSE b/wrappers/tokio/impls/tokio-macros/LICENSE new file mode 100644 index 00000000..daf719b1 --- /dev/null +++ b/wrappers/tokio/impls/tokio-macros/LICENSE @@ -0,0 +1,47 @@ +Copyright (c) 2023 Tokio Contributors + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. + +The MIT License (MIT) + +Copyright (c) 2019 Yoshua Wuyts + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/wrappers/tokio/impls/tokio-macros/src/entry.rs b/wrappers/tokio/impls/tokio-macros/src/entry.rs new file mode 100644 index 00000000..7292e7ac --- /dev/null +++ b/wrappers/tokio/impls/tokio-macros/src/entry.rs @@ -0,0 +1,394 @@ +//! This file is based on [tokio-macros/src/entry.rs](https://github.com/tokio-rs/tokio/blob/master/tokio-macros/src/entry.rs), +//! but has had the following changes applied. All changes are labeled with `SHUTTLE_CHANGES` markers: +//! 1. Unsupported features (RuntimeFlavor, UnhandledPanic, etc.) have been removed +//! 2. Crate name resolution has been improved to support package renaming and the wrapper scheme +//! 3. The macro output wraps the function body in a `shuttle_tokio::check()` call instead of setting up a tokio runtime + +use proc_macro2::{Span, TokenStream, TokenTree}; +use proc_macro_crate::{crate_name, FoundCrate}; +use quote::{quote, quote_spanned, ToTokens}; +use syn::parse::{Parse, ParseStream, Parser}; +use syn::{braced, Attribute, Ident, ReturnType, Signature, Visibility}; + +// syn::AttributeArgs does not implement syn::Parse +type AttributeArgs = syn::punctuated::Punctuated; + +struct Configuration { + is_test: bool, +} + +impl Configuration { + fn new(is_test: bool) -> Self { + Configuration { is_test } + } + + fn macro_name(&self) -> &'static str { + if self.is_test { + "shuttle_tokio::test" + } else { + "shuttle_tokio::main" + } + } +} + +fn build_config(input: &ItemFn, args: AttributeArgs, is_test: bool) -> Result<(), syn::Error> { + if input.sig.asyncness.is_none() { + let msg = "the `async` keyword is missing from the function declaration"; + return Err(syn::Error::new_spanned(input.sig.fn_token, msg)); + } + + let config = Configuration::new(is_test); + let macro_name = config.macro_name(); + + for arg in args { + match arg { + syn::Meta::NameValue(namevalue) => { + let ident = namevalue + .path + .get_ident() + .ok_or_else(|| syn::Error::new_spanned(&namevalue, "Must have specified ident"))? + .to_string() + .to_lowercase(); + let _lit = match &namevalue.value { + syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit, + expr => return Err(syn::Error::new_spanned(expr, "Must be a literal")), + }; + match ident.as_str() { + "worker_threads" => {} + "flavor" => {} + "start_paused" => {} + "core_threads" => { + let msg = "Attribute `core_threads` is renamed to `worker_threads`"; + return Err(syn::Error::new_spanned(namevalue, msg)); + } + "crate" => {} + name => { + let msg = format!( + "Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`", + ); + return Err(syn::Error::new_spanned(namevalue, msg)); + } + } + } + syn::Meta::Path(path) => { + let name = path + .get_ident() + .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))? + .to_string() + .to_lowercase(); + let msg = match name.as_str() { + "threaded_scheduler" | "multi_thread" => { + format!("Set the runtime flavor with #[{macro_name}(flavor = \"multi_thread\")].") + } + "basic_scheduler" | "current_thread" | "single_threaded" => { + format!("Set the runtime flavor with #[{macro_name}(flavor = \"current_thread\")].") + } + "flavor" | "worker_threads" | "start_paused" => { + format!("The `{name}` attribute requires an argument.") + } + name => { + format!( + "Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`" + ) + } + }; + return Err(syn::Error::new_spanned(path, msg)); + } + other => { + return Err(syn::Error::new_spanned(other, "Unknown attribute inside the macro")); + } + } + } + + Ok(()) +} + +fn get_env(s: &str, default: T) -> T +where + T: std::str::FromStr, +{ + std::env::var(s) + .map(|v| { + v.parse::() + .unwrap_or_else(|_| panic!("cannot parse env var {}={} as {}", s, v, std::any::type_name::())) + }) + .unwrap_or(default) +} + +fn get_env_usize(s: &str, default: usize) -> usize { + get_env(s, default) +} + +// SHUTTLE_CHANGES +// Slightly modified from the version in Tokio. Part which sets up the runtime is removed. +// The one in Tokio follows a `let body = quote! {}` scheme, this one wraps the body directly. +fn parse_knobs(mut input: ItemFn, is_test: bool) -> TokenStream { + input.sig.asyncness = None; + + // If type mismatch occurs, the current rustc points to the last statement. + let (_last_stmt_start_span, last_stmt_end_span) = { + let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter(); + + // `Span` on stable Rust has a limitation that only points to the first + // token, not the whole tokens. We can work around this limitation by + // using the first/last span of the tokens like + // `syn::Error::new_spanned` does. + let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span()); + let end = last_stmt.last().map_or(start, |t| t.span()); + (start, end) + }; + + // SHUTTLE_CHANGES + // Changed from how Tokio does it. Tokio's approach does not allow package renaming. + // This checks for an import of one of `shuttle-tokio`, `shuttle-tokio-impl`, or `shuttle-tokio-impl-inner`, and uses the first one found. + let found_crate = crate_name("shuttle-tokio").unwrap_or_else(|_| { + crate_name("shuttle-tokio-impl").unwrap_or_else(|_|{ + crate_name("shuttle-tokio-impl-inner") + .expect("Could not find an import for \"shuttle-tokio\", \"shuttle-tokio-impl\", or \"shuttle-tokio-impl-inner\".") + }) + }); + + let crate_path = match found_crate { + FoundCrate::Itself => quote!(crate), + FoundCrate::Name(name) => { + let ident = Ident::new(&name, Span::call_site()); + quote!( #ident ) + } + }; + + let header = if is_test { + quote! { + #[::core::prelude::v1::test] + } + } else { + quote! {} + }; + + // TODO: Enable setting of these, scheduler, etc + let config = quote! { + #crate_path::__default_shuttle_config() + }; + let num_iterations = get_env_usize("SHUTTLE_ITERATIONS", 100); + + let body = input.body(); + let body = match input.sig.output { + syn::ReturnType::Default => { + quote_spanned! {last_stmt_end_span=> + // From Tokio. Not sure if this scheme needs it as well, but doesn't hurt to have it. Can figure out whether it can be removed later. + #[allow(clippy::expect_used, clippy::diverging_sub_expression)] + { + #crate_path::__check( + move || { + #crate_path::runtime::Handle::current().block_on(async #body) + }, + #config, + #num_iterations, + ); + }; + } + } + // Tests with a return type becomes an unwrap on the return value, then have their return type wiped. + syn::ReturnType::Type(_, ref ty) => { + quote_spanned! {last_stmt_end_span=> + #[allow(clippy::expect_used, clippy::diverging_sub_expression)] + { + async fn __function_under_test() -> #ty { + #body + } + + #crate_path::__check( + move || { + #crate_path::runtime::Handle::current().block_on(async { __function_under_test().await.unwrap_or_else(|e| panic!("Test failed with error: {e:?}")); }) + }, + #config, + #num_iterations, + ); + }; + } + } + }; + + // Wipe the return type + input.sig.output = ReturnType::Default; + + let last_block = quote! {}; + + input.into_tokens(header, body, last_block) +} + +fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream { + tokens.extend(error.into_compile_error()); + tokens +} + +#[cfg(not(test))] // Work around for rust-lang/rust#62127 +pub(crate) fn main(args: TokenStream, item: TokenStream) -> TokenStream { + // If any of the steps for this macro fail, we still want to expand to an item that is as close + // to the expected output as possible. This helps out IDEs such that completions and other + // related features keep working. + let input: ItemFn = match syn::parse2(item.clone()) { + Ok(it) => it, + Err(e) => return token_stream_with_error(item, e), + }; + + let config = if input.sig.ident == "main" && !input.sig.inputs.is_empty() { + let msg = "the main function cannot accept arguments"; + Err(syn::Error::new_spanned(&input.sig.ident, msg)) + } else { + AttributeArgs::parse_terminated + .parse2(args) + .and_then(|args| build_config(&input, args, false)) + }; + + match config { + Ok(()) => parse_knobs(input, false), + Err(e) => token_stream_with_error(parse_knobs(input, false), e), + } +} + +pub(crate) fn test(args: TokenStream, item: TokenStream) -> TokenStream { + // If any of the steps for this macro fail, we still want to expand to an item that is as close + // to the expected output as possible. This helps out IDEs such that completions and other + // related features keep working. + let input: ItemFn = match syn::parse2(item.clone()) { + Ok(it) => it, + Err(e) => return token_stream_with_error(item, e), + }; + let config = if let Some(attr) = input.attrs().find(|attr| attr.meta.path().is_ident("test")) { + let msg = "second test attribute is supplied"; + Err(syn::Error::new_spanned(attr, msg)) + } else { + AttributeArgs::parse_terminated + .parse2(args) + .and_then(|args| build_config(&input, args, true)) + }; + + match config { + Ok(()) => parse_knobs(input, true), + Err(e) => token_stream_with_error(parse_knobs(input, true), e), + } +} + +struct ItemFn { + outer_attrs: Vec, + vis: Visibility, + sig: Signature, + brace_token: syn::token::Brace, + inner_attrs: Vec, + stmts: Vec, +} + +impl ItemFn { + /// Access all attributes of the function item. + fn attrs(&self) -> impl Iterator { + self.outer_attrs.iter().chain(self.inner_attrs.iter()) + } + + /// Get the body of the function item in a manner so that it can be + /// conveniently used with the `quote!` macro. + fn body(&self) -> Body<'_> { + Body { + brace_token: self.brace_token, + stmts: &self.stmts, + } + } + + /// Convert our local function item into a token stream. + fn into_tokens( + self, + header: proc_macro2::TokenStream, + body: proc_macro2::TokenStream, + last_block: proc_macro2::TokenStream, + ) -> TokenStream { + let mut tokens = proc_macro2::TokenStream::new(); + header.to_tokens(&mut tokens); + + // Outer attributes are simply streamed as-is. + for attr in self.outer_attrs { + attr.to_tokens(&mut tokens); + } + + // Inner attributes require extra care, since they're not supported on + // blocks (which is what we're expanded into) we instead lift them + // outside of the function. This matches the behavior of `syn`. + for mut attr in self.inner_attrs { + attr.style = syn::AttrStyle::Outer; + attr.to_tokens(&mut tokens); + } + + self.vis.to_tokens(&mut tokens); + self.sig.to_tokens(&mut tokens); + + self.brace_token.surround(&mut tokens, |tokens| { + body.to_tokens(tokens); + last_block.to_tokens(tokens); + }); + + tokens + } +} + +impl Parse for ItemFn { + #[inline] + fn parse(input: ParseStream<'_>) -> syn::Result { + // This parse implementation has been largely lifted from `syn`, with + // the exception of: + // * We don't have access to the plumbing necessary to parse inner + // attributes in-place. + // * We do our own statements parsing to avoid recursively parsing + // entire statements and only look for the parts we're interested in. + + let outer_attrs = input.call(Attribute::parse_outer)?; + let vis: Visibility = input.parse()?; + let sig: Signature = input.parse()?; + + let content; + let brace_token = braced!(content in input); + let inner_attrs = Attribute::parse_inner(&content)?; + + let mut buf = proc_macro2::TokenStream::new(); + let mut stmts = Vec::new(); + + while !content.is_empty() { + if let Some(semi) = content.parse::>()? { + semi.to_tokens(&mut buf); + stmts.push(buf); + buf = proc_macro2::TokenStream::new(); + continue; + } + + // Parse a single token tree and extend our current buffer with it. + // This avoids parsing the entire content of the sub-tree. + buf.extend([content.parse::()?]); + } + + if !buf.is_empty() { + stmts.push(buf); + } + + Ok(Self { + outer_attrs, + vis, + sig, + brace_token, + inner_attrs, + stmts, + }) + } +} + +struct Body<'a> { + brace_token: syn::token::Brace, + // Statements, with terminating `;`. + stmts: &'a [TokenStream], +} + +impl ToTokens for Body<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + self.brace_token.surround(tokens, |tokens| { + for stmt in self.stmts { + stmt.to_tokens(tokens); + } + }); + } +} diff --git a/wrappers/tokio/impls/tokio-macros/src/lib.rs b/wrappers/tokio/impls/tokio-macros/src/lib.rs new file mode 100644 index 00000000..77ddd239 --- /dev/null +++ b/wrappers/tokio/impls/tokio-macros/src/lib.rs @@ -0,0 +1,121 @@ +#![allow(clippy::needless_doctest_main)] +#![warn(missing_debug_implementations, missing_docs, rust_2018_idioms, unreachable_pub)] +#![doc(test( + no_crate_inject, + attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) +))] + +//! Macros for use when testing with shuttle-tokio +//! This package is not intended to be depended on directly, and should instead be used via `shuttle-tokio`. +//! Put up an issue if you have a use case for using it directly and we'll add a `shuttle-tokio-macros` crate. +// This file is code-wise a verbatim copy of [tokio-macros/src/lib.rs](https://github.com/tokio-rs/tokio/blob/9e94fa7e15cfe6ebbd06e9ebad4642896620d924/tokio-macros/src/lib.rs), +// but the examples in the comments have been removed. + +// This `extern` is required for older `rustc` versions but newer `rustc` +// versions warn about the unused `extern crate`. +#[allow(unused_extern_crates)] +extern crate proc_macro; + +mod entry; +mod select; + +use proc_macro::TokenStream; + +/// Marks async function to be executed by the selected runtime. This macro +/// helps set up a `Runtime` without requiring the user to use +/// [Runtime](../../shuttle-tokio/runtime/struct.Runtime.html) or +/// [Builder](../../shuttle-tokio/runtime/struct.Builder.html) directly. +/// +/// Note: This macro is designed to be simplistic and targets applications that +/// do not require a complex setup. If the provided functionality is not +/// sufficient, you may be interested in using +/// [Builder](../../shuttle-tokio/runtime/struct.Builder.html), which provides a more +/// powerful interface. +#[proc_macro_attribute] +#[cfg(not(test))] // Work around for rust-lang/rust#62127 +pub fn main(args: TokenStream, item: TokenStream) -> TokenStream { + entry::main(args.into(), item.into()).into() +} + +/// Marks async function to be executed by selected runtime. This macro helps set up a `Runtime` +/// without requiring the user to use [Runtime](../../shuttle-tokio/runtime/struct.Runtime.html) or +/// [Builder](../../shuttle-tokio/runtime/struct.Builder.html) directly. +#[proc_macro_attribute] +#[cfg(not(test))] // Work around for rust-lang/rust#62127 +pub fn main_rt(args: TokenStream, item: TokenStream) -> TokenStream { + entry::main(args.into(), item.into()).into() +} + +/// Marks async function to be executed by runtime, suitable to test environment. +/// This macro helps set up a `Runtime` without requiring the user to use +/// [Runtime](../../shuttle-tokio/runtime/struct.Runtime.html) or +/// [Builder](../../shuttle-tokio/runtime/struct.Builder.html) directly. +/// +/// Note: This macro is designed to be simplistic and targets applications that +/// do not require a complex setup. If the provided functionality is not +/// sufficient, you may be interested in using +/// [Builder](../../shuttle-tokio/runtime/struct.Builder.html), which provides a more +/// powerful interface. +#[proc_macro_attribute] +pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { + entry::test(args.into(), item.into()).into() +} + +/// Marks async function to be executed by runtime, suitable to test environment +/// +/// ## Usage +/// +/// ```no_run +/// #[shuttle_tokio::test] +/// async fn my_test() { +/// assert!(true); +/// } +/// ``` +#[proc_macro_attribute] +pub fn test_rt(args: TokenStream, item: TokenStream) -> TokenStream { + entry::test(args.into(), item.into()).into() +} + +/// Always fails with the error message below. +/// ```text +/// The #[shuttle_tokio::main] macro requires rt or rt-multi-thread. +/// ``` +#[proc_macro_attribute] +pub fn main_fail(_args: TokenStream, _item: TokenStream) -> TokenStream { + syn::Error::new( + proc_macro2::Span::call_site(), + "The #[shuttle_tokio::main] macro requires rt or rt-multi-thread.", + ) + .to_compile_error() + .into() +} + +/// Always fails with the error message below. +/// ```text +/// The #[shuttle_tokio::test] macro requires rt or rt-multi-thread. +/// ``` +#[proc_macro_attribute] +pub fn test_fail(_args: TokenStream, _item: TokenStream) -> TokenStream { + syn::Error::new( + proc_macro2::Span::call_site(), + "The #[shuttle_tokio::test] macro requires rt or rt-multi-thread.", + ) + .to_compile_error() + .into() +} + +/// Implementation detail of the `select!` macro. This macro is **not** intended +/// to be used as part of the public API and is permitted to change. +#[proc_macro] +#[doc(hidden)] +pub fn select_priv_declare_output_enum(input: TokenStream) -> TokenStream { + select::declare_output_enum(input) +} + +/// Implementation detail of the `select!` macro. This macro is **not** intended +/// to be used as part of the public API and is permitted to change. +#[proc_macro] +#[doc(hidden)] +pub fn select_priv_clean_pattern(input: TokenStream) -> TokenStream { + select::clean_pattern_macro(input) +} diff --git a/wrappers/tokio/impls/tokio-macros/src/select.rs b/wrappers/tokio/impls/tokio-macros/src/select.rs new file mode 100644 index 00000000..b2bc3164 --- /dev/null +++ b/wrappers/tokio/impls/tokio-macros/src/select.rs @@ -0,0 +1,110 @@ +// This file is lifted verbatim from [tokio-macros/src/select.rs](https://github.com/tokio-rs/tokio/blob/9e94fa7e15cfe6ebbd06e9ebad4642896620d924/tokio-macros/src/select.rs) +use proc_macro::{TokenStream, TokenTree}; +use proc_macro2::Span; +use quote::quote; +use syn::{parse::Parser, Ident}; + +pub(crate) fn declare_output_enum(input: TokenStream) -> TokenStream { + // passed in is: `(_ _ _)` with one `_` per branch + let branches = match input.into_iter().next() { + Some(TokenTree::Group(group)) => group.stream().into_iter().count(), + _ => panic!("unexpected macro input"), + }; + + let variants = (0..branches) + .map(|num| Ident::new(&format!("_{num}"), Span::call_site())) + .collect::>(); + + // Use a bitfield to track which futures completed + let mask = Ident::new( + if branches <= 8 { + "u8" + } else if branches <= 16 { + "u16" + } else if branches <= 32 { + "u32" + } else if branches <= 64 { + "u64" + } else { + panic!("up to 64 branches supported"); + }, + Span::call_site(), + ); + + TokenStream::from(quote! { + pub(super) enum Out<#( #variants ),*> { + #( #variants(#variants), )* + // Include a `Disabled` variant signifying that all select branches + // failed to resolve. + Disabled, + } + + pub(super) type Mask = #mask; + }) +} + +pub(crate) fn clean_pattern_macro(input: TokenStream) -> TokenStream { + // If this isn't a pattern, we return the token stream as-is. The select! + // macro is using it in a location requiring a pattern, so an error will be + // emitted there. + let mut input: syn::Pat = match syn::Pat::parse_single.parse(input.clone()) { + Ok(it) => it, + Err(_) => return input, + }; + + clean_pattern(&mut input); + quote::ToTokens::into_token_stream(input).into() +} + +// Removes any occurrences of ref or mut in the provided pattern. +fn clean_pattern(pat: &mut syn::Pat) { + match pat { + syn::Pat::Lit(_literal) => {} + syn::Pat::Macro(_macro) => {} + syn::Pat::Path(_path) => {} + syn::Pat::Range(_range) => {} + syn::Pat::Rest(_rest) => {} + syn::Pat::Verbatim(_tokens) => {} + syn::Pat::Wild(_underscore) => {} + syn::Pat::Ident(ident) => { + ident.by_ref = None; + ident.mutability = None; + if let Some((_at, pat)) = &mut ident.subpat { + clean_pattern(&mut *pat); + } + } + syn::Pat::Or(or) => { + for case in &mut or.cases { + clean_pattern(case); + } + } + syn::Pat::Slice(slice) => { + for elem in &mut slice.elems { + clean_pattern(elem); + } + } + syn::Pat::Struct(struct_pat) => { + for field in &mut struct_pat.fields { + clean_pattern(&mut field.pat); + } + } + syn::Pat::Tuple(tuple) => { + for elem in &mut tuple.elems { + clean_pattern(elem); + } + } + syn::Pat::TupleStruct(tuple) => { + for elem in &mut tuple.elems { + clean_pattern(elem); + } + } + syn::Pat::Reference(reference) => { + reference.mutability = None; + clean_pattern(&mut reference.pat); + } + syn::Pat::Type(type_pat) => { + clean_pattern(&mut type_pat.pat); + } + _ => {} + } +} diff --git a/wrappers/tokio/impls/tokio-stream/Cargo.toml b/wrappers/tokio/impls/tokio-stream/Cargo.toml new file mode 100644 index 00000000..082c9435 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "shuttle-tokio-stream-impl" +version = "0.1.0" # Forked from "0.1.14" +edition = "2021" +rust-version = "1.63" + +[features] +default = ["time"] + +full = [ + "time", + "net", + "io-util", + "fs", + "sync", + "signal" +] + +time = [] +net = [] +io-util = [] +fs = [] +sync = [] +signal = [] + +[dependencies] +tokio = { path = "../tokio", package = "shuttle-tokio-impl", version = "*" } +shuttle = { version = "*", path = "../../../../shuttle" } +futures-core = { version = "0.3.0" } +pin-project-lite = "0.2.11" +tokio-util = "0.7.0" + +[dev-dependencies] +async-stream = "0.3" +tokio-test = { package = "shuttle-tokio-test-impl", path = "../tokio-test" } +futures = { version = "0.3", default-features = false } diff --git a/wrappers/tokio/impls/tokio-stream/LICENSE b/wrappers/tokio/impls/tokio-stream/LICENSE new file mode 100644 index 00000000..8bdf6bd6 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2023 Tokio Contributors + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/wrappers/tokio/impls/tokio-stream/README.md b/wrappers/tokio/impls/tokio-stream/README.md new file mode 100644 index 00000000..b7b96cd8 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/README.md @@ -0,0 +1,6 @@ +# Shuttle support for tokio-stream + +This package was derived from tokio-stream by: +1. Copying the source files from the original tokio-stream crate +2. Replacing the tokio dependency with Shuttle's tokio implementation in Cargo.toml +3. Removing extraneous files and dependencies (docs, fuzz tests, CHANGELOG.md, etc.) diff --git a/wrappers/tokio/impls/tokio-stream/src/empty.rs b/wrappers/tokio/impls/tokio-stream/src/empty.rs new file mode 100644 index 00000000..03363b91 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/empty.rs @@ -0,0 +1,36 @@ +use crate::Stream; + +use core::marker::PhantomData; +use core::pin::Pin; +use core::task::{Context, Poll}; + +/// Stream for the [`empty`](fn@empty) function. +#[derive(Debug)] +#[must_use = "streams do nothing unless polled"] +pub struct Empty(PhantomData); + +impl Unpin for Empty {} +unsafe impl Send for Empty {} +unsafe impl Sync for Empty {} + +/// Creates a stream that yields nothing. +/// +/// The returned stream is immediately ready and returns `None`. Use +/// [`stream::pending()`](super::pending()) to obtain a stream that is never +/// ready. +/// +pub const fn empty() -> Empty { + Empty(PhantomData) +} + +impl Stream for Empty { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(None) + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(0)) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/iter.rs b/wrappers/tokio/impls/tokio-stream/src/iter.rs new file mode 100644 index 00000000..63d4984b --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/iter.rs @@ -0,0 +1,55 @@ +use crate::Stream; + +use core::pin::Pin; +use core::task::{Context, Poll}; + +/// Stream for the [`iter`](fn@iter) function. +#[derive(Debug)] +#[must_use = "streams do nothing unless polled"] +pub struct Iter { + iter: I, + yield_amt: usize, +} + +impl Unpin for Iter {} + +/// Converts an `Iterator` into a `Stream` which is always ready +/// to yield the next value. +/// +/// Iterators in Rust don't express the ability to block, so this adapter +/// simply always calls `iter.next()` and returns that. +pub fn iter(i: I) -> Iter +where + I: IntoIterator, +{ + Iter { + iter: i.into_iter(), + yield_amt: 0, + } +} + +impl Stream for Iter +where + I: Iterator, +{ + type Item = I::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // TODO: add coop back + if self.yield_amt >= 32 { + self.yield_amt = 0; + + cx.waker().wake_by_ref(); + + Poll::Pending + } else { + self.yield_amt += 1; + + Poll::Ready(self.iter.next()) + } + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/lib.rs b/wrappers/tokio/impls/tokio-stream/src/lib.rs new file mode 100644 index 00000000..7c10c299 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/lib.rs @@ -0,0 +1,52 @@ +#![allow( + clippy::cognitive_complexity, + clippy::large_enum_variant, + clippy::needless_doctest_main +)] +#![warn(missing_debug_implementations, missing_docs, rust_2018_idioms, unreachable_pub)] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![doc(test( + no_crate_inject, + attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) +))] + +//! This crate contains [`Shuttle`] internal implementations of the [`tokio-stream`] crate. +//! Do not depend on this crate directly. Use the `shuttle-tokio-stream` crate instead, which +//! exposes these Shuttle-compatible implementations when the `shuttle` feature is enabled, +//! or the original tokio-stream functionality when the feature is disabled. +//! +//! [`Shuttle`]: +//! +//! [`tokio-stream`]: + +#[macro_use] +mod macros; + +pub mod wrappers; + +mod stream_ext; +pub use stream_ext::{collect::FromStream, StreamExt}; +cfg_time! { + pub use stream_ext::timeout::{Elapsed, Timeout}; +} + +mod empty; +pub use empty::{empty, Empty}; + +mod iter; +pub use iter::{iter, Iter}; + +mod once; +pub use once::{once, Once}; + +mod pending; +pub use pending::{pending, Pending}; + +mod stream_map; +pub use stream_map::StreamMap; + +mod stream_close; +pub use stream_close::StreamNotifyClose; + +#[doc(no_inline)] +pub use futures_core::Stream; diff --git a/wrappers/tokio/impls/tokio-stream/src/macros.rs b/wrappers/tokio/impls/tokio-stream/src/macros.rs new file mode 100644 index 00000000..fbc848d6 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/macros.rs @@ -0,0 +1,70 @@ +#![allow(unused)] // TODO: Remove. A few of these are unused as the functionality they are gating is unused. + +macro_rules! cfg_fs { + ($($item:item)*) => { + $( + #[cfg(feature = "fs")] + #[cfg_attr(docsrs, doc(cfg(feature = "fs")))] + $item + )* + } +} + +macro_rules! cfg_io_util { + ($($item:item)*) => { + $( + #[cfg(feature = "io-util")] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + $item + )* + } +} + +macro_rules! cfg_net { + ($($item:item)*) => { + $( + #[cfg(feature = "net")] + #[cfg_attr(docsrs, doc(cfg(feature = "net")))] + $item + )* + } +} + +macro_rules! cfg_time { + ($($item:item)*) => { + $( + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + $item + )* + } +} + +macro_rules! cfg_sync { + ($($item:item)*) => { + $( + #[cfg(feature = "sync")] + #[cfg_attr(docsrs, doc(cfg(feature = "sync")))] + $item + )* + } +} + +macro_rules! cfg_signal { + ($($item:item)*) => { + $( + #[cfg(feature = "signal")] + #[cfg_attr(docsrs, doc(cfg(feature = "signal")))] + $item + )* + } +} + +macro_rules! ready { + ($e:expr $(,)?) => { + match $e { + std::task::Poll::Ready(t) => t, + std::task::Poll::Pending => return std::task::Poll::Pending, + } + }; +} diff --git a/wrappers/tokio/impls/tokio-stream/src/once.rs b/wrappers/tokio/impls/tokio-stream/src/once.rs new file mode 100644 index 00000000..a3824651 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/once.rs @@ -0,0 +1,35 @@ +use crate::{Iter, Stream}; + +use core::option; +use core::pin::Pin; +use core::task::{Context, Poll}; + +/// Stream for the [`once`](fn@once) function. +#[derive(Debug)] +#[must_use = "streams do nothing unless polled"] +pub struct Once { + iter: Iter>, +} + +impl Unpin for Once {} + +/// Creates a stream that emits an element exactly once. +/// +/// The returned stream is immediately ready and emits the provided value once. +pub fn once(value: T) -> Once { + Once { + iter: crate::iter(Some(value)), + } +} + +impl Stream for Once { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.iter).poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/pending.rs b/wrappers/tokio/impls/tokio-stream/src/pending.rs new file mode 100644 index 00000000..afe44340 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/pending.rs @@ -0,0 +1,36 @@ +use crate::Stream; + +use core::marker::PhantomData; +use core::pin::Pin; +use core::task::{Context, Poll}; + +/// Stream for the [`pending`](fn@pending) function. +#[derive(Debug)] +#[must_use = "streams do nothing unless polled"] +pub struct Pending(PhantomData); + +impl Unpin for Pending {} +unsafe impl Send for Pending {} +unsafe impl Sync for Pending {} + +/// Creates a stream that is never ready +/// +/// The returned stream is never ready. Attempting to call +/// [`next()`](crate::StreamExt::next) will never complete. Use +/// [`stream::empty()`](super::empty()) to obtain a stream that is is +/// immediately empty but returns no values. +pub const fn pending() -> Pending { + Pending(PhantomData) +} + +impl Stream for Pending { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Pending + } + + fn size_hint(&self) -> (usize, Option) { + (0, None) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_close.rs b/wrappers/tokio/impls/tokio-stream/src/stream_close.rs new file mode 100644 index 00000000..21a3ec54 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_close.rs @@ -0,0 +1,68 @@ +use crate::Stream; +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A `Stream` that wraps the values in an `Option`. + /// + /// Whenever the wrapped stream yields an item, this stream yields that item + /// wrapped in `Some`. When the inner stream ends, then this stream first + /// yields a `None` item, and then this stream will also end. + #[must_use = "streams do nothing unless polled"] + pub struct StreamNotifyClose { + #[pin] + inner: Option, + } +} + +impl StreamNotifyClose { + /// Create a new `StreamNotifyClose`. + pub fn new(stream: S) -> Self { + Self { inner: Some(stream) } + } + + /// Get back the inner `Stream`. + /// + /// Returns `None` if the stream has reached its end. + pub fn into_inner(self) -> Option { + self.inner + } +} + +impl Stream for StreamNotifyClose +where + S: Stream, +{ + type Item = Option; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // We can't invoke poll_next after it ended, so we unset the inner stream as a marker. + match self + .as_mut() + .project() + .inner + .as_pin_mut() + .map(|stream| S::poll_next(stream, cx)) + { + Some(Poll::Ready(Some(item))) => Poll::Ready(Some(Some(item))), + Some(Poll::Ready(None)) => { + self.project().inner.set(None); + Poll::Ready(Some(None)) + } + Some(Poll::Pending) => Poll::Pending, + None => Poll::Ready(None), + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + if let Some(inner) = &self.inner { + // We always return +1 because when there's stream there's atleast one more item. + let (l, u) = inner.size_hint(); + (l.saturating_add(1), u.and_then(|u| u.checked_add(1))) + } else { + (0, Some(0)) + } + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext.rs new file mode 100644 index 00000000..19c22a6d --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext.rs @@ -0,0 +1,1205 @@ +use core::future::Future; +use futures_core::Stream; + +mod all; +use all::AllFuture; + +mod any; +use any::AnyFuture; + +mod chain; +use chain::Chain; + +pub(crate) mod collect; +use collect::{Collect, FromStream}; + +mod filter; +use filter::Filter; + +mod filter_map; +use filter_map::FilterMap; + +mod fold; +use fold::FoldFuture; + +mod fuse; +use fuse::Fuse; + +mod map; +use map::Map; + +mod map_while; +use map_while::MapWhile; + +mod merge; +use merge::Merge; + +mod next; +use next::Next; + +mod skip; +use skip::Skip; + +mod skip_while; +use skip_while::SkipWhile; + +mod take; +use take::Take; + +mod take_while; +use take_while::TakeWhile; + +mod then; +use then::Then; + +mod try_next; +use try_next::TryNext; + +mod peekable; +use peekable::Peekable; + +cfg_time! { + pub(crate) mod timeout; + pub(crate) mod timeout_repeating; + use timeout::Timeout; + use timeout_repeating::TimeoutRepeating; + use tokio::time::{Duration, Interval}; + mod throttle; + use throttle::{throttle, Throttle}; + mod chunks_timeout; + use chunks_timeout::ChunksTimeout; +} + +/// An extension trait for the [`Stream`] trait that provides a variety of +/// convenient combinator functions. +/// +/// Be aware that the `Stream` trait in Tokio is a re-export of the trait found +/// in the [futures] crate, however both Tokio and futures provide separate +/// `StreamExt` utility traits, and some utilities are only available on one of +/// these traits. Click [here][futures-StreamExt] to see the other `StreamExt` +/// trait in the futures crate. +/// +/// If you need utilities from both `StreamExt` traits, you should prefer to +/// import one of them, and use the other through the fully qualified call +/// syntax. +/// +/// [`Stream`]: crate::Stream +/// [futures]: https://docs.rs/futures +/// [futures-StreamExt]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html +pub trait StreamExt: Stream { + /// Consumes and returns the next value in the stream or `None` if the + /// stream is finished. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn next(&mut self) -> Option; + /// ``` + /// + /// Note that because `next` doesn't take ownership over the stream, + /// the [`Stream`] type must be [`Unpin`]. If you want to use `next` with a + /// [`!Unpin`](Unpin) stream, you'll first have to pin the stream. This can + /// be done by boxing the stream using [`Box::pin`] or + /// pinning it to the stack using the `pin_mut!` macro from the `pin_utils` + /// crate. + /// + /// # Cancel safety + /// + /// This method is cancel safe. The returned future only + /// holds onto a reference to the underlying stream, + /// so dropping it will never lose a value. + /// + /// # Examples + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let mut stream = stream::iter(1..=3); + /// + /// assert_eq!(stream.next().await, Some(1)); + /// assert_eq!(stream.next().await, Some(2)); + /// assert_eq!(stream.next().await, Some(3)); + /// assert_eq!(stream.next().await, None); + /// # } + /// ``` + fn next(&mut self) -> Next<'_, Self> + where + Self: Unpin, + { + Next::new(self) + } + + /// Consumes and returns the next item in the stream. If an error is + /// encountered before the next item, the error is returned instead. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn try_next(&mut self) -> Result, E>; + /// ``` + /// + /// This is similar to the [`next`](StreamExt::next) combinator, + /// but returns a [`Result, E>`](Result) rather than + /// an [`Option>`](Option), making for easy use + /// with the [`?`](std::ops::Try) operator. + /// + /// # Cancel safety + /// + /// This method is cancel safe. The returned future only + /// holds onto a reference to the underlying stream, + /// so dropping it will never lose a value. + /// + /// # Examples + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let mut stream = stream::iter(vec![Ok(1), Ok(2), Err("nope")]); + /// + /// assert_eq!(stream.try_next().await, Ok(Some(1))); + /// assert_eq!(stream.try_next().await, Ok(Some(2))); + /// assert_eq!(stream.try_next().await, Err("nope")); + /// # } + /// ``` + fn try_next(&mut self) -> TryNext<'_, Self> + where + Self: Stream> + Unpin, + { + TryNext::new(self) + } + + /// Maps this stream's items to a different type, returning a new stream of + /// the resulting type. + /// + /// The provided closure is executed over all elements of this stream as + /// they are made available. It is executed inline with calls to + /// [`poll_next`](Stream::poll_next). + /// + /// Note that this function consumes the stream passed into it and returns a + /// wrapped version of it, similar to the existing `map` methods in the + /// standard library. + /// + /// # Examples + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let stream = stream::iter(1..=3); + /// let mut stream = stream.map(|x| x + 3); + /// + /// assert_eq!(stream.next().await, Some(4)); + /// assert_eq!(stream.next().await, Some(5)); + /// assert_eq!(stream.next().await, Some(6)); + /// # } + /// ``` + fn map(self, f: F) -> Map + where + F: FnMut(Self::Item) -> T, + Self: Sized, + { + Map::new(self, f) + } + + /// Map this stream's items to a different type for as long as determined by + /// the provided closure. A stream of the target type will be returned, + /// which will yield elements until the closure returns `None`. + /// + /// The provided closure is executed over all elements of this stream as + /// they are made available, until it returns `None`. It is executed inline + /// with calls to [`poll_next`](Stream::poll_next). Once `None` is returned, + /// the underlying stream will not be polled again. + /// + /// Note that this function consumes the stream passed into it and returns a + /// wrapped version of it, similar to the [`Iterator::map_while`] method in the + /// standard library. + /// + /// # Examples + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let stream = stream::iter(1..=10); + /// let mut stream = stream.map_while(|x| { + /// if x < 4 { + /// Some(x + 3) + /// } else { + /// None + /// } + /// }); + /// assert_eq!(stream.next().await, Some(4)); + /// assert_eq!(stream.next().await, Some(5)); + /// assert_eq!(stream.next().await, Some(6)); + /// assert_eq!(stream.next().await, None); + /// # } + /// ``` + fn map_while(self, f: F) -> MapWhile + where + F: FnMut(Self::Item) -> Option, + Self: Sized, + { + MapWhile::new(self, f) + } + + /// Maps this stream's items asynchronously to a different type, returning a + /// new stream of the resulting type. + /// + /// The provided closure is executed over all elements of this stream as + /// they are made available, and the returned future is executed. Only one + /// future is executed at the time. + /// + /// Note that this function consumes the stream passed into it and returns a + /// wrapped version of it, similar to the existing `then` methods in the + /// standard library. + /// + /// Be aware that if the future is not `Unpin`, then neither is the `Stream` + /// returned by this method. To handle this, you can use `tokio::pin!` as in + /// the example below or put the stream in a `Box` with `Box::pin(stream)`. + /// + /// # Examples + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// async fn do_async_work(value: i32) -> i32 { + /// value + 3 + /// } + /// + /// let stream = stream::iter(1..=3); + /// let stream = stream.then(do_async_work); + /// + /// tokio::pin!(stream); + /// + /// assert_eq!(stream.next().await, Some(4)); + /// assert_eq!(stream.next().await, Some(5)); + /// assert_eq!(stream.next().await, Some(6)); + /// # } + /// ``` + fn then(self, f: F) -> Then + where + F: FnMut(Self::Item) -> Fut, + Fut: Future, + Self: Sized, + { + Then::new(self, f) + } + + /// Combine two streams into one by interleaving the output of both as it + /// is produced. + /// + /// Values are produced from the merged stream in the order they arrive from + /// the two source streams. If both source streams provide values + /// simultaneously, the merge stream alternates between them. This provides + /// some level of fairness. You should not chain calls to `merge`, as this + /// will break the fairness of the merging. + /// + /// The merged stream completes once **both** source streams complete. When + /// one source stream completes before the other, the merge stream + /// exclusively polls the remaining stream. + /// + /// For merging multiple streams, consider using [`StreamMap`] instead. + /// + /// [`StreamMap`]: crate::StreamMap + /// + /// # Examples + /// + /// ```ignore + /// use shuttle_tokio_stream_impl::{StreamExt, Stream}; + /// use tokio::sync::mpsc; + /// use tokio::time; + /// + /// use std::time::Duration; + /// use std::pin::Pin; + /// + /// # /* + /// #[tokio::main] + /// # */ + /// # #[tokio::main(flavor = "current_thread")] + /// async fn main() { + /// # time::pause(); + /// let (tx1, mut rx1) = mpsc::channel::(10); + /// let (tx2, mut rx2) = mpsc::channel::(10); + /// + /// // Convert the channels to a `Stream`. + /// let rx1 = Box::pin(async_stream::stream! { + /// while let Some(item) = rx1.recv().await { + /// yield item; + /// } + /// }) as Pin + Send>>; + /// + /// let rx2 = Box::pin(async_stream::stream! { + /// while let Some(item) = rx2.recv().await { + /// yield item; + /// } + /// }) as Pin + Send>>; + /// + /// let mut rx = rx1.merge(rx2); + /// + /// tokio::spawn(async move { + /// // Send some values immediately + /// tx1.send(1).await.unwrap(); + /// tx1.send(2).await.unwrap(); + /// + /// // Let the other task send values + /// time::sleep(Duration::from_millis(20)).await; + /// + /// tx1.send(4).await.unwrap(); + /// }); + /// + /// tokio::spawn(async move { + /// // Wait for the first task to send values + /// time::sleep(Duration::from_millis(5)).await; + /// + /// tx2.send(3).await.unwrap(); + /// + /// time::sleep(Duration::from_millis(25)).await; + /// + /// // Send the final value + /// tx2.send(5).await.unwrap(); + /// }); + /// + /// assert_eq!(1, rx.next().await.unwrap()); + /// assert_eq!(2, rx.next().await.unwrap()); + /// assert_eq!(3, rx.next().await.unwrap()); + /// assert_eq!(4, rx.next().await.unwrap()); + /// assert_eq!(5, rx.next().await.unwrap()); + /// + /// // The merged stream is consumed + /// assert!(rx.next().await.is_none()); + /// } + /// ``` + fn merge(self, other: U) -> Merge + where + U: Stream, + Self: Sized, + { + Merge::new(self, other) + } + + /// Filters the values produced by this stream according to the provided + /// predicate. + /// + /// As values of this stream are made available, the provided predicate `f` + /// will be run against them. If the predicate + /// resolves to `true`, then the stream will yield the value, but if the + /// predicate resolves to `false`, then the value + /// will be discarded and the next value will be produced. + /// + /// Note that this function consumes the stream passed into it and returns a + /// wrapped version of it, similar to [`Iterator::filter`] method in the + /// standard library. + /// + /// # Examples + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let stream = stream::iter(1..=8); + /// let mut evens = stream.filter(|x| x % 2 == 0); + /// + /// assert_eq!(Some(2), evens.next().await); + /// assert_eq!(Some(4), evens.next().await); + /// assert_eq!(Some(6), evens.next().await); + /// assert_eq!(Some(8), evens.next().await); + /// assert_eq!(None, evens.next().await); + /// # } + /// ``` + fn filter(self, f: F) -> Filter + where + F: FnMut(&Self::Item) -> bool, + Self: Sized, + { + Filter::new(self, f) + } + + /// Filters the values produced by this stream while simultaneously mapping + /// them to a different type according to the provided closure. + /// + /// As values of this stream are made available, the provided function will + /// be run on them. If the predicate `f` resolves to + /// [`Some(item)`](Some) then the stream will yield the value `item`, but if + /// it resolves to [`None`], then the value will be skipped. + /// + /// Note that this function consumes the stream passed into it and returns a + /// wrapped version of it, similar to [`Iterator::filter_map`] method in the + /// standard library. + /// + /// # Examples + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let stream = stream::iter(1..=8); + /// let mut evens = stream.filter_map(|x| { + /// if x % 2 == 0 { Some(x + 1) } else { None } + /// }); + /// + /// assert_eq!(Some(3), evens.next().await); + /// assert_eq!(Some(5), evens.next().await); + /// assert_eq!(Some(7), evens.next().await); + /// assert_eq!(Some(9), evens.next().await); + /// assert_eq!(None, evens.next().await); + /// # } + /// ``` + fn filter_map(self, f: F) -> FilterMap + where + F: FnMut(Self::Item) -> Option, + Self: Sized, + { + FilterMap::new(self, f) + } + + /// Creates a stream which ends after the first `None`. + /// + /// After a stream returns `None`, behavior is undefined. Future calls to + /// `poll_next` may or may not return `Some(T)` again or they may panic. + /// `fuse()` adapts a stream, ensuring that after `None` is given, it will + /// return `None` forever. + /// + /// # Examples + /// + /// ```ignore + /// use shuttle_tokio_stream_impl::{Stream, StreamExt}; + /// + /// use std::pin::Pin; + /// use std::task::{Context, Poll}; + /// + /// // a stream which alternates between Some and None + /// struct Alternate { + /// state: i32, + /// } + /// + /// impl Stream for Alternate { + /// type Item = i32; + /// + /// fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + /// let val = self.state; + /// self.state = self.state + 1; + /// + /// // if it's even, Some(i32), else None + /// if val % 2 == 0 { + /// Poll::Ready(Some(val)) + /// } else { + /// Poll::Ready(None) + /// } + /// } + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let mut stream = Alternate { state: 0 }; + /// + /// // the stream goes back and forth + /// assert_eq!(stream.next().await, Some(0)); + /// assert_eq!(stream.next().await, None); + /// assert_eq!(stream.next().await, Some(2)); + /// assert_eq!(stream.next().await, None); + /// + /// // however, once it is fused + /// let mut stream = stream.fuse(); + /// + /// assert_eq!(stream.next().await, Some(4)); + /// assert_eq!(stream.next().await, None); + /// + /// // it will always return `None` after the first time. + /// assert_eq!(stream.next().await, None); + /// assert_eq!(stream.next().await, None); + /// assert_eq!(stream.next().await, None); + /// } + /// ``` + fn fuse(self) -> Fuse + where + Self: Sized, + { + Fuse::new(self) + } + + /// Creates a new stream of at most `n` items of the underlying stream. + /// + /// Once `n` items have been yielded from this stream then it will always + /// return that the stream is done. + /// + /// # Examples + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let mut stream = stream::iter(1..=10).take(3); + /// + /// assert_eq!(Some(1), stream.next().await); + /// assert_eq!(Some(2), stream.next().await); + /// assert_eq!(Some(3), stream.next().await); + /// assert_eq!(None, stream.next().await); + /// # } + /// ``` + fn take(self, n: usize) -> Take + where + Self: Sized, + { + Take::new(self, n) + } + + /// Take elements from this stream while the provided predicate + /// resolves to `true`. + /// + /// This function, like `Iterator::take_while`, will take elements from the + /// stream until the predicate `f` resolves to `false`. Once one element + /// returns false it will always return that the stream is done. + /// + /// # Examples + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let mut stream = stream::iter(1..=10).take_while(|x| *x <= 3); + /// + /// assert_eq!(Some(1), stream.next().await); + /// assert_eq!(Some(2), stream.next().await); + /// assert_eq!(Some(3), stream.next().await); + /// assert_eq!(None, stream.next().await); + /// # } + /// ``` + fn take_while(self, f: F) -> TakeWhile + where + F: FnMut(&Self::Item) -> bool, + Self: Sized, + { + TakeWhile::new(self, f) + } + + /// Creates a new stream that will skip the `n` first items of the + /// underlying stream. + /// + /// # Examples + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let mut stream = stream::iter(1..=10).skip(7); + /// + /// assert_eq!(Some(8), stream.next().await); + /// assert_eq!(Some(9), stream.next().await); + /// assert_eq!(Some(10), stream.next().await); + /// assert_eq!(None, stream.next().await); + /// # } + /// ``` + fn skip(self, n: usize) -> Skip + where + Self: Sized, + { + Skip::new(self, n) + } + + /// Skip elements from the underlying stream while the provided predicate + /// resolves to `true`. + /// + /// This function, like [`Iterator::skip_while`], will ignore elements from the + /// stream until the predicate `f` resolves to `false`. Once one element + /// returns false, the rest of the elements will be yielded. + /// + /// [`Iterator::skip_while`]: std::iter::Iterator::skip_while() + /// + /// # Examples + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// let mut stream = stream::iter(vec![1,2,3,4,1]).skip_while(|x| *x < 3); + /// + /// assert_eq!(Some(3), stream.next().await); + /// assert_eq!(Some(4), stream.next().await); + /// assert_eq!(Some(1), stream.next().await); + /// assert_eq!(None, stream.next().await); + /// # } + /// ``` + fn skip_while(self, f: F) -> SkipWhile + where + F: FnMut(&Self::Item) -> bool, + Self: Sized, + { + SkipWhile::new(self, f) + } + + /// Tests if every element of the stream matches a predicate. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn all(&mut self, f: F) -> bool; + /// ``` + /// + /// `all()` takes a closure that returns `true` or `false`. It applies + /// this closure to each element of the stream, and if they all return + /// `true`, then so does `all`. If any of them return `false`, it + /// returns `false`. An empty stream returns `true`. + /// + /// `all()` is short-circuiting; in other words, it will stop processing + /// as soon as it finds a `false`, given that no matter what else happens, + /// the result will also be `false`. + /// + /// An empty stream returns `true`. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let a = [1, 2, 3]; + /// + /// assert!(stream::iter(&a).all(|&x| x > 0).await); + /// + /// assert!(!stream::iter(&a).all(|&x| x > 2).await); + /// # } + /// ``` + /// + /// Stopping at the first `false`: + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let a = [1, 2, 3]; + /// + /// let mut iter = stream::iter(&a); + /// + /// assert!(!iter.all(|&x| x != 2).await); + /// + /// // we can still use `iter`, as there are more elements. + /// assert_eq!(iter.next().await, Some(&3)); + /// # } + /// ``` + fn all(&mut self, f: F) -> AllFuture<'_, Self, F> + where + Self: Unpin, + F: FnMut(Self::Item) -> bool, + { + AllFuture::new(self, f) + } + + /// Tests if any element of the stream matches a predicate. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn any(&mut self, f: F) -> bool; + /// ``` + /// + /// `any()` takes a closure that returns `true` or `false`. It applies + /// this closure to each element of the stream, and if any of them return + /// `true`, then so does `any()`. If they all return `false`, it + /// returns `false`. + /// + /// `any()` is short-circuiting; in other words, it will stop processing + /// as soon as it finds a `true`, given that no matter what else happens, + /// the result will also be `true`. + /// + /// An empty stream returns `false`. + /// + /// Basic usage: + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let a = [1, 2, 3]; + /// + /// assert!(stream::iter(&a).any(|&x| x > 0).await); + /// + /// assert!(!stream::iter(&a).any(|&x| x > 5).await); + /// # } + /// ``` + /// + /// Stopping at the first `true`: + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// let a = [1, 2, 3]; + /// + /// let mut iter = stream::iter(&a); + /// + /// assert!(iter.any(|&x| x != 2).await); + /// + /// // we can still use `iter`, as there are more elements. + /// assert_eq!(iter.next().await, Some(&2)); + /// # } + /// ``` + fn any(&mut self, f: F) -> AnyFuture<'_, Self, F> + where + Self: Unpin, + F: FnMut(Self::Item) -> bool, + { + AnyFuture::new(self, f) + } + + /// Combine two streams into one by first returning all values from the + /// first stream then all values from the second stream. + /// + /// As long as `self` still has values to emit, no values from `other` are + /// emitted, even if some are ready. + /// + /// # Examples + /// + /// ```ignore + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// #[tokio::main] + /// async fn main() { + /// let one = stream::iter(vec![1, 2, 3]); + /// let two = stream::iter(vec![4, 5, 6]); + /// + /// let mut stream = one.chain(two); + /// + /// assert_eq!(stream.next().await, Some(1)); + /// assert_eq!(stream.next().await, Some(2)); + /// assert_eq!(stream.next().await, Some(3)); + /// assert_eq!(stream.next().await, Some(4)); + /// assert_eq!(stream.next().await, Some(5)); + /// assert_eq!(stream.next().await, Some(6)); + /// assert_eq!(stream.next().await, None); + /// } + /// ``` + fn chain(self, other: U) -> Chain + where + U: Stream, + Self: Sized, + { + Chain::new(self, other) + } + + /// A combinator that applies a function to every element in a stream + /// producing a single, final value. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn fold(self, init: B, f: F) -> B; + /// ``` + /// + /// # Examples + /// Basic usage: + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, *}; + /// + /// let s = stream::iter(vec![1u8, 2, 3]); + /// let sum = s.fold(0, |acc, x| acc + x).await; + /// + /// assert_eq!(sum, 6); + /// # } + /// ``` + fn fold(self, init: B, f: F) -> FoldFuture + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + FoldFuture::new(self, init, f) + } + + /// Drain stream pushing all emitted values into a collection. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn collect(self) -> T; + /// ``` + /// + /// `collect` streams all values, awaiting as needed. Values are pushed into + /// a collection. A number of different target collection types are + /// supported, including [`Vec`], [`String`], and [`Bytes`]. + /// + /// [`Bytes`]: https://docs.rs/bytes/0.6.0/bytes/struct.Bytes.html + /// + /// # `Result` + /// + /// `collect()` can also be used with streams of type `Result` where + /// `T: FromStream<_>`. In this case, `collect()` will stream as long as + /// values yielded from the stream are `Ok(_)`. If `Err(_)` is encountered, + /// streaming is terminated and `collect()` returns the `Err`. + /// + /// # Notes + /// + /// `FromStream` is currently a sealed trait. Stabilization is pending + /// enhancements to the Rust language. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```ignore + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// #[tokio::main] + /// async fn main() { + /// let doubled: Vec = + /// stream::iter(vec![1, 2, 3]) + /// .map(|x| x * 2) + /// .collect() + /// .await; + /// + /// assert_eq!(vec![2, 4, 6], doubled); + /// } + /// ``` + /// + /// Collecting a stream of `Result` values + /// + /// ```ignore + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// #[tokio::main] + /// async fn main() { + /// // A stream containing only `Ok` values will be collected + /// let values: Result, &str> = + /// stream::iter(vec![Ok(1), Ok(2), Ok(3)]) + /// .collect() + /// .await; + /// + /// assert_eq!(Ok(vec![1, 2, 3]), values); + /// + /// // A stream containing `Err` values will return the first error. + /// let results = vec![Ok(1), Err("no"), Ok(2), Ok(3), Err("nein")]; + /// + /// let values: Result, &str> = + /// stream::iter(results) + /// .collect() + /// .await; + /// + /// assert_eq!(Err("no"), values); + /// } + /// ``` + fn collect(self) -> Collect + where + T: FromStream, + Self: Sized, + { + Collect::new(self) + } + + /// Applies a per-item timeout to the passed stream. + /// + /// `timeout()` takes a `Duration` that represents the maximum amount of + /// time each element of the stream has to complete before timing out. + /// + /// If the wrapped stream yields a value before the deadline is reached, the + /// value is returned. Otherwise, an error is returned. The caller may decide + /// to continue consuming the stream and will eventually get the next source + /// stream value once it becomes available. See + /// [`timeout_repeating`](StreamExt::timeout_repeating) for an alternative + /// where the timeouts will repeat. + /// + /// # Notes + /// + /// This function consumes the stream passed into it and returns a + /// wrapped version of it. + /// + /// Polling the returned stream will continue to poll the inner stream even + /// if one or more items time out. + /// + /// # Examples + /// + /// Suppose we have a stream `int_stream` that yields 3 numbers (1, 2, 3): + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// use std::time::Duration; + /// # let int_stream = stream::iter(1..=3); + /// + /// let int_stream = int_stream.timeout(Duration::from_secs(1)); + /// tokio::pin!(int_stream); + /// + /// // When no items time out, we get the 3 elements in succession: + /// assert_eq!(int_stream.try_next().await, Ok(Some(1))); + /// assert_eq!(int_stream.try_next().await, Ok(Some(2))); + /// assert_eq!(int_stream.try_next().await, Ok(Some(3))); + /// assert_eq!(int_stream.try_next().await, Ok(None)); + /// + /// // If the second item times out, we get an error and continue polling the stream: + /// # let mut int_stream = stream::iter(vec![Ok(1), Err(()), Ok(2), Ok(3)]); + /// assert_eq!(int_stream.try_next().await, Ok(Some(1))); + /// assert!(int_stream.try_next().await.is_err()); + /// assert_eq!(int_stream.try_next().await, Ok(Some(2))); + /// assert_eq!(int_stream.try_next().await, Ok(Some(3))); + /// assert_eq!(int_stream.try_next().await, Ok(None)); + /// + /// // If we want to stop consuming the source stream the first time an + /// // element times out, we can use the `take_while` operator: + /// # let int_stream = stream::iter(vec![Ok(1), Err(()), Ok(2), Ok(3)]); + /// let mut int_stream = int_stream.take_while(Result::is_ok); + /// + /// assert_eq!(int_stream.try_next().await, Ok(Some(1))); + /// assert_eq!(int_stream.try_next().await, Ok(None)); + /// # } + /// ``` + /// + /// Once a timeout error is received, no further events will be received + /// unless the wrapped stream yields a value (timeouts do not repeat). + /// + /// ```ignore + /// # #[tokio::main(flavor = "current_thread", start_paused = true)] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{StreamExt, wrappers::IntervalStream}; + /// use std::time::Duration; + /// let interval_stream = IntervalStream::new(tokio::time::interval(Duration::from_millis(100))); + /// let timeout_stream = interval_stream.timeout(Duration::from_millis(10)); + /// tokio::pin!(timeout_stream); + /// + /// // Only one timeout will be received between values in the source stream. + /// assert!(timeout_stream.try_next().await.is_ok()); + /// assert!(timeout_stream.try_next().await.is_err(), "expected one timeout"); + /// assert!(timeout_stream.try_next().await.is_ok(), "expected no more timeouts"); + /// # } + /// ``` + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + fn timeout(self, duration: Duration) -> Timeout + where + Self: Sized, + { + Timeout::new(self, duration) + } + + /// Applies a per-item timeout to the passed stream. + /// + /// `timeout_repeating()` takes an [`Interval`] that controls the time each + /// element of the stream has to complete before timing out. + /// + /// If the wrapped stream yields a value before the deadline is reached, the + /// value is returned. Otherwise, an error is returned. The caller may decide + /// to continue consuming the stream and will eventually get the next source + /// stream value once it becomes available. Unlike `timeout()`, if no value + /// becomes available before the deadline is reached, additional errors are + /// returned at the specified interval. See [`timeout`](StreamExt::timeout) + /// for an alternative where the timeouts do not repeat. + /// + /// # Notes + /// + /// This function consumes the stream passed into it and returns a + /// wrapped version of it. + /// + /// Polling the returned stream will continue to poll the inner stream even + /// if one or more items time out. + /// + /// # Examples + /// + /// Suppose we have a stream `int_stream` that yields 3 numbers (1, 2, 3): + /// + /// ```ignore + /// # #[tokio::main] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// use std::time::Duration; + /// # let int_stream = stream::iter(1..=3); + /// + /// let int_stream = int_stream.timeout_repeating(tokio::time::interval(Duration::from_secs(1))); + /// tokio::pin!(int_stream); + /// + /// // When no items time out, we get the 3 elements in succession: + /// assert_eq!(int_stream.try_next().await, Ok(Some(1))); + /// assert_eq!(int_stream.try_next().await, Ok(Some(2))); + /// assert_eq!(int_stream.try_next().await, Ok(Some(3))); + /// assert_eq!(int_stream.try_next().await, Ok(None)); + /// + /// // If the second item times out, we get an error and continue polling the stream: + /// # let mut int_stream = stream::iter(vec![Ok(1), Err(()), Ok(2), Ok(3)]); + /// assert_eq!(int_stream.try_next().await, Ok(Some(1))); + /// assert!(int_stream.try_next().await.is_err()); + /// assert_eq!(int_stream.try_next().await, Ok(Some(2))); + /// assert_eq!(int_stream.try_next().await, Ok(Some(3))); + /// assert_eq!(int_stream.try_next().await, Ok(None)); + /// + /// // If we want to stop consuming the source stream the first time an + /// // element times out, we can use the `take_while` operator: + /// # let int_stream = stream::iter(vec![Ok(1), Err(()), Ok(2), Ok(3)]); + /// let mut int_stream = int_stream.take_while(Result::is_ok); + /// + /// assert_eq!(int_stream.try_next().await, Ok(Some(1))); + /// assert_eq!(int_stream.try_next().await, Ok(None)); + /// # } + /// ``` + /// + /// Timeout errors will be continuously produced at the specified interval + /// until the wrapped stream yields a value. + /// + /// ```ignore + /// # #[tokio::main(flavor = "current_thread", start_paused = true)] + /// # async fn main() { + /// use shuttle_tokio_stream_impl::{StreamExt, wrappers::IntervalStream}; + /// use std::time::Duration; + /// let interval_stream = IntervalStream::new(tokio::time::interval(Duration::from_millis(23))); + /// let timeout_stream = interval_stream.timeout_repeating(tokio::time::interval(Duration::from_millis(9))); + /// tokio::pin!(timeout_stream); + /// + /// // Multiple timeouts will be received between values in the source stream. + /// assert!(timeout_stream.try_next().await.is_ok()); + /// assert!(timeout_stream.try_next().await.is_err(), "expected one timeout"); + /// assert!(timeout_stream.try_next().await.is_err(), "expected a second timeout"); + /// // Will eventually receive another value from the source stream... + /// assert!(timeout_stream.try_next().await.is_ok(), "expected non-timeout"); + /// # } + /// ``` + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + fn timeout_repeating(self, interval: Interval) -> TimeoutRepeating + where + Self: Sized, + { + TimeoutRepeating::new(self, interval) + } + + /// Slows down a stream by enforcing a delay between items. + /// + /// The underlying timer behind this utility has a granularity of one millisecond. + /// + /// # Example + /// + /// Create a throttled stream. + /// ```rust,no_run + /// use std::time::Duration; + /// use shuttle_tokio_stream_impl::StreamExt; + /// + /// # async fn dox() { + /// let item_stream = futures::stream::repeat("one").throttle(Duration::from_secs(2)); + /// tokio::pin!(item_stream); + /// + /// loop { + /// // The string will be produced at most every 2 seconds + /// println!("{:?}", item_stream.next().await); + /// } + /// # } + /// ``` + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + fn throttle(self, duration: Duration) -> Throttle + where + Self: Sized, + { + throttle(duration, self) + } + + /// Batches the items in the given stream using a maximum duration and size for each batch. + /// + /// This stream returns the next batch of items in the following situations: + /// 1. The inner stream has returned at least `max_size` many items since the last batch. + /// 2. The time since the first item of a batch is greater than the given duration. + /// 3. The end of the stream is reached. + /// + /// The length of the returned vector is never empty or greater than the maximum size. Empty batches + /// will not be emitted if no items are received upstream. + /// + /// # Panics + /// + /// This function panics if `max_size` is zero + /// + /// # Example + /// + /// ```ignore + /// use std::time::Duration; + /// use tokio::time; + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// use futures::FutureExt; + /// + /// #[tokio::main] + /// # async fn _unused() {} + /// # #[tokio::main(flavor = "current_thread", start_paused = true)] + /// async fn main() { + /// let iter = vec![1, 2, 3, 4].into_iter(); + /// let stream0 = stream::iter(iter); + /// + /// let iter = vec![5].into_iter(); + /// let stream1 = stream::iter(iter) + /// .then(move |n| time::sleep(Duration::from_secs(5)).map(move |_| n)); + /// + /// let chunk_stream = stream0 + /// .chain(stream1) + /// .chunks_timeout(3, Duration::from_secs(2)); + /// tokio::pin!(chunk_stream); + /// + /// // a full batch was received + /// assert_eq!(chunk_stream.next().await, Some(vec![1,2,3])); + /// // deadline was reached before max_size was reached + /// assert_eq!(chunk_stream.next().await, Some(vec![4])); + /// // last element in the stream + /// assert_eq!(chunk_stream.next().await, Some(vec![5])); + /// } + /// ``` + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + #[track_caller] + fn chunks_timeout(self, max_size: usize, duration: Duration) -> ChunksTimeout + where + Self: Sized, + { + assert!(max_size > 0, "`max_size` must be non-zero."); + ChunksTimeout::new(self, max_size, duration) + } + + /// Turns the stream into a peekable stream, whose next element can be peeked at without being + /// consumed. + /// ```ignore + /// use shuttle_tokio_stream_impl::{self as stream, StreamExt}; + /// + /// #[tokio::main] + /// # async fn _unused() {} + /// # #[tokio::main(flavor = "current_thread", start_paused = true)] + /// async fn main() { + /// let iter = vec![1, 2, 3, 4].into_iter(); + /// let mut stream = stream::iter(iter).peekable(); + /// + /// assert_eq!(*stream.peek().await.unwrap(), 1); + /// assert_eq!(*stream.peek().await.unwrap(), 1); + /// assert_eq!(stream.next().await.unwrap(), 1); + /// assert_eq!(*stream.peek().await.unwrap(), 2); + /// } + /// ``` + fn peekable(self) -> Peekable + where + Self: Sized, + { + Peekable::new(self) + } +} + +impl StreamExt for St where St: Stream {} + +/// Merge the size hints from two streams. +fn merge_size_hints( + (left_low, left_high): (usize, Option), + (right_low, right_high): (usize, Option), +) -> (usize, Option) { + let low = left_low.saturating_add(right_low); + let high = match (left_high, right_high) { + (Some(h1), Some(h2)) => h1.checked_add(h2), + _ => None, + }; + (low, high) +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/all.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/all.rs new file mode 100644 index 00000000..b4dbc1e9 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/all.rs @@ -0,0 +1,58 @@ +use crate::Stream; + +use core::future::Future; +use core::marker::PhantomPinned; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Future for the [`all`](super::StreamExt::all) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct AllFuture<'a, St: ?Sized, F> { + stream: &'a mut St, + f: F, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +impl<'a, St: ?Sized, F> AllFuture<'a, St, F> { + pub(super) fn new(stream: &'a mut St, f: F) -> Self { + Self { + stream, + f, + _pin: PhantomPinned, + } + } +} + +impl Future for AllFuture<'_, St, F> +where + St: ?Sized + Stream + Unpin, + F: FnMut(St::Item) -> bool, +{ + type Output = bool; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let me = self.project(); + let mut stream = Pin::new(me.stream); + + // Take a maximum of 32 items from the stream before yielding. + for _ in 0..32 { + match futures_core::ready!(stream.as_mut().poll_next(cx)) { + Some(v) => { + if !(me.f)(v) { + return Poll::Ready(false); + } + } + None => return Poll::Ready(true), + } + } + + cx.waker().wake_by_ref(); + Poll::Pending + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/any.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/any.rs new file mode 100644 index 00000000..31394f24 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/any.rs @@ -0,0 +1,58 @@ +use crate::Stream; + +use core::future::Future; +use core::marker::PhantomPinned; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Future for the [`any`](super::StreamExt::any) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct AnyFuture<'a, St: ?Sized, F> { + stream: &'a mut St, + f: F, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +impl<'a, St: ?Sized, F> AnyFuture<'a, St, F> { + pub(super) fn new(stream: &'a mut St, f: F) -> Self { + Self { + stream, + f, + _pin: PhantomPinned, + } + } +} + +impl Future for AnyFuture<'_, St, F> +where + St: ?Sized + Stream + Unpin, + F: FnMut(St::Item) -> bool, +{ + type Output = bool; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let me = self.project(); + let mut stream = Pin::new(me.stream); + + // Take a maximum of 32 items from the stream before yielding. + for _ in 0..32 { + match futures_core::ready!(stream.as_mut().poll_next(cx)) { + Some(v) => { + if (me.f)(v) { + return Poll::Ready(true); + } + } + None => return Poll::Ready(false), + } + } + + cx.waker().wake_by_ref(); + Poll::Pending + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/chain.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/chain.rs new file mode 100644 index 00000000..bd64f33c --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/chain.rs @@ -0,0 +1,50 @@ +use crate::stream_ext::Fuse; +use crate::Stream; + +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream returned by the [`chain`](super::StreamExt::chain) method. + pub struct Chain { + #[pin] + a: Fuse, + #[pin] + b: U, + } +} + +impl Chain { + pub(super) fn new(a: T, b: U) -> Chain + where + T: Stream, + U: Stream, + { + Chain { a: Fuse::new(a), b } + } +} + +impl Stream for Chain +where + T: Stream, + U: Stream, +{ + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use Poll::Ready; + + let me = self.project(); + + if let Some(v) = ready!(me.a.poll_next(cx)) { + return Ready(Some(v)); + } + + me.b.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + super::merge_size_hints(self.a.size_hint(), self.b.size_hint()) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/chunks_timeout.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/chunks_timeout.rs new file mode 100644 index 00000000..48acd932 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/chunks_timeout.rs @@ -0,0 +1,86 @@ +use crate::stream_ext::Fuse; +use crate::Stream; +use tokio::time::{sleep, Sleep}; + +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; +use std::time::Duration; + +pin_project! { + /// Stream returned by the [`chunks_timeout`](super::StreamExt::chunks_timeout) method. + #[must_use = "streams do nothing unless polled"] + #[derive(Debug)] + pub struct ChunksTimeout { + #[pin] + stream: Fuse, + #[pin] + deadline: Option, + duration: Duration, + items: Vec, + cap: usize, // https://github.com/rust-lang/futures-rs/issues/1475 + } +} + +impl ChunksTimeout { + pub(super) fn new(stream: S, max_size: usize, duration: Duration) -> Self { + ChunksTimeout { + stream: Fuse::new(stream), + deadline: None, + duration, + items: Vec::with_capacity(max_size), + cap: max_size, + } + } +} + +impl Stream for ChunksTimeout { + type Item = Vec; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut me = self.as_mut().project(); + loop { + match me.stream.as_mut().poll_next(cx) { + Poll::Pending => break, + Poll::Ready(Some(item)) => { + if me.items.is_empty() { + me.deadline.set(Some(sleep(*me.duration))); + me.items.reserve_exact(*me.cap); + } + me.items.push(item); + if me.items.len() >= *me.cap { + return Poll::Ready(Some(std::mem::take(me.items))); + } + } + Poll::Ready(None) => { + // Returning Some here is only correct because we fuse the inner stream. + let last = if me.items.is_empty() { + None + } else { + Some(std::mem::take(me.items)) + }; + + return Poll::Ready(last); + } + } + } + + if !me.items.is_empty() { + if let Some(deadline) = me.deadline.as_pin_mut() { + ready!(deadline.poll(cx)); + } + return Poll::Ready(Some(std::mem::take(me.items))); + } + + Poll::Pending + } + + fn size_hint(&self) -> (usize, Option) { + let chunk_len = if self.items.is_empty() { 0 } else { 1 }; + let (lower, upper) = self.stream.size_hint(); + let lower = (lower / self.cap).saturating_add(chunk_len); + let upper = upper.and_then(|x| x.checked_add(chunk_len)); + (lower, upper) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/collect.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/collect.rs new file mode 100644 index 00000000..d02e62d9 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/collect.rs @@ -0,0 +1,216 @@ +use crate::Stream; + +use core::future::Future; +use core::marker::PhantomPinned; +use core::mem; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +// Do not export this struct until `FromStream` can be unsealed. +pin_project! { + /// Future returned by the [`collect`](super::StreamExt::collect) method. + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[derive(Debug)] + pub struct Collect + where + T: Stream, + U: FromStream, + { + #[pin] + stream: T, + collection: U::InternalCollection, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +/// Convert from a [`Stream`]. +/// +/// This trait is not intended to be used directly. Instead, call +/// [`StreamExt::collect()`](super::StreamExt::collect). +/// +/// # Implementing +/// +/// Currently, this trait may not be implemented by third parties. The trait is +/// sealed in order to make changes in the future. Stabilization is pending +/// enhancements to the Rust language. +pub trait FromStream: sealed::FromStreamPriv {} + +impl Collect +where + T: Stream, + U: FromStream, +{ + pub(super) fn new(stream: T) -> Collect { + let (lower, upper) = stream.size_hint(); + let collection = U::initialize(sealed::Internal, lower, upper); + + Collect { + stream, + collection, + _pin: PhantomPinned, + } + } +} + +impl Future for Collect +where + T: Stream, + U: FromStream, +{ + type Output = U; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + use Poll::Ready; + + loop { + let me = self.as_mut().project(); + + let item = match ready!(me.stream.poll_next(cx)) { + Some(item) => item, + None => { + return Ready(U::finalize(sealed::Internal, me.collection)); + } + }; + + if !U::extend(sealed::Internal, me.collection, item) { + return Ready(U::finalize(sealed::Internal, me.collection)); + } + } + } +} + +// ===== FromStream implementations + +impl FromStream<()> for () {} + +impl sealed::FromStreamPriv<()> for () { + type InternalCollection = (); + + fn initialize(_: sealed::Internal, _lower: usize, _upper: Option) {} + + fn extend(_: sealed::Internal, _collection: &mut (), _item: ()) -> bool { + true + } + + fn finalize(_: sealed::Internal, _collection: &mut ()) {} +} + +impl> FromStream for String {} + +impl> sealed::FromStreamPriv for String { + type InternalCollection = String; + + fn initialize(_: sealed::Internal, _lower: usize, _upper: Option) -> String { + String::new() + } + + fn extend(_: sealed::Internal, collection: &mut String, item: T) -> bool { + collection.push_str(item.as_ref()); + true + } + + fn finalize(_: sealed::Internal, collection: &mut String) -> String { + mem::take(collection) + } +} + +impl FromStream for Vec {} + +impl sealed::FromStreamPriv for Vec { + type InternalCollection = Vec; + + fn initialize(_: sealed::Internal, lower: usize, _upper: Option) -> Vec { + Vec::with_capacity(lower) + } + + fn extend(_: sealed::Internal, collection: &mut Vec, item: T) -> bool { + collection.push(item); + true + } + + fn finalize(_: sealed::Internal, collection: &mut Vec) -> Vec { + mem::take(collection) + } +} + +impl FromStream for Box<[T]> {} + +impl sealed::FromStreamPriv for Box<[T]> { + type InternalCollection = Vec; + + fn initialize(_: sealed::Internal, lower: usize, upper: Option) -> Vec { + as sealed::FromStreamPriv>::initialize(sealed::Internal, lower, upper) + } + + fn extend(_: sealed::Internal, collection: &mut Vec, item: T) -> bool { + as sealed::FromStreamPriv>::extend(sealed::Internal, collection, item) + } + + fn finalize(_: sealed::Internal, collection: &mut Vec) -> Box<[T]> { + as sealed::FromStreamPriv>::finalize(sealed::Internal, collection).into_boxed_slice() + } +} + +impl FromStream> for Result where U: FromStream {} + +impl sealed::FromStreamPriv> for Result +where + U: FromStream, +{ + type InternalCollection = Result; + + fn initialize(_: sealed::Internal, lower: usize, upper: Option) -> Result { + Ok(U::initialize(sealed::Internal, lower, upper)) + } + + fn extend(_: sealed::Internal, collection: &mut Self::InternalCollection, item: Result) -> bool { + assert!(collection.is_ok()); + match item { + Ok(item) => { + let collection = collection.as_mut().ok().expect("invalid state"); + U::extend(sealed::Internal, collection, item) + } + Err(err) => { + *collection = Err(err); + false + } + } + } + + fn finalize(_: sealed::Internal, collection: &mut Self::InternalCollection) -> Result { + if let Ok(collection) = collection.as_mut() { + Ok(U::finalize(sealed::Internal, collection)) + } else { + let res = mem::replace(collection, Ok(U::initialize(sealed::Internal, 0, Some(0)))); + + Err(res.map(drop).unwrap_err()) + } + } +} + +pub(crate) mod sealed { + #[doc(hidden)] + pub trait FromStreamPriv { + /// Intermediate type used during collection process + /// + /// The name of this type is internal and cannot be relied upon. + type InternalCollection; + + /// Initialize the collection + fn initialize(internal: Internal, lower: usize, upper: Option) -> Self::InternalCollection; + + /// Extend the collection with the received item + /// + /// Return `true` to continue streaming, `false` complete collection. + fn extend(internal: Internal, collection: &mut Self::InternalCollection, item: T) -> bool; + + /// Finalize collection into target type. + fn finalize(internal: Internal, collection: &mut Self::InternalCollection) -> Self; + } + + #[allow(missing_debug_implementations)] + pub struct Internal; +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/filter.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/filter.rs new file mode 100644 index 00000000..3af4fa40 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/filter.rs @@ -0,0 +1,56 @@ +use crate::Stream; + +use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream returned by the [`filter`](super::StreamExt::filter) method. + #[must_use = "streams do nothing unless polled"] + pub struct Filter { + #[pin] + stream: St, + f: F, + } +} + +impl fmt::Debug for Filter +where + St: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Filter").field("stream", &self.stream).finish() + } +} + +impl Filter { + pub(super) fn new(stream: St, f: F) -> Self { + Self { stream, f } + } +} + +impl Stream for Filter +where + St: Stream, + F: FnMut(&St::Item) -> bool, +{ + type Item = St::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(self.as_mut().project().stream.poll_next(cx)) { + Some(e) => { + if (self.as_mut().project().f)(&e) { + return Poll::Ready(Some(e)); + } + } + None => return Poll::Ready(None), + } + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, self.stream.size_hint().1) // can't know a lower bound, due to the predicate + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/filter_map.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/filter_map.rs new file mode 100644 index 00000000..7123a40a --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/filter_map.rs @@ -0,0 +1,56 @@ +use crate::Stream; + +use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream returned by the [`filter_map`](super::StreamExt::filter_map) method. + #[must_use = "streams do nothing unless polled"] + pub struct FilterMap { + #[pin] + stream: St, + f: F, + } +} + +impl fmt::Debug for FilterMap +where + St: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FilterMap").field("stream", &self.stream).finish() + } +} + +impl FilterMap { + pub(super) fn new(stream: St, f: F) -> Self { + Self { stream, f } + } +} + +impl Stream for FilterMap +where + St: Stream, + F: FnMut(St::Item) -> Option, +{ + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(self.as_mut().project().stream.poll_next(cx)) { + Some(e) => { + if let Some(e) = (self.as_mut().project().f)(e) { + return Poll::Ready(Some(e)); + } + } + None => return Poll::Ready(None), + } + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, self.stream.size_hint().1) // can't know a lower bound, due to the predicate + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/fold.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/fold.rs new file mode 100644 index 00000000..e2e97d8f --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/fold.rs @@ -0,0 +1,57 @@ +use crate::Stream; + +use core::future::Future; +use core::marker::PhantomPinned; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Future returned by the [`fold`](super::StreamExt::fold) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct FoldFuture { + #[pin] + stream: St, + acc: Option, + f: F, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +impl FoldFuture { + pub(super) fn new(stream: St, init: B, f: F) -> Self { + Self { + stream, + acc: Some(init), + f, + _pin: PhantomPinned, + } + } +} + +impl Future for FoldFuture +where + St: Stream, + F: FnMut(B, St::Item) -> B, +{ + type Output = B; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut me = self.project(); + loop { + let next = ready!(me.stream.as_mut().poll_next(cx)); + + match next { + Some(v) => { + let old = me.acc.take().unwrap(); + let new = (me.f)(old, v); + *me.acc = Some(new); + } + None => return Poll::Ready(me.acc.take().unwrap()), + } + } + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/fuse.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/fuse.rs new file mode 100644 index 00000000..80240b19 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/fuse.rs @@ -0,0 +1,51 @@ +use crate::Stream; + +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Stream returned by [`fuse()`][super::StreamExt::fuse]. + #[derive(Debug)] + pub struct Fuse { + #[pin] + stream: Option, + } +} + +impl Fuse +where + T: Stream, +{ + pub(crate) fn new(stream: T) -> Fuse { + Fuse { stream: Some(stream) } + } +} + +impl Stream for Fuse +where + T: Stream, +{ + type Item = T::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let res = match Option::as_pin_mut(self.as_mut().project().stream) { + Some(stream) => ready!(stream.poll_next(cx)), + None => return Poll::Ready(None), + }; + + if res.is_none() { + // Do not poll the stream anymore + self.as_mut().project().stream.set(None); + } + + Poll::Ready(res) + } + + fn size_hint(&self) -> (usize, Option) { + match self.stream { + Some(ref stream) => stream.size_hint(), + None => (0, Some(0)), + } + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/map.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/map.rs new file mode 100644 index 00000000..e6b47cd2 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/map.rs @@ -0,0 +1,51 @@ +use crate::Stream; + +use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream for the [`map`](super::StreamExt::map) method. + #[must_use = "streams do nothing unless polled"] + pub struct Map { + #[pin] + stream: St, + f: F, + } +} + +impl fmt::Debug for Map +where + St: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Map").field("stream", &self.stream).finish() + } +} + +impl Map { + pub(super) fn new(stream: St, f: F) -> Self { + Map { stream, f } + } +} + +impl Stream for Map +where + St: Stream, + F: FnMut(St::Item) -> T, +{ + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.as_mut() + .project() + .stream + .poll_next(cx) + .map(|opt| opt.map(|x| (self.as_mut().project().f)(x))) + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/map_while.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/map_while.rs new file mode 100644 index 00000000..347b80ba --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/map_while.rs @@ -0,0 +1,50 @@ +use crate::Stream; + +use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream for the [`map_while`](super::StreamExt::map_while) method. + #[must_use = "streams do nothing unless polled"] + pub struct MapWhile { + #[pin] + stream: St, + f: F, + } +} + +impl fmt::Debug for MapWhile +where + St: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapWhile").field("stream", &self.stream).finish() + } +} + +impl MapWhile { + pub(super) fn new(stream: St, f: F) -> Self { + MapWhile { stream, f } + } +} + +impl Stream for MapWhile +where + St: Stream, + F: FnMut(St::Item) -> Option, +{ + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + let f = me.f; + me.stream.poll_next(cx).map(|opt| opt.and_then(f)) + } + + fn size_hint(&self) -> (usize, Option) { + let (_, upper) = self.stream.size_hint(); + (0, upper) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/merge.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/merge.rs new file mode 100644 index 00000000..1145b88f --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/merge.rs @@ -0,0 +1,84 @@ +use crate::stream_ext::Fuse; +use crate::Stream; + +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream returned by the [`merge`](super::StreamExt::merge) method. + pub struct Merge { + #[pin] + a: Fuse, + #[pin] + b: Fuse, + // When `true`, poll `a` first, otherwise, `poll` b`. + a_first: bool, + } +} + +impl Merge { + pub(super) fn new(a: T, b: U) -> Merge + where + T: Stream, + U: Stream, + { + Merge { + a: Fuse::new(a), + b: Fuse::new(b), + a_first: true, + } + } +} + +impl Stream for Merge +where + T: Stream, + U: Stream, +{ + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + let a_first = *me.a_first; + + // Toggle the flag + *me.a_first = !a_first; + + if a_first { + poll_next(me.a, me.b, cx) + } else { + poll_next(me.b, me.a, cx) + } + } + + fn size_hint(&self) -> (usize, Option) { + super::merge_size_hints(self.a.size_hint(), self.b.size_hint()) + } +} + +fn poll_next(first: Pin<&mut T>, second: Pin<&mut U>, cx: &mut Context<'_>) -> Poll> +where + T: Stream, + U: Stream, +{ + let mut done = true; + + match first.poll_next(cx) { + Poll::Ready(Some(val)) => return Poll::Ready(Some(val)), + Poll::Ready(None) => {} + Poll::Pending => done = false, + } + + match second.poll_next(cx) { + Poll::Ready(Some(val)) => return Poll::Ready(Some(val)), + Poll::Ready(None) => {} + Poll::Pending => done = false, + } + + if done { + Poll::Ready(None) + } else { + Poll::Pending + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/next.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/next.rs new file mode 100644 index 00000000..706069fa --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/next.rs @@ -0,0 +1,44 @@ +use crate::Stream; + +use core::future::Future; +use core::marker::PhantomPinned; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Future for the [`next`](super::StreamExt::next) method. + /// + /// # Cancel safety + /// + /// This method is cancel safe. It only + /// holds onto a reference to the underlying stream, + /// so dropping it will never lose a value. + /// + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct Next<'a, St: ?Sized> { + stream: &'a mut St, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +impl<'a, St: ?Sized> Next<'a, St> { + pub(super) fn new(stream: &'a mut St) -> Self { + Next { + stream, + _pin: PhantomPinned, + } + } +} + +impl Future for Next<'_, St> { + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let me = self.project(); + Pin::new(me.stream).poll_next(cx) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/peekable.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/peekable.rs new file mode 100644 index 00000000..7545ee40 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/peekable.rs @@ -0,0 +1,50 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_core::Stream; +use pin_project_lite::pin_project; + +use crate::stream_ext::Fuse; +use crate::StreamExt; + +pin_project! { + /// Stream returned by the [`chain`](super::StreamExt::peekable) method. + pub struct Peekable { + peek: Option, + #[pin] + stream: Fuse, + } +} + +impl Peekable { + pub(crate) fn new(stream: T) -> Self { + let stream = stream.fuse(); + Self { peek: None, stream } + } + + /// Peek at the next item in the stream. + pub async fn peek(&mut self) -> Option<&T::Item> + where + T: Unpin, + { + if let Some(ref it) = self.peek { + Some(it) + } else { + self.peek = self.next().await; + self.peek.as_ref() + } + } +} + +impl Stream for Peekable { + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if let Some(it) = this.peek.take() { + Poll::Ready(Some(it)) + } else { + this.stream.poll_next(cx) + } + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/skip.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/skip.rs new file mode 100644 index 00000000..429a8553 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/skip.rs @@ -0,0 +1,61 @@ +use crate::Stream; + +use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream for the [`skip`](super::StreamExt::skip) method. + #[must_use = "streams do nothing unless polled"] + pub struct Skip { + #[pin] + stream: St, + remaining: usize, + } +} + +impl fmt::Debug for Skip +where + St: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Skip").field("stream", &self.stream).finish() + } +} + +impl Skip { + pub(super) fn new(stream: St, remaining: usize) -> Self { + Self { stream, remaining } + } +} + +impl Stream for Skip +where + St: Stream, +{ + type Item = St::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(self.as_mut().project().stream.poll_next(cx)) { + Some(e) => { + if self.remaining == 0 { + return Poll::Ready(Some(e)); + } + *self.as_mut().project().remaining -= 1; + } + None => return Poll::Ready(None), + } + } + } + + fn size_hint(&self) -> (usize, Option) { + let (lower, upper) = self.stream.size_hint(); + + let lower = lower.saturating_sub(self.remaining); + let upper = upper.map(|x| x.saturating_sub(self.remaining)); + + (lower, upper) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/skip_while.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/skip_while.rs new file mode 100644 index 00000000..c4f55f71 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/skip_while.rs @@ -0,0 +1,71 @@ +use crate::Stream; + +use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream for the [`skip_while`](super::StreamExt::skip_while) method. + #[must_use = "streams do nothing unless polled"] + pub struct SkipWhile { + #[pin] + stream: St, + predicate: Option, + } +} + +impl fmt::Debug for SkipWhile +where + St: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SkipWhile").field("stream", &self.stream).finish() + } +} + +impl SkipWhile { + pub(super) fn new(stream: St, predicate: F) -> Self { + Self { + stream, + predicate: Some(predicate), + } + } +} + +impl Stream for SkipWhile +where + St: Stream, + F: FnMut(&St::Item) -> bool, +{ + type Item = St::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + if let Some(predicate) = this.predicate { + loop { + match ready!(this.stream.as_mut().poll_next(cx)) { + Some(item) => { + if !(predicate)(&item) { + *this.predicate = None; + return Poll::Ready(Some(item)); + } + } + None => return Poll::Ready(None), + } + } + } else { + this.stream.poll_next(cx) + } + } + + fn size_hint(&self) -> (usize, Option) { + let (lower, upper) = self.stream.size_hint(); + + if self.predicate.is_some() { + return (0, upper); + } + + (lower, upper) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/take.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/take.rs new file mode 100644 index 00000000..a90800cb --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/take.rs @@ -0,0 +1,74 @@ +use crate::Stream; + +use core::cmp; +use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream for the [`take`](super::StreamExt::take) method. + #[must_use = "streams do nothing unless polled"] + pub struct Take { + #[pin] + stream: St, + remaining: usize, + } +} + +impl fmt::Debug for Take +where + St: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Take").field("stream", &self.stream).finish() + } +} + +impl Take { + pub(super) fn new(stream: St, remaining: usize) -> Self { + Self { stream, remaining } + } +} + +impl Stream for Take +where + St: Stream, +{ + type Item = St::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if *self.as_mut().project().remaining > 0 { + self.as_mut().project().stream.poll_next(cx).map(|ready| { + match &ready { + Some(_) => { + *self.as_mut().project().remaining -= 1; + } + None => { + *self.as_mut().project().remaining = 0; + } + } + ready + }) + } else { + Poll::Ready(None) + } + } + + fn size_hint(&self) -> (usize, Option) { + if self.remaining == 0 { + return (0, Some(0)); + } + + let (lower, upper) = self.stream.size_hint(); + + let lower = cmp::min(lower, self.remaining); + + let upper = match upper { + Some(x) if x < self.remaining => Some(x), + _ => Some(self.remaining), + }; + + (lower, upper) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/take_while.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/take_while.rs new file mode 100644 index 00000000..23d4110b --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/take_while.rs @@ -0,0 +1,73 @@ +use crate::Stream; + +use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream for the [`take_while`](super::StreamExt::take_while) method. + #[must_use = "streams do nothing unless polled"] + pub struct TakeWhile { + #[pin] + stream: St, + predicate: F, + done: bool, + } +} + +impl fmt::Debug for TakeWhile +where + St: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TakeWhile") + .field("stream", &self.stream) + .field("done", &self.done) + .finish() + } +} + +impl TakeWhile { + pub(super) fn new(stream: St, predicate: F) -> Self { + Self { + stream, + predicate, + done: false, + } + } +} + +impl Stream for TakeWhile +where + St: Stream, + F: FnMut(&St::Item) -> bool, +{ + type Item = St::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if *self.as_mut().project().done { + Poll::Ready(None) + } else { + self.as_mut().project().stream.poll_next(cx).map(|ready| { + let ready = ready.and_then(|item| (self.as_mut().project().predicate)(&item).then_some(item)); + + if ready.is_none() { + *self.as_mut().project().done = true; + } + + ready + }) + } + } + + fn size_hint(&self) -> (usize, Option) { + if self.done { + return (0, Some(0)); + } + + let (_, upper) = self.stream.size_hint(); + + (0, upper) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/then.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/then.rs new file mode 100644 index 00000000..5997ecd1 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/then.rs @@ -0,0 +1,81 @@ +use crate::Stream; + +use core::fmt; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream for the [`then`](super::StreamExt::then) method. + #[must_use = "streams do nothing unless polled"] + pub struct Then { + #[pin] + stream: St, + #[pin] + future: Option, + f: F, + } +} + +impl fmt::Debug for Then +where + St: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Then").field("stream", &self.stream).finish() + } +} + +impl Then { + pub(super) fn new(stream: St, f: F) -> Self { + Then { + stream, + future: None, + f, + } + } +} + +impl Stream for Then +where + St: Stream, + Fut: Future, + F: FnMut(St::Item) -> Fut, +{ + type Item = Fut::Output; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut me = self.project(); + + loop { + if let Some(future) = me.future.as_mut().as_pin_mut() { + match future.poll(cx) { + Poll::Ready(item) => { + me.future.set(None); + return Poll::Ready(Some(item)); + } + Poll::Pending => return Poll::Pending, + } + } + + match me.stream.as_mut().poll_next(cx) { + Poll::Ready(Some(item)) => { + me.future.set(Some((me.f)(item))); + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + } + + fn size_hint(&self) -> (usize, Option) { + let future_len = usize::from(self.future.is_some()); + let (lower, upper) = self.stream.size_hint(); + + let lower = lower.saturating_add(future_len); + let upper = upper.and_then(|upper| upper.checked_add(future_len)); + + (lower, upper) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/throttle.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/throttle.rs new file mode 100644 index 00000000..50001392 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/throttle.rs @@ -0,0 +1,96 @@ +//! Slow down a stream by enforcing a delay between items. + +use crate::Stream; +use tokio::time::{Duration, Instant, Sleep}; + +use std::future::Future; +use std::pin::Pin; +use std::task::{self, Poll}; + +use pin_project_lite::pin_project; + +pub(super) fn throttle(duration: Duration, stream: T) -> Throttle +where + T: Stream, +{ + Throttle { + delay: tokio::time::sleep_until(Instant::now() + duration), + duration, + has_delayed: true, + stream, + } +} + +pin_project! { + /// Stream for the [`throttle`](throttle) function. This object is `!Unpin`. If you need it to + /// implement `Unpin` you can pin your throttle like this: `Box::pin(your_throttle)`. + #[derive(Debug)] + #[must_use = "streams do nothing unless polled"] + pub struct Throttle { + #[pin] + delay: Sleep, + duration: Duration, + + // Set to true when `delay` has returned ready, but `stream` hasn't. + has_delayed: bool, + + // The stream to throttle + #[pin] + stream: T, + } +} + +impl Throttle { + /// Acquires a reference to the underlying stream that this combinator is + /// pulling from. + pub fn get_ref(&self) -> &T { + &self.stream + } + + /// Acquires a mutable reference to the underlying stream that this combinator + /// is pulling from. + /// + /// Note that care must be taken to avoid tampering with the state of the stream + /// which may otherwise confuse this combinator. + pub fn get_mut(&mut self) -> &mut T { + &mut self.stream + } + + /// Consumes this combinator, returning the underlying stream. + /// + /// Note that this may discard intermediate state of this combinator, so care + /// should be taken to avoid losing resources when this is called. + pub fn into_inner(self) -> T { + self.stream + } +} + +impl Stream for Throttle { + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + let mut me = self.project(); + let dur = *me.duration; + + if !*me.has_delayed && !is_zero(dur) { + ready!(me.delay.as_mut().poll(cx)); + *me.has_delayed = true; + } + + let value = ready!(me.stream.poll_next(cx)); + + if value.is_some() { + if !is_zero(dur) { + me.delay.reset(Instant::now() + dur); + } + + *me.has_delayed = false; + } + + Poll::Ready(value) + } +} + +fn is_zero(dur: Duration) -> bool { + dur == Duration::from_millis(0) +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/timeout.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/timeout.rs new file mode 100644 index 00000000..95fba4e5 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/timeout.rs @@ -0,0 +1,107 @@ +use crate::stream_ext::Fuse; +use crate::Stream; +use tokio::time::{Instant, Sleep}; + +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; +use std::fmt; +use std::time::Duration; + +pin_project! { + /// Stream returned by the [`timeout`](super::StreamExt::timeout) method. + #[must_use = "streams do nothing unless polled"] + #[derive(Debug)] + pub struct Timeout { + #[pin] + stream: Fuse, + #[pin] + deadline: Sleep, + duration: Duration, + poll_deadline: bool, + } +} + +/// Error returned by `Timeout` and `TimeoutRepeating`. +#[derive(Debug, PartialEq, Eq)] +pub struct Elapsed(()); + +impl Timeout { + pub(super) fn new(stream: S, duration: Duration) -> Self { + let next = Instant::now() + duration; + let deadline = tokio::time::sleep_until(next); + + Timeout { + stream: Fuse::new(stream), + deadline, + duration, + poll_deadline: true, + } + } +} + +impl Stream for Timeout { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + + match me.stream.poll_next(cx) { + Poll::Ready(v) => { + if v.is_some() { + let next = Instant::now() + *me.duration; + me.deadline.reset(next); + *me.poll_deadline = true; + } + return Poll::Ready(v.map(Ok)); + } + Poll::Pending => {} + } + + if *me.poll_deadline { + ready!(me.deadline.poll(cx)); + *me.poll_deadline = false; + return Poll::Ready(Some(Err(Elapsed::new()))); + } + + Poll::Pending + } + + fn size_hint(&self) -> (usize, Option) { + let (lower, upper) = self.stream.size_hint(); + + // The timeout stream may insert an error before and after each message + // from the underlying stream, but no more than one error between each + // message. Hence the upper bound is computed as 2x+1. + + // Using a helper function to enable use of question mark operator. + fn twice_plus_one(value: Option) -> Option { + value?.checked_mul(2)?.checked_add(1) + } + + (lower, twice_plus_one(upper)) + } +} + +// ===== impl Elapsed ===== + +impl Elapsed { + pub(crate) fn new() -> Self { + Elapsed(()) + } +} + +impl fmt::Display for Elapsed { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + "deadline has elapsed".fmt(fmt) + } +} + +impl std::error::Error for Elapsed {} + +impl From for std::io::Error { + fn from(_err: Elapsed) -> std::io::Error { + std::io::ErrorKind::TimedOut.into() + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/timeout_repeating.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/timeout_repeating.rs new file mode 100644 index 00000000..bc42bff9 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/timeout_repeating.rs @@ -0,0 +1,56 @@ +use crate::stream_ext::Fuse; +use crate::{Elapsed, Stream}; +use tokio::time::Interval; + +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream returned by the [`timeout_repeating`](super::StreamExt::timeout_repeating) method. + #[must_use = "streams do nothing unless polled"] + #[derive(Debug)] + pub struct TimeoutRepeating { + #[pin] + stream: Fuse, + #[pin] + interval: Interval, + } +} + +impl TimeoutRepeating { + pub(super) fn new(stream: S, interval: Interval) -> Self { + TimeoutRepeating { + stream: Fuse::new(stream), + interval, + } + } +} + +impl Stream for TimeoutRepeating { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut me = self.project(); + + match me.stream.poll_next(cx) { + Poll::Ready(v) => { + if v.is_some() { + me.interval.reset(); + } + return Poll::Ready(v.map(Ok)); + } + Poll::Pending => {} + } + + ready!(me.interval.poll_tick(cx)); + Poll::Ready(Some(Err(Elapsed::new()))) + } + + fn size_hint(&self) -> (usize, Option) { + let (lower, _) = self.stream.size_hint(); + + // The timeout stream may insert an error an infinite number of times. + (lower, None) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_ext/try_next.rs b/wrappers/tokio/impls/tokio-stream/src/stream_ext/try_next.rs new file mode 100644 index 00000000..93aa3bc1 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_ext/try_next.rs @@ -0,0 +1,45 @@ +use crate::stream_ext::Next; +use crate::Stream; + +use core::future::Future; +use core::marker::PhantomPinned; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Future for the [`try_next`](super::StreamExt::try_next) method. + /// + /// # Cancel safety + /// + /// This method is cancel safe. It only + /// holds onto a reference to the underlying stream, + /// so dropping it will never lose a value. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct TryNext<'a, St: ?Sized> { + #[pin] + inner: Next<'a, St>, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +impl<'a, St: ?Sized> TryNext<'a, St> { + pub(super) fn new(stream: &'a mut St) -> Self { + Self { + inner: Next::new(stream), + _pin: PhantomPinned, + } + } +} + +impl> + Unpin> Future for TryNext<'_, St> { + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let me = self.project(); + me.inner.poll(cx).map(Option::transpose) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/stream_map.rs b/wrappers/tokio/impls/tokio-stream/src/stream_map.rs new file mode 100644 index 00000000..2f2c2c43 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/stream_map.rs @@ -0,0 +1,314 @@ +// SHUTTLE_CHANGES +//! This file is derived from [tokio-stream/src/stream_map.rs](https://github.com/tokio-rs/tokio/blob/9e94fa7e15cfe6ebbd06e9ebad4642896620d924/tokio-stream/src/stream_map.rs), and has had the following changes applied to it: +//! 1. Examples removed. +//! 2. Custom rand implementation removed. See CHANGED below. +use crate::Stream; + +use shuttle::rand::Rng; +use std::borrow::Borrow; +use std::hash::Hash; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Combine many streams into one, indexing each source stream with a unique +/// key. +/// +/// `StreamMap` is similar to [`StreamExt::merge`] in that it combines source +/// streams into a single merged stream that yields values in the order that +/// they arrive from the source streams. However, `StreamMap` has a lot more +/// flexibility in usage patterns. +/// +/// `StreamMap` can: +/// +/// * Merge an arbitrary number of streams. +/// * Track which source stream the value was received from. +/// * Handle inserting and removing streams from the set of managed streams at +/// any point during iteration. +/// +/// All source streams held by `StreamMap` are indexed using a key. This key is +/// included with the value when a source stream yields a value. The key is also +/// used to remove the stream from the `StreamMap` before the stream has +/// completed streaming. +/// +/// # `Unpin` +/// +/// Because the `StreamMap` API moves streams during runtime, both streams and +/// keys must be `Unpin`. In order to insert a `!Unpin` stream into a +/// `StreamMap`, use [`pin!`] to pin the stream to the stack or [`Box::pin`] to +/// pin the stream in the heap. +/// +/// # Implementation +/// +/// `StreamMap` is backed by a `Vec<(K, V)>`. There is no guarantee that this +/// internal implementation detail will persist in future versions, but it is +/// important to know the runtime implications. In general, `StreamMap` works +/// best with a "smallish" number of streams as all entries are scanned on +/// insert, remove, and polling. In cases where a large number of streams need +/// to be merged, it may be advisable to use tasks sending values on a shared +/// [`mpsc`] channel. +/// +/// # Notes +/// +/// `StreamMap` removes finished streams automatically, without alerting the user. +/// In some scenarios, the caller would want to know on closed streams. +/// To do this, use [`StreamNotifyClose`] as a wrapper to your stream. +/// It will return None when the stream is closed. +/// +/// [`StreamExt::merge`]: crate::StreamExt::merge +/// [`mpsc`]: https://docs.rs/tokio/1.0/tokio/sync/mpsc/index.html +/// [`pin!`]: https://docs.rs/tokio/1.0/tokio/macro.pin.html +/// [`Box::pin`]: std::boxed::Box::pin +/// [`StreamNotifyClose`]: crate::StreamNotifyClose + +#[derive(Debug)] +pub struct StreamMap { + /// Streams stored in the map + entries: Vec<(K, V)>, +} + +impl StreamMap { + /// An iterator visiting all key-value pairs in arbitrary order. + /// + /// The iterator element type is `&'a (K, V)`. + pub fn iter(&self) -> impl Iterator { + self.entries.iter() + } + + /// An iterator visiting all key-value pairs mutably in arbitrary order. + /// + /// The iterator element type is `&'a mut (K, V)`. + pub fn iter_mut(&mut self) -> impl Iterator { + self.entries.iter_mut() + } + + /// Creates an empty `StreamMap`. + /// + /// The stream map is initially created with a capacity of `0`, so it will + /// not allocate until it is first inserted into. + pub fn new() -> StreamMap { + StreamMap { entries: vec![] } + } + + /// Creates an empty `StreamMap` with the specified capacity. + /// + /// The stream map will be able to hold at least `capacity` elements without + /// reallocating. If `capacity` is 0, the stream map will not allocate. + pub fn with_capacity(capacity: usize) -> StreamMap { + StreamMap { + entries: Vec::with_capacity(capacity), + } + } + + /// Returns an iterator visiting all keys in arbitrary order. + /// + /// The iterator element type is `&'a K`. + pub fn keys(&self) -> impl Iterator { + self.iter().map(|(k, _)| k) + } + + /// An iterator visiting all values in arbitrary order. + /// + /// The iterator element type is `&'a V`. + pub fn values(&self) -> impl Iterator { + self.iter().map(|(_, v)| v) + } + + /// An iterator visiting all values mutably in arbitrary order. + /// + /// The iterator element type is `&'a mut V`. + pub fn values_mut(&mut self) -> impl Iterator { + self.iter_mut().map(|(_, v)| v) + } + + /// Returns the number of streams the map can hold without reallocating. + /// + /// This number is a lower bound; the `StreamMap` might be able to hold + /// more, but is guaranteed to be able to hold at least this many. + pub fn capacity(&self) -> usize { + self.entries.capacity() + } + + /// Returns the number of streams in the map. + pub fn len(&self) -> usize { + self.entries.len() + } + + /// Returns `true` if the map contains no elements. + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + + /// Clears the map, removing all key-stream pairs. Keeps the allocated + /// memory for reuse. + pub fn clear(&mut self) { + self.entries.clear(); + } + + /// Insert a key-stream pair into the map. + /// + /// If the map did not have this key present, `None` is returned. + /// + /// If the map did have this key present, the new `stream` replaces the old + /// one and the old stream is returned. + pub fn insert(&mut self, k: K, stream: V) -> Option + where + K: Hash + Eq, + { + let ret = self.remove(&k); + self.entries.push((k, stream)); + + ret + } + + /// Removes a key from the map, returning the stream at the key if the key was previously in the map. + /// + /// The key may be any borrowed form of the map's key type, but `Hash` and + /// `Eq` on the borrowed form must match those for the key type. + pub fn remove(&mut self, k: &Q) -> Option + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + for i in 0..self.entries.len() { + if self.entries[i].0.borrow() == k { + return Some(self.entries.swap_remove(i).1); + } + } + + None + } + + /// Returns `true` if the map contains a stream for the specified key. + /// + /// The key may be any borrowed form of the map's key type, but `Hash` and + /// `Eq` on the borrowed form must match those for the key type. + pub fn contains_key(&self, k: &Q) -> bool + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + for i in 0..self.entries.len() { + if self.entries[i].0.borrow() == k { + return true; + } + } + + false + } +} + +impl StreamMap +where + K: Unpin, + V: Stream + Unpin, +{ + /// Polls the next value, includes the vec entry index + fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll> { + // SHUTTLE_CHANGES: Uses Shuttle's controlled `thread_rng` for deterministic replay + let start = if self.entries.is_empty() { + 0 + } else { + shuttle::rand::thread_rng().gen::() % self.entries.len() + }; + let mut idx = start; + + for _ in 0..self.entries.len() { + let (_, stream) = &mut self.entries[idx]; + + match Pin::new(stream).poll_next(cx) { + Poll::Ready(Some(val)) => return Poll::Ready(Some((idx, val))), + Poll::Ready(None) => { + // Remove the entry + self.entries.swap_remove(idx); + + // Check if this was the last entry, if so the cursor needs + // to wrap + if idx == self.entries.len() { + idx = 0; + } else if idx < start && start <= self.entries.len() { + // The stream being swapped into the current index has + // already been polled, so skip it. + idx = idx.wrapping_add(1) % self.entries.len(); + } + } + Poll::Pending => { + idx = idx.wrapping_add(1) % self.entries.len(); + } + } + } + + // If the map is empty, then the stream is complete. + if self.entries.is_empty() { + Poll::Ready(None) + } else { + Poll::Pending + } + } +} + +impl Default for StreamMap { + fn default() -> Self { + Self::new() + } +} + +impl Stream for StreamMap +where + K: Clone + Unpin, + V: Stream + Unpin, +{ + type Item = (K, V::Item); + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some((idx, val)) = ready!(self.poll_next_entry(cx)) { + let key = self.entries[idx].0.clone(); + Poll::Ready(Some((key, val))) + } else { + Poll::Ready(None) + } + } + + fn size_hint(&self) -> (usize, Option) { + let mut ret = (0, Some(0)); + + for (_, stream) in &self.entries { + let hint = stream.size_hint(); + + ret.0 += hint.0; + + match (ret.1, hint.1) { + (Some(a), Some(b)) => ret.1 = Some(a + b), + (Some(_), None) => ret.1 = None, + _ => {} + } + } + + ret + } +} + +impl FromIterator<(K, V)> for StreamMap +where + K: Hash + Eq, +{ + fn from_iter>(iter: T) -> Self { + let iterator = iter.into_iter(); + let (lower_bound, _) = iterator.size_hint(); + let mut stream_map = Self::with_capacity(lower_bound); + + for (key, value) in iterator { + stream_map.insert(key, value); + } + + stream_map + } +} + +impl Extend<(K, V)> for StreamMap { + fn extend(&mut self, iter: T) + where + T: IntoIterator, + { + self.entries.extend(iter); + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers.rs new file mode 100644 index 00000000..b2bdbc7a --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers.rs @@ -0,0 +1,80 @@ +//! Wrappers for Tokio types that implement `Stream`. + +// TODO: Implement `broadcast` and uncomment +/* +/// Error types for the wrappers. +pub mod errors { + cfg_sync! { + pub use crate::wrappers::broadcast::BroadcastStreamRecvError; + } +} +*/ + +mod mpsc_bounded; +pub use mpsc_bounded::ReceiverStream; + +mod mpsc_unbounded; +pub use mpsc_unbounded::UnboundedReceiverStream; + +cfg_sync! { + // TODO: Implement `broadcast` and uncomment + /* + mod broadcast; + pub use broadcast::BroadcastStream; + */ + + mod watch; + pub use watch::WatchStream; +} + +// TODO: Implement `signal` and uncomment +/* +cfg_signal! { + #[cfg(unix)] + mod signal_unix; + #[cfg(unix)] + pub use signal_unix::SignalStream; + + #[cfg(any(windows, docsrs))] + mod signal_windows; + #[cfg(any(windows, docsrs))] + pub use signal_windows::{CtrlCStream, CtrlBreakStream}; +} +*/ + +cfg_time! { + mod interval; + pub use interval::IntervalStream; +} + +// TODO: Implement `net` and uncomment +/* +cfg_net! { + mod tcp_listener; + pub use tcp_listener::TcpListenerStream; + + #[cfg(unix)] + mod unix_listener; + #[cfg(unix)] + pub use unix_listener::UnixListenerStream; +} +*/ + +// TODO: Implement `io` and uncomment +/* +cfg_io_util! { + mod split; + pub use split::SplitStream; + + mod lines; + pub use lines::LinesStream; +} +*/ + +// TODO: Implement `fs` and uncomment +/* +cfg_fs! { + mod read_dir; + pub use read_dir::ReadDirStream; +} +*/ diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers/broadcast.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers/broadcast.rs new file mode 100644 index 00000000..71106646 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers/broadcast.rs @@ -0,0 +1,79 @@ +use std::pin::Pin; +use tokio::sync::broadcast::error::RecvError; +use tokio::sync::broadcast::Receiver; + +use futures_core::Stream; +use tokio_util::sync::ReusableBoxFuture; + +use std::fmt; +use std::task::{Context, Poll}; + +/// A wrapper around [`tokio::sync::broadcast::Receiver`] that implements [`Stream`]. +/// +/// [`tokio::sync::broadcast::Receiver`]: struct@tokio::sync::broadcast::Receiver +/// [`Stream`]: trait@crate::Stream +#[cfg_attr(docsrs, doc(cfg(feature = "sync")))] +pub struct BroadcastStream { + inner: ReusableBoxFuture<'static, (Result, Receiver)>, +} + +/// An error returned from the inner stream of a [`BroadcastStream`]. +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum BroadcastStreamRecvError { + /// The receiver lagged too far behind. Attempting to receive again will + /// return the oldest message still retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), +} + +impl fmt::Display for BroadcastStreamRecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BroadcastStreamRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), + } + } +} + +impl std::error::Error for BroadcastStreamRecvError {} + +async fn make_future(mut rx: Receiver) -> (Result, Receiver) { + let result = rx.recv().await; + (result, rx) +} + +impl BroadcastStream { + /// Create a new `BroadcastStream`. + pub fn new(rx: Receiver) -> Self { + Self { + inner: ReusableBoxFuture::new(make_future(rx)), + } + } +} + +impl Stream for BroadcastStream { + type Item = Result; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let (result, rx) = ready!(self.inner.poll(cx)); + self.inner.set(make_future(rx)); + match result { + Ok(item) => Poll::Ready(Some(Ok(item))), + Err(RecvError::Closed) => Poll::Ready(None), + Err(RecvError::Lagged(n)) => { + Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(n)))) + } + } + } +} + +impl fmt::Debug for BroadcastStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BroadcastStream").finish() + } +} + +impl From> for BroadcastStream { + fn from(recv: Receiver) -> Self { + Self::new(recv) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers/interval.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers/interval.rs new file mode 100644 index 00000000..c7a0b1f1 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers/interval.rs @@ -0,0 +1,50 @@ +use crate::Stream; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::time::{Instant, Interval}; + +/// A wrapper around [`Interval`] that implements [`Stream`]. +/// +/// [`Interval`]: struct@tokio::time::Interval +/// [`Stream`]: trait@crate::Stream +#[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(feature = "time")))] +pub struct IntervalStream { + inner: Interval, +} + +impl IntervalStream { + /// Create a new `IntervalStream`. + pub fn new(interval: Interval) -> Self { + Self { inner: interval } + } + + /// Get back the inner `Interval`. + pub fn into_inner(self) -> Interval { + self.inner + } +} + +impl Stream for IntervalStream { + type Item = Instant; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_tick(cx).map(Some) + } + + fn size_hint(&self) -> (usize, Option) { + (usize::MAX, None) + } +} + +impl AsRef for IntervalStream { + fn as_ref(&self) -> &Interval { + &self.inner + } +} + +impl AsMut for IntervalStream { + fn as_mut(&mut self) -> &mut Interval { + &mut self.inner + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers/lines.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers/lines.rs new file mode 100644 index 00000000..4850429a --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers/lines.rs @@ -0,0 +1,59 @@ +use crate::Stream; +use pin_project_lite::pin_project; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncBufRead, Lines}; + +pin_project! { + /// A wrapper around [`tokio::io::Lines`] that implements [`Stream`]. + /// + /// [`tokio::io::Lines`]: struct@tokio::io::Lines + /// [`Stream`]: trait@crate::Stream + #[derive(Debug)] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + pub struct LinesStream { + #[pin] + inner: Lines, + } +} + +impl LinesStream { + /// Create a new `LinesStream`. + pub fn new(lines: Lines) -> Self { + Self { inner: lines } + } + + /// Get back the inner `Lines`. + pub fn into_inner(self) -> Lines { + self.inner + } + + /// Obtain a pinned reference to the inner `Lines`. + pub fn as_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Lines> { + self.project().inner + } +} + +impl Stream for LinesStream { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .inner + .poll_next_line(cx) + .map(Result::transpose) + } +} + +impl AsRef> for LinesStream { + fn as_ref(&self) -> &Lines { + &self.inner + } +} + +impl AsMut> for LinesStream { + fn as_mut(&mut self) -> &mut Lines { + &mut self.inner + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers/mpsc_bounded.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers/mpsc_bounded.rs new file mode 100644 index 00000000..18d799e9 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers/mpsc_bounded.rs @@ -0,0 +1,65 @@ +use crate::Stream; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::Receiver; + +/// A wrapper around [`tokio::sync::mpsc::Receiver`] that implements [`Stream`]. +/// +/// [`tokio::sync::mpsc::Receiver`]: struct@tokio::sync::mpsc::Receiver +/// [`Stream`]: trait@crate::Stream +#[derive(Debug)] +pub struct ReceiverStream { + inner: Receiver, +} + +impl ReceiverStream { + /// Create a new `ReceiverStream`. + pub fn new(recv: Receiver) -> Self { + Self { inner: recv } + } + + /// Get back the inner `Receiver`. + pub fn into_inner(self) -> Receiver { + self.inner + } + + /// Closes the receiving half of a channel without dropping it. + /// + /// This prevents any further messages from being sent on the channel while + /// still enabling the receiver to drain messages that are buffered. Any + /// outstanding [`Permit`] values will still be able to send messages. + /// + /// To guarantee no messages are dropped, after calling `close()`, you must + /// receive all items from the stream until `None` is returned. + /// + /// [`Permit`]: struct@tokio::sync::mpsc::Permit + pub fn close(&mut self) { + self.inner.close(); + } +} + +impl Stream for ReceiverStream { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_recv(cx) + } +} + +impl AsRef> for ReceiverStream { + fn as_ref(&self) -> &Receiver { + &self.inner + } +} + +impl AsMut> for ReceiverStream { + fn as_mut(&mut self) -> &mut Receiver { + &mut self.inner + } +} + +impl From> for ReceiverStream { + fn from(recv: Receiver) -> Self { + Self::new(recv) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers/mpsc_unbounded.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers/mpsc_unbounded.rs new file mode 100644 index 00000000..6945b087 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers/mpsc_unbounded.rs @@ -0,0 +1,59 @@ +use crate::Stream; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::UnboundedReceiver; + +/// A wrapper around [`tokio::sync::mpsc::UnboundedReceiver`] that implements [`Stream`]. +/// +/// [`tokio::sync::mpsc::UnboundedReceiver`]: struct@tokio::sync::mpsc::UnboundedReceiver +/// [`Stream`]: trait@crate::Stream +#[derive(Debug)] +pub struct UnboundedReceiverStream { + inner: UnboundedReceiver, +} + +impl UnboundedReceiverStream { + /// Create a new `UnboundedReceiverStream`. + pub fn new(recv: UnboundedReceiver) -> Self { + Self { inner: recv } + } + + /// Get back the inner `UnboundedReceiver`. + pub fn into_inner(self) -> UnboundedReceiver { + self.inner + } + + /// Closes the receiving half of a channel without dropping it. + /// + /// This prevents any further messages from being sent on the channel while + /// still enabling the receiver to drain messages that are buffered. + pub fn close(&mut self) { + self.inner.close(); + } +} + +impl Stream for UnboundedReceiverStream { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_recv(cx) + } +} + +impl AsRef> for UnboundedReceiverStream { + fn as_ref(&self) -> &UnboundedReceiver { + &self.inner + } +} + +impl AsMut> for UnboundedReceiverStream { + fn as_mut(&mut self) -> &mut UnboundedReceiver { + &mut self.inner + } +} + +impl From> for UnboundedReceiverStream { + fn from(recv: UnboundedReceiver) -> Self { + Self::new(recv) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers/read_dir.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers/read_dir.rs new file mode 100644 index 00000000..b5cf54f7 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers/read_dir.rs @@ -0,0 +1,47 @@ +use crate::Stream; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::fs::{DirEntry, ReadDir}; + +/// A wrapper around [`tokio::fs::ReadDir`] that implements [`Stream`]. +/// +/// [`tokio::fs::ReadDir`]: struct@tokio::fs::ReadDir +/// [`Stream`]: trait@crate::Stream +#[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(feature = "fs")))] +pub struct ReadDirStream { + inner: ReadDir, +} + +impl ReadDirStream { + /// Create a new `ReadDirStream`. + pub fn new(read_dir: ReadDir) -> Self { + Self { inner: read_dir } + } + + /// Get back the inner `ReadDir`. + pub fn into_inner(self) -> ReadDir { + self.inner + } +} + +impl Stream for ReadDirStream { + type Item = io::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_next_entry(cx).map(Result::transpose) + } +} + +impl AsRef for ReadDirStream { + fn as_ref(&self) -> &ReadDir { + &self.inner + } +} + +impl AsMut for ReadDirStream { + fn as_mut(&mut self) -> &mut ReadDir { + &mut self.inner + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers/signal_unix.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers/signal_unix.rs new file mode 100644 index 00000000..2f74e7d1 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers/signal_unix.rs @@ -0,0 +1,46 @@ +use crate::Stream; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::signal::unix::Signal; + +/// A wrapper around [`Signal`] that implements [`Stream`]. +/// +/// [`Signal`]: struct@tokio::signal::unix::Signal +/// [`Stream`]: trait@crate::Stream +#[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "signal"))))] +pub struct SignalStream { + inner: Signal, +} + +impl SignalStream { + /// Create a new `SignalStream`. + pub fn new(interval: Signal) -> Self { + Self { inner: interval } + } + + /// Get back the inner `Signal`. + pub fn into_inner(self) -> Signal { + self.inner + } +} + +impl Stream for SignalStream { + type Item = (); + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_recv(cx) + } +} + +impl AsRef for SignalStream { + fn as_ref(&self) -> &Signal { + &self.inner + } +} + +impl AsMut for SignalStream { + fn as_mut(&mut self) -> &mut Signal { + &mut self.inner + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers/signal_windows.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers/signal_windows.rs new file mode 100644 index 00000000..4631fbad --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers/signal_windows.rs @@ -0,0 +1,88 @@ +use crate::Stream; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::signal::windows::{CtrlBreak, CtrlC}; + +/// A wrapper around [`CtrlC`] that implements [`Stream`]. +/// +/// [`CtrlC`]: struct@tokio::signal::windows::CtrlC +/// [`Stream`]: trait@crate::Stream +#[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(all(windows, feature = "signal"))))] +pub struct CtrlCStream { + inner: CtrlC, +} + +impl CtrlCStream { + /// Create a new `CtrlCStream`. + pub fn new(interval: CtrlC) -> Self { + Self { inner: interval } + } + + /// Get back the inner `CtrlC`. + pub fn into_inner(self) -> CtrlC { + self.inner + } +} + +impl Stream for CtrlCStream { + type Item = (); + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_recv(cx) + } +} + +impl AsRef for CtrlCStream { + fn as_ref(&self) -> &CtrlC { + &self.inner + } +} + +impl AsMut for CtrlCStream { + fn as_mut(&mut self) -> &mut CtrlC { + &mut self.inner + } +} + +/// A wrapper around [`CtrlBreak`] that implements [`Stream`]. +/// +/// [`CtrlBreak`]: struct@tokio::signal::windows::CtrlBreak +/// [`Stream`]: trait@crate::Stream +#[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(all(windows, feature = "signal"))))] +pub struct CtrlBreakStream { + inner: CtrlBreak, +} + +impl CtrlBreakStream { + /// Create a new `CtrlBreakStream`. + pub fn new(interval: CtrlBreak) -> Self { + Self { inner: interval } + } + + /// Get back the inner `CtrlBreak`. + pub fn into_inner(self) -> CtrlBreak { + self.inner + } +} + +impl Stream for CtrlBreakStream { + type Item = (); + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_recv(cx) + } +} + +impl AsRef for CtrlBreakStream { + fn as_ref(&self) -> &CtrlBreak { + &self.inner + } +} + +impl AsMut for CtrlBreakStream { + fn as_mut(&mut self) -> &mut CtrlBreak { + &mut self.inner + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers/split.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers/split.rs new file mode 100644 index 00000000..ac46a8ba --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers/split.rs @@ -0,0 +1,59 @@ +use crate::Stream; +use pin_project_lite::pin_project; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncBufRead, Split}; + +pin_project! { + /// A wrapper around [`tokio::io::Split`] that implements [`Stream`]. + /// + /// [`tokio::io::Split`]: struct@tokio::io::Split + /// [`Stream`]: trait@crate::Stream + #[derive(Debug)] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + pub struct SplitStream { + #[pin] + inner: Split, + } +} + +impl SplitStream { + /// Create a new `SplitStream`. + pub fn new(split: Split) -> Self { + Self { inner: split } + } + + /// Get back the inner `Split`. + pub fn into_inner(self) -> Split { + self.inner + } + + /// Obtain a pinned reference to the inner `Split`. + pub fn as_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Split> { + self.project().inner + } +} + +impl Stream for SplitStream { + type Item = io::Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .inner + .poll_next_segment(cx) + .map(Result::transpose) + } +} + +impl AsRef> for SplitStream { + fn as_ref(&self) -> &Split { + &self.inner + } +} + +impl AsMut> for SplitStream { + fn as_mut(&mut self) -> &mut Split { + &mut self.inner + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers/tcp_listener.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers/tcp_listener.rs new file mode 100644 index 00000000..ce7cb163 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers/tcp_listener.rs @@ -0,0 +1,54 @@ +use crate::Stream; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::net::{TcpListener, TcpStream}; + +/// A wrapper around [`TcpListener`] that implements [`Stream`]. +/// +/// [`TcpListener`]: struct@tokio::net::TcpListener +/// [`Stream`]: trait@crate::Stream +#[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(feature = "net")))] +pub struct TcpListenerStream { + inner: TcpListener, +} + +impl TcpListenerStream { + /// Create a new `TcpListenerStream`. + pub fn new(listener: TcpListener) -> Self { + Self { inner: listener } + } + + /// Get back the inner `TcpListener`. + pub fn into_inner(self) -> TcpListener { + self.inner + } +} + +impl Stream for TcpListenerStream { + type Item = io::Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.inner.poll_accept(cx) { + Poll::Ready(Ok((stream, _))) => Poll::Ready(Some(Ok(stream))), + Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))), + Poll::Pending => Poll::Pending, + } + } +} + +impl AsRef for TcpListenerStream { + fn as_ref(&self) -> &TcpListener { + &self.inner + } +} + +impl AsMut for TcpListenerStream { + fn as_mut(&mut self) -> &mut TcpListener { + &mut self.inner + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers/unix_listener.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers/unix_listener.rs new file mode 100644 index 00000000..0beba588 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers/unix_listener.rs @@ -0,0 +1,54 @@ +use crate::Stream; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::net::{UnixListener, UnixStream}; + +/// A wrapper around [`UnixListener`] that implements [`Stream`]. +/// +/// [`UnixListener`]: struct@tokio::net::UnixListener +/// [`Stream`]: trait@crate::Stream +#[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "net"))))] +pub struct UnixListenerStream { + inner: UnixListener, +} + +impl UnixListenerStream { + /// Create a new `UnixListenerStream`. + pub fn new(listener: UnixListener) -> Self { + Self { inner: listener } + } + + /// Get back the inner `UnixListener`. + pub fn into_inner(self) -> UnixListener { + self.inner + } +} + +impl Stream for UnixListenerStream { + type Item = io::Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.inner.poll_accept(cx) { + Poll::Ready(Ok((stream, _))) => Poll::Ready(Some(Ok(stream))), + Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))), + Poll::Pending => Poll::Pending, + } + } +} + +impl AsRef for UnixListenerStream { + fn as_ref(&self) -> &UnixListener { + &self.inner + } +} + +impl AsMut for UnixListenerStream { + fn as_mut(&mut self) -> &mut UnixListener { + &mut self.inner + } +} diff --git a/wrappers/tokio/impls/tokio-stream/src/wrappers/watch.rs b/wrappers/tokio/impls/tokio-stream/src/wrappers/watch.rs new file mode 100644 index 00000000..05efb64d --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/src/wrappers/watch.rs @@ -0,0 +1,75 @@ +use std::pin::Pin; +use tokio::sync::watch::Receiver; + +use futures_core::Stream; +use tokio_util::sync::ReusableBoxFuture; + +use std::fmt; +use std::task::{Context, Poll}; +use tokio::sync::watch::error::RecvError; + +/// A wrapper around [`tokio::sync::watch::Receiver`] that implements [`Stream`]. +/// +/// This stream will start by yielding the current value when the `WatchStream` is polled, +/// regardless of whether it was the initial value or sent afterwards, +/// unless you use [`WatchStream::from_changes`]. +/// +/// [`tokio::sync::watch::Receiver`]: struct@tokio::sync::watch::Receiver +/// [`Stream`]: trait@crate::Stream +#[cfg_attr(docsrs, doc(cfg(feature = "sync")))] +pub struct WatchStream { + inner: ReusableBoxFuture<'static, (Result<(), RecvError>, Receiver)>, +} + +async fn make_future( + mut rx: Receiver, +) -> (Result<(), RecvError>, Receiver) { + let result = rx.changed().await; + (result, rx) +} + +impl WatchStream { + /// Create a new `WatchStream`. + pub fn new(rx: Receiver) -> Self { + Self { + inner: ReusableBoxFuture::new(async move { (Ok(()), rx) }), + } + } + + /// Create a new `WatchStream` that waits for the value to be changed. + pub fn from_changes(rx: Receiver) -> Self { + Self { + inner: ReusableBoxFuture::new(make_future(rx)), + } + } +} + +impl Stream for WatchStream { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let (result, mut rx) = ready!(self.inner.poll(cx)); + if let Ok(()) = result { + let received = (*rx.borrow_and_update()).clone(); + self.inner.set(make_future(rx)); + Poll::Ready(Some(received)) + } else { + self.inner.set(make_future(rx)); + Poll::Ready(None) + } + } +} + +impl Unpin for WatchStream {} + +impl fmt::Debug for WatchStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WatchStream").finish() + } +} + +impl From> for WatchStream { + fn from(recv: Receiver) -> Self { + Self::new(recv) + } +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/async_send_sync.rs b/wrappers/tokio/impls/tokio-stream/tests/async_send_sync.rs new file mode 100644 index 00000000..128f9fc1 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/async_send_sync.rs @@ -0,0 +1,110 @@ +#![allow(clippy::diverging_sub_expression)] + +use std::rc::Rc; + +#[allow(dead_code)] +type BoxStream = std::pin::Pin>>; + +#[allow(dead_code)] +fn require_send(_t: &T) {} +#[allow(dead_code)] +fn require_sync(_t: &T) {} +#[allow(dead_code)] +fn require_unpin(_t: &T) {} + +#[allow(dead_code)] +struct Invalid; + +#[allow(unused)] +trait AmbiguousIfSend { + fn some_item(&self) {} +} +impl AmbiguousIfSend<()> for T {} +impl AmbiguousIfSend for T {} + +#[allow(unused)] +trait AmbiguousIfSync { + fn some_item(&self) {} +} +impl AmbiguousIfSync<()> for T {} +impl AmbiguousIfSync for T {} + +#[allow(unused)] +trait AmbiguousIfUnpin { + fn some_item(&self) {} +} +impl AmbiguousIfUnpin<()> for T {} +impl AmbiguousIfUnpin for T {} + +macro_rules! into_todo { + ($typ:ty) => {{ + let x: $typ = todo!(); + x + }}; +} + +macro_rules! async_assert_fn { + ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): Send & Sync) => { + #[allow(unreachable_code)] + #[allow(unused_variables)] + const _: fn() = || { + let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); + require_send(&f); + require_sync(&f); + }; + }; + ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): Send & !Sync) => { + #[allow(unreachable_code)] + #[allow(unused_variables)] + const _: fn() = || { + let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); + require_send(&f); + AmbiguousIfSync::some_item(&f); + }; + }; + ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): !Send & Sync) => { + #[allow(unreachable_code)] + #[allow(unused_variables)] + const _: fn() = || { + let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); + AmbiguousIfSend::some_item(&f); + require_sync(&f); + }; + }; + ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): !Send & !Sync) => { + #[allow(unreachable_code)] + #[allow(unused_variables)] + const _: fn() = || { + let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); + AmbiguousIfSend::some_item(&f); + AmbiguousIfSync::some_item(&f); + }; + }; + ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): !Unpin) => { + #[allow(unreachable_code)] + #[allow(unused_variables)] + const _: fn() = || { + let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); + AmbiguousIfUnpin::some_item(&f); + }; + }; + ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): Unpin) => { + #[allow(unreachable_code)] + #[allow(unused_variables)] + const _: fn() = || { + let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); + require_unpin(&f); + }; + }; +} + +async_assert_fn!(shuttle_tokio_stream_impl::empty>(): Send & Sync); +async_assert_fn!(shuttle_tokio_stream_impl::pending>(): Send & Sync); +async_assert_fn!(shuttle_tokio_stream_impl::iter(std::vec::IntoIter): Send & Sync); + +async_assert_fn!(shuttle_tokio_stream_impl::StreamExt::next(&mut BoxStream<()>): !Unpin); +async_assert_fn!(shuttle_tokio_stream_impl::StreamExt::try_next(&mut BoxStream>): !Unpin); +async_assert_fn!(shuttle_tokio_stream_impl::StreamExt::all(&mut BoxStream<()>, fn(())->bool): !Unpin); +async_assert_fn!(shuttle_tokio_stream_impl::StreamExt::any(&mut BoxStream<()>, fn(())->bool): !Unpin); +async_assert_fn!(shuttle_tokio_stream_impl::StreamExt::fold(&mut BoxStream<()>, (), fn((), ())->()): !Unpin); +async_assert_fn!(shuttle_tokio_stream_impl::StreamExt::collect>(&mut BoxStream<()>): !Unpin); diff --git a/wrappers/tokio/impls/tokio-stream/tests/chunks_timeout.rs b/wrappers/tokio/impls/tokio-stream/tests/chunks_timeout.rs new file mode 100644 index 00000000..a9bf6df6 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/chunks_timeout.rs @@ -0,0 +1,82 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "time", feature = "sync", feature = "io-util"))] + +use shuttle_tokio_stream_impl::{self as stream, StreamExt}; +use tokio::time; +use tokio_test::assert_pending; +use tokio_test::task; + +use futures::FutureExt; +use std::time::Duration; + +// TODO: The tests here are `ignore`d because they rely on pausing and manually advancing time, +// and thus fail as time is currently unsupported. However, they are good tests to have if +// we ever add time support, which is why they are not removed entirely. + +#[ignore] +#[tokio::test(start_paused = true)] +async fn usage() { + let iter = vec![1, 2, 3].into_iter(); + let stream0 = stream::iter(iter); + + let iter = vec![4].into_iter(); + let stream1 = stream::iter(iter).then(move |n| time::sleep(Duration::from_secs(3)).map(move |()| n)); + + let chunk_stream = stream0.chain(stream1).chunks_timeout(4, Duration::from_secs(2)); + + let mut chunk_stream = task::spawn(chunk_stream); + + assert_pending!(chunk_stream.poll_next()); + time::advance(Duration::from_secs(2)).await; + assert_eq!(chunk_stream.next().await, Some(vec![1, 2, 3])); + + assert_pending!(chunk_stream.poll_next()); + time::advance(Duration::from_secs(2)).await; + assert_eq!(chunk_stream.next().await, Some(vec![4])); +} + +#[ignore] +#[tokio::test(start_paused = true)] +async fn full_chunk_with_timeout() { + let iter = vec![1, 2].into_iter(); + let stream0 = stream::iter(iter); + + let iter = vec![3].into_iter(); + let stream1 = stream::iter(iter).then(move |n| time::sleep(Duration::from_secs(1)).map(move |()| n)); + + let iter = vec![4].into_iter(); + let stream2 = stream::iter(iter).then(move |n| time::sleep(Duration::from_secs(3)).map(move |()| n)); + + let chunk_stream = stream0 + .chain(stream1) + .chain(stream2) + .chunks_timeout(3, Duration::from_secs(2)); + + let mut chunk_stream = task::spawn(chunk_stream); + + assert_pending!(chunk_stream.poll_next()); + time::advance(Duration::from_secs(2)).await; + assert_eq!(chunk_stream.next().await, Some(vec![1, 2, 3])); + + assert_pending!(chunk_stream.poll_next()); + time::advance(Duration::from_secs(2)).await; + assert_eq!(chunk_stream.next().await, Some(vec![4])); +} + +#[tokio::test] +#[ignore] +async fn real_time() { + let iter = vec![1, 2, 3, 4].into_iter(); + let stream0 = stream::iter(iter); + + let iter = vec![5].into_iter(); + let stream1 = stream::iter(iter).then(move |n| time::sleep(Duration::from_secs(5)).map(move |()| n)); + + let chunk_stream = stream0.chain(stream1).chunks_timeout(3, Duration::from_secs(2)); + + let mut chunk_stream = task::spawn(chunk_stream); + + assert_eq!(chunk_stream.next().await, Some(vec![1, 2, 3])); + assert_eq!(chunk_stream.next().await, Some(vec![4])); + assert_eq!(chunk_stream.next().await, Some(vec![5])); +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/stream_chain.rs b/wrappers/tokio/impls/tokio-stream/tests/stream_chain.rs new file mode 100644 index 00000000..465907cf --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/stream_chain.rs @@ -0,0 +1,97 @@ +use shuttle_tokio_stream_impl::{self as stream, Stream, StreamExt}; +use tokio_test::{assert_pending, assert_ready, task}; + +mod support { + pub(crate) mod mpsc; +} + +use support::mpsc; + +#[tokio::test] +async fn basic_usage() { + let one = stream::iter(vec![1, 2, 3]); + let two = stream::iter(vec![4, 5, 6]); + + let mut stream = one.chain(two); + + assert_eq!(stream.size_hint(), (6, Some(6))); + assert_eq!(stream.next().await, Some(1)); + + assert_eq!(stream.size_hint(), (5, Some(5))); + assert_eq!(stream.next().await, Some(2)); + + assert_eq!(stream.size_hint(), (4, Some(4))); + assert_eq!(stream.next().await, Some(3)); + + assert_eq!(stream.size_hint(), (3, Some(3))); + assert_eq!(stream.next().await, Some(4)); + + assert_eq!(stream.size_hint(), (2, Some(2))); + assert_eq!(stream.next().await, Some(5)); + + assert_eq!(stream.size_hint(), (1, Some(1))); + assert_eq!(stream.next().await, Some(6)); + + assert_eq!(stream.size_hint(), (0, Some(0))); + assert_eq!(stream.next().await, None); + + assert_eq!(stream.size_hint(), (0, Some(0))); + assert_eq!(stream.next().await, None); +} + +#[tokio::test] +async fn pending_first() { + let (tx1, rx1) = mpsc::unbounded_channel_stream(); + let (tx2, rx2) = mpsc::unbounded_channel_stream(); + + let mut stream = task::spawn(rx1.chain(rx2)); + assert_eq!(stream.size_hint(), (0, None)); + + assert_pending!(stream.poll_next()); + + tx2.send(2).unwrap(); + assert!(!stream.is_woken()); + + assert_pending!(stream.poll_next()); + + tx1.send(1).unwrap(); + assert!(stream.is_woken()); + assert_eq!(Some(1), assert_ready!(stream.poll_next())); + + assert_pending!(stream.poll_next()); + + drop(tx1); + + assert_eq!(stream.size_hint(), (0, None)); + + assert!(stream.is_woken()); + assert_eq!(Some(2), assert_ready!(stream.poll_next())); + + assert_eq!(stream.size_hint(), (0, None)); + + drop(tx2); + + assert_eq!(stream.size_hint(), (0, None)); + assert_eq!(None, assert_ready!(stream.poll_next())); +} + +#[test] +fn size_overflow() { + struct Monster; + + impl shuttle_tokio_stream_impl::Stream for Monster { + type Item = (); + fn poll_next(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { + panic!() + } + + fn size_hint(&self) -> (usize, Option) { + (usize::MAX, Some(usize::MAX)) + } + } + + let m1 = Monster; + let m2 = Monster; + let m = m1.chain(m2); + assert_eq!(m.size_hint(), (usize::MAX, None)); +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/stream_close.rs b/wrappers/tokio/impls/tokio-stream/tests/stream_close.rs new file mode 100644 index 00000000..c09a5410 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/stream_close.rs @@ -0,0 +1,11 @@ +use shuttle_tokio_stream_impl::{StreamExt, StreamNotifyClose}; + +#[tokio::test] +async fn basic_usage() { + let mut stream = StreamNotifyClose::new(shuttle_tokio_stream_impl::iter(vec![0, 1])); + + assert_eq!(stream.next().await, Some(Some(0))); + assert_eq!(stream.next().await, Some(Some(1))); + assert_eq!(stream.next().await, Some(None)); + assert_eq!(stream.next().await, None); +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/stream_collect.rs b/wrappers/tokio/impls/tokio-stream/tests/stream_collect.rs new file mode 100644 index 00000000..f7c7182e --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/stream_collect.rs @@ -0,0 +1,146 @@ +use shuttle_tokio_stream_impl::{self as stream, StreamExt}; +use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task}; + +mod support { + pub(crate) mod mpsc; +} + +use support::mpsc; + +#[allow(clippy::let_unit_value)] +#[tokio::test] +async fn empty_unit() { + // Drains the stream. + let mut iter = vec![(), (), ()].into_iter(); + let _: () = stream::iter(&mut iter).collect().await; + assert!(iter.next().is_none()); +} + +#[tokio::test] +async fn empty_vec() { + let coll: Vec = stream::empty().collect().await; + assert!(coll.is_empty()); +} + +#[tokio::test] +async fn empty_box_slice() { + let coll: Box<[u32]> = stream::empty().collect().await; + assert!(coll.is_empty()); +} + +#[tokio::test] +async fn empty_string() { + let coll: String = stream::empty::<&str>().collect().await; + assert!(coll.is_empty()); +} + +#[tokio::test] +async fn empty_result() { + let coll: Result, &str> = stream::empty().collect().await; + assert_eq!(Ok(vec![]), coll); +} + +#[tokio::test] +async fn collect_vec_items() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + let mut fut = task::spawn(rx.collect::>()); + + assert_pending!(fut.poll()); + + tx.send(1).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send(2).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!(vec![1, 2], coll); +} + +#[tokio::test] +async fn collect_string_items() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + + let mut fut = task::spawn(rx.collect::()); + + assert_pending!(fut.poll()); + + tx.send("hello ".to_string()).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send("world".to_string()).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!("hello world", coll); +} + +#[tokio::test] +async fn collect_str_items() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + + let mut fut = task::spawn(rx.collect::()); + + assert_pending!(fut.poll()); + + tx.send("hello ").unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send("world").unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!("hello world", coll); +} + +#[tokio::test] +async fn collect_results_ok() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + + let mut fut = task::spawn(rx.collect::>()); + + assert_pending!(fut.poll()); + + tx.send(Ok("hello ")).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send(Ok("world")).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready_ok!(fut.poll()); + assert_eq!("hello world", coll); +} + +#[tokio::test] +async fn collect_results_err() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + + let mut fut = task::spawn(rx.collect::>()); + + assert_pending!(fut.poll()); + + tx.send(Ok("hello ")).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send(Err("oh no")).unwrap(); + assert!(fut.is_woken()); + let err = assert_ready_err!(fut.poll()); + assert_eq!("oh no", err); +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/stream_empty.rs b/wrappers/tokio/impls/tokio-stream/tests/stream_empty.rs new file mode 100644 index 00000000..981801de --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/stream_empty.rs @@ -0,0 +1,11 @@ +use shuttle_tokio_stream_impl::{self as stream, Stream, StreamExt}; + +#[tokio::test] +async fn basic_usage() { + let mut stream = stream::empty::(); + + for _ in 0..2 { + assert_eq!(stream.size_hint(), (0, Some(0))); + assert_eq!(None, stream.next().await); + } +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/stream_fuse.rs b/wrappers/tokio/impls/tokio-stream/tests/stream_fuse.rs new file mode 100644 index 00000000..16ebffd7 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/stream_fuse.rs @@ -0,0 +1,50 @@ +use shuttle_tokio_stream_impl::{Stream, StreamExt}; + +use std::pin::Pin; +use std::task::{Context, Poll}; + +// a stream which alternates between Some and None +struct Alternate { + state: i32, +} + +impl Stream for Alternate { + type Item = i32; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let val = self.state; + self.state += 1; + + // if it's even, Some(i32), else None + if val % 2 == 0 { + Poll::Ready(Some(val)) + } else { + Poll::Ready(None) + } + } +} + +#[tokio::test] +async fn basic_usage() { + let mut stream = Alternate { state: 0 }; + + // the stream goes back and forth + assert_eq!(stream.next().await, Some(0)); + assert_eq!(stream.next().await, None); + assert_eq!(stream.next().await, Some(2)); + assert_eq!(stream.next().await, None); + + // however, once it is fused + let mut stream = stream.fuse(); + + assert_eq!(stream.size_hint(), (0, None)); + assert_eq!(stream.next().await, Some(4)); + + assert_eq!(stream.size_hint(), (0, None)); + assert_eq!(stream.next().await, None); + + // it will always return `None` after the first time. + assert_eq!(stream.size_hint(), (0, Some(0))); + assert_eq!(stream.next().await, None); + assert_eq!(stream.size_hint(), (0, Some(0))); +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/stream_iter.rs b/wrappers/tokio/impls/tokio-stream/tests/stream_iter.rs new file mode 100644 index 00000000..c728e898 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/stream_iter.rs @@ -0,0 +1,18 @@ +use shuttle_tokio_stream_impl as stream; +use tokio_test::task; + +use std::iter; + +#[tokio::test] +async fn coop() { + let mut stream = task::spawn(stream::iter(iter::repeat(1))); + + for _ in 0..10_000 { + if stream.poll_next().is_pending() { + assert!(stream.is_woken()); + return; + } + } + + panic!("did not yield"); +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/stream_merge.rs b/wrappers/tokio/impls/tokio-stream/tests/stream_merge.rs new file mode 100644 index 00000000..fa966740 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/stream_merge.rs @@ -0,0 +1,80 @@ +use shuttle_tokio_stream_impl::{self as stream, Stream, StreamExt}; +use tokio_test::task; +use tokio_test::{assert_pending, assert_ready}; + +mod support { + pub(crate) mod mpsc; +} + +use support::mpsc; + +#[tokio::test] +async fn merge_sync_streams() { + let mut s = stream::iter(vec![0, 2, 4, 6]).merge(stream::iter(vec![1, 3, 5])); + + for i in 0..7 { + let rem = 7 - i; + assert_eq!(s.size_hint(), (rem, Some(rem))); + assert_eq!(Some(i), s.next().await); + } + + assert!(s.next().await.is_none()); +} + +#[tokio::test] +async fn merge_async_streams() { + let (tx1, rx1) = mpsc::unbounded_channel_stream(); + let (tx2, rx2) = mpsc::unbounded_channel_stream(); + + let mut rx = task::spawn(rx1.merge(rx2)); + + assert_eq!(rx.size_hint(), (0, None)); + + assert_pending!(rx.poll_next()); + + tx1.send(1).unwrap(); + + assert!(rx.is_woken()); + assert_eq!(Some(1), assert_ready!(rx.poll_next())); + + assert_pending!(rx.poll_next()); + tx2.send(2).unwrap(); + + assert!(rx.is_woken()); + assert_eq!(Some(2), assert_ready!(rx.poll_next())); + assert_pending!(rx.poll_next()); + + drop(tx1); + assert!(rx.is_woken()); + assert_pending!(rx.poll_next()); + + tx2.send(3).unwrap(); + assert!(rx.is_woken()); + assert_eq!(Some(3), assert_ready!(rx.poll_next())); + assert_pending!(rx.poll_next()); + + drop(tx2); + assert!(rx.is_woken()); + assert_eq!(None, assert_ready!(rx.poll_next())); +} + +#[test] +fn size_overflow() { + struct Monster; + + impl shuttle_tokio_stream_impl::Stream for Monster { + type Item = (); + fn poll_next(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { + panic!() + } + + fn size_hint(&self) -> (usize, Option) { + (usize::MAX, Some(usize::MAX)) + } + } + + let m1 = Monster; + let m2 = Monster; + let m = m1.merge(m2); + assert_eq!(m.size_hint(), (usize::MAX, None)); +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/stream_once.rs b/wrappers/tokio/impls/tokio-stream/tests/stream_once.rs new file mode 100644 index 00000000..64cd1961 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/stream_once.rs @@ -0,0 +1,12 @@ +use shuttle_tokio_stream_impl::{self as stream, Stream, StreamExt}; + +#[tokio::test] +async fn basic_usage() { + let mut one = stream::once(1); + + assert_eq!(one.size_hint(), (1, Some(1))); + assert_eq!(Some(1), one.next().await); + + assert_eq!(one.size_hint(), (0, Some(0))); + assert_eq!(None, one.next().await); +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/stream_panic.rs b/wrappers/tokio/impls/tokio-stream/tests/stream_panic.rs new file mode 100644 index 00000000..dfec4ff2 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/stream_panic.rs @@ -0,0 +1,54 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "time", not(target_os = "wasi")))] // Wasi does not support panic recovery +#![cfg(panic = "unwind")] + +use shuttle::sync::Arc; +use shuttle::sync::Mutex; +use shuttle_tokio_stream_impl::{self as stream, StreamExt}; +use std::panic; +use tokio::time::Duration; + +fn test_panic(func: Func) -> Option { + let panic_mutex = Mutex::new(()); + + { + let _guard = panic_mutex.lock(); + let panic_file: Arc>> = Arc::new(Mutex::new(None)); + + let prev_hook = panic::take_hook(); + { + let panic_file = panic_file.clone(); + panic::set_hook(Box::new(move |panic_info| { + let panic_location = panic_info.location().unwrap(); + panic_file + .lock() + .unwrap() + .clone_from(&Some(panic_location.file().to_string())); + })); + } + + let result = panic::catch_unwind(func); + // Return to the previously set panic hook (maybe default) so that we get nice error + // messages in the tests. + panic::set_hook(prev_hook); + + if result.is_err() { + panic_file.lock().unwrap().clone() + } else { + None + } + } +} + +#[tokio::test] +async fn stream_chunks_timeout_panic_caller() { + let panic_location_file = test_panic(|| { + let iter = vec![1, 2, 3].into_iter(); + let stream0 = stream::iter(iter); + + let _chunk_stream = stream0.chunks_timeout(0, Duration::from_secs(2)); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/stream_pending.rs b/wrappers/tokio/impls/tokio-stream/tests/stream_pending.rs new file mode 100644 index 00000000..2ec4f90b --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/stream_pending.rs @@ -0,0 +1,14 @@ +use shuttle_tokio_stream_impl::{self as stream, Stream, StreamExt}; +use tokio_test::{assert_pending, task}; + +#[tokio::test] +async fn basic_usage() { + let mut stream = stream::pending::(); + + for _ in 0..2 { + assert_eq!(stream.size_hint(), (0, None)); + + let mut next = task::spawn(async { stream.next().await }); + assert_pending!(next.poll()); + } +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/stream_stream_map.rs b/wrappers/tokio/impls/tokio-stream/tests/stream_stream_map.rs new file mode 100644 index 00000000..eda9585f --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/stream_stream_map.rs @@ -0,0 +1,322 @@ +use shuttle_tokio_stream_impl::{self as stream, pending, Stream, StreamExt, StreamMap}; +use tokio_test::{assert_ok, assert_pending, assert_ready, task}; + +mod support { + pub(crate) mod mpsc; +} + +use support::mpsc; + +use std::pin::Pin; + +macro_rules! assert_ready_some { + ($($t:tt)*) => { + match assert_ready!($($t)*) { + Some(v) => v, + None => panic!("expected `Some`, got `None`"), + } + }; +} + +macro_rules! assert_ready_none { + ($($t:tt)*) => { + match assert_ready!($($t)*) { + None => {} + Some(v) => panic!("expected `None`, got `Some({:?})`", v), + } + }; +} + +#[tokio::test] +async fn empty() { + let mut map = StreamMap::<&str, stream::Pending<()>>::new(); + + assert_eq!(map.len(), 0); + assert!(map.is_empty()); + + assert!(map.next().await.is_none()); + assert!(map.next().await.is_none()); + + assert!(map.remove("foo").is_none()); +} + +#[tokio::test] +async fn single_entry() { + let mut map = task::spawn(StreamMap::new()); + let (tx, rx) = mpsc::unbounded_channel_stream(); + let rx = Box::pin(rx); + + assert_ready_none!(map.poll_next()); + + assert!(map.insert("foo", rx).is_none()); + assert!(map.contains_key("foo")); + assert!(!map.contains_key("bar")); + + assert_eq!(map.len(), 1); + assert!(!map.is_empty()); + + assert_pending!(map.poll_next()); + + assert_ok!(tx.send(1)); + + assert!(map.is_woken()); + let (k, v) = assert_ready_some!(map.poll_next()); + assert_eq!(k, "foo"); + assert_eq!(v, 1); + + assert_pending!(map.poll_next()); + + assert_ok!(tx.send(2)); + + assert!(map.is_woken()); + let (k, v) = assert_ready_some!(map.poll_next()); + assert_eq!(k, "foo"); + assert_eq!(v, 2); + + assert_pending!(map.poll_next()); + drop(tx); + assert!(map.is_woken()); + assert_ready_none!(map.poll_next()); +} + +#[tokio::test] +async fn multiple_entries() { + let mut map = task::spawn(StreamMap::new()); + let (tx1, rx1) = mpsc::unbounded_channel_stream(); + let (tx2, rx2) = mpsc::unbounded_channel_stream(); + + let rx1 = Box::pin(rx1); + let rx2 = Box::pin(rx2); + + map.insert("foo", rx1); + map.insert("bar", rx2); + + assert_pending!(map.poll_next()); + + assert_ok!(tx1.send(1)); + + assert!(map.is_woken()); + let (k, v) = assert_ready_some!(map.poll_next()); + assert_eq!(k, "foo"); + assert_eq!(v, 1); + + assert_pending!(map.poll_next()); + + assert_ok!(tx2.send(2)); + + assert!(map.is_woken()); + let (k, v) = assert_ready_some!(map.poll_next()); + assert_eq!(k, "bar"); + assert_eq!(v, 2); + + assert_pending!(map.poll_next()); + + assert_ok!(tx1.send(3)); + assert_ok!(tx2.send(4)); + + assert!(map.is_woken()); + + // Given the randomization, there is no guarantee what order the values will + // be received in. + let mut v = (0..2).map(|_| assert_ready_some!(map.poll_next())).collect::>(); + + assert_pending!(map.poll_next()); + + v.sort_unstable(); + assert_eq!(v[0].0, "bar"); + assert_eq!(v[0].1, 4); + assert_eq!(v[1].0, "foo"); + assert_eq!(v[1].1, 3); + + drop(tx1); + assert!(map.is_woken()); + assert_pending!(map.poll_next()); + drop(tx2); + + assert_ready_none!(map.poll_next()); +} + +#[tokio::test] +async fn insert_remove() { + let mut map = task::spawn(StreamMap::new()); + let (tx, rx) = mpsc::unbounded_channel_stream(); + + let rx = Box::pin(rx); + + assert_ready_none!(map.poll_next()); + + assert!(map.insert("foo", rx).is_none()); + let rx = map.remove("foo").unwrap(); + + assert_ok!(tx.send(1)); + + assert!(!map.is_woken()); + assert_ready_none!(map.poll_next()); + + assert!(map.insert("bar", rx).is_none()); + + let v = assert_ready_some!(map.poll_next()); + assert_eq!(v.0, "bar"); + assert_eq!(v.1, 1); + + assert!(map.remove("bar").is_some()); + assert_ready_none!(map.poll_next()); + + assert!(map.is_empty()); + assert_eq!(0, map.len()); +} + +#[tokio::test] +async fn replace() { + let mut map = task::spawn(StreamMap::new()); + let (tx1, rx1) = mpsc::unbounded_channel_stream(); + let (tx2, rx2) = mpsc::unbounded_channel_stream(); + + let rx1 = Box::pin(rx1); + let rx2 = Box::pin(rx2); + + assert!(map.insert("foo", rx1).is_none()); + + assert_pending!(map.poll_next()); + + let _rx1 = map.insert("foo", rx2).unwrap(); + + assert_pending!(map.poll_next()); + + tx1.send(1).unwrap(); + assert_pending!(map.poll_next()); + + tx2.send(2).unwrap(); + assert!(map.is_woken()); + let v = assert_ready_some!(map.poll_next()); + assert_eq!(v.0, "foo"); + assert_eq!(v.1, 2); +} + +#[test] +fn size_hint_with_upper() { + let mut map = StreamMap::new(); + + map.insert("a", stream::iter(vec![1])); + map.insert("b", stream::iter(vec![1, 2])); + map.insert("c", stream::iter(vec![1, 2, 3])); + + assert_eq!(3, map.len()); + assert!(!map.is_empty()); + + let size_hint = map.size_hint(); + assert_eq!(size_hint, (6, Some(6))); +} + +#[test] +fn size_hint_without_upper() { + let mut map = StreamMap::new(); + + map.insert("a", pin_box(stream::iter(vec![1]))); + map.insert("b", pin_box(stream::iter(vec![1, 2]))); + map.insert("c", pin_box(pending())); + + let size_hint = map.size_hint(); + assert_eq!(size_hint, (3, None)); +} + +#[test] +fn new_capacity_zero() { + let map = StreamMap::<&str, stream::Pending<()>>::new(); + assert_eq!(0, map.capacity()); + + assert!(map.keys().next().is_none()); +} + +#[test] +fn with_capacity() { + let map = StreamMap::<&str, stream::Pending<()>>::with_capacity(10); + assert!(10 <= map.capacity()); + + assert!(map.keys().next().is_none()); +} + +#[test] +fn iter_keys() { + let mut map = StreamMap::new(); + + map.insert("a", pending::()); + map.insert("b", pending()); + map.insert("c", pending()); + + let mut keys = map.keys().collect::>(); + keys.sort_unstable(); + + assert_eq!(&keys[..], &[&"a", &"b", &"c"]); +} + +#[test] +fn iter_values() { + let mut map = StreamMap::new(); + + map.insert("a", stream::iter(vec![1])); + map.insert("b", stream::iter(vec![1, 2])); + map.insert("c", stream::iter(vec![1, 2, 3])); + + let mut size_hints = map.values().map(|s| s.size_hint().0).collect::>(); + + size_hints.sort_unstable(); + + assert_eq!(&size_hints[..], &[1, 2, 3]); +} + +#[test] +fn iter_values_mut() { + let mut map = StreamMap::new(); + + map.insert("a", stream::iter(vec![1])); + map.insert("b", stream::iter(vec![1, 2])); + map.insert("c", stream::iter(vec![1, 2, 3])); + + let mut size_hints = map.values_mut().map(|s: &mut _| s.size_hint().0).collect::>(); + + size_hints.sort_unstable(); + + assert_eq!(&size_hints[..], &[1, 2, 3]); +} + +#[tokio::test] +async fn clear() { + let mut map = task::spawn(StreamMap::new()); + + map.insert("a", stream::iter(vec![1])); + map.insert("b", stream::iter(vec![1, 2])); + map.insert("c", stream::iter(vec![1, 2, 3])); + + assert_ready_some!(map.poll_next()); + + map.clear(); + + assert_ready_none!(map.poll_next()); + assert!(map.is_empty()); +} + +#[test] +fn contains_key_borrow() { + let mut map = StreamMap::new(); + map.insert("foo".to_string(), pending::<()>()); + + assert!(map.contains_key("foo")); +} + +#[tokio::test] +async fn one_ready_many_none() { + let mut map = task::spawn(StreamMap::new()); + + map.insert(0, pin_box(stream::empty())); + map.insert(1, pin_box(stream::empty())); + map.insert(2, pin_box(stream::once("hello"))); + map.insert(3, pin_box(stream::pending())); + + let v = assert_ready_some!(map.poll_next()); + assert_eq!(v, (2, "hello")); +} + +fn pin_box + 'static, U>(s: T) -> Pin>> { + Box::pin(s) +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/stream_timeout.rs b/wrappers/tokio/impls/tokio-stream/tests/stream_timeout.rs new file mode 100644 index 00000000..202fa154 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/stream_timeout.rs @@ -0,0 +1,114 @@ +#![cfg(all(feature = "time", feature = "sync", feature = "io-util"))] + +use shuttle_tokio_stream_impl::{self, StreamExt}; +use tokio::time::{self, sleep, Duration}; +use tokio_test::*; + +use futures::stream; + +async fn maybe_sleep(idx: i32) -> i32 { + if idx % 2 == 0 { + sleep(ms(200)).await; + } + idx +} + +fn ms(n: u64) -> Duration { + Duration::from_millis(n) +} + +// TODO: The tests here are `ignore`d because they rely on pausing and manually advancing time, +// and thus fail as time is currently unsupported. However, they are good tests to have if +// we ever add time support, which is why they are not removed entirely. + +#[ignore] +#[tokio::test] +async fn basic_usage() { + time::pause(); + + // Items 2 and 4 time out. If we run the stream until it completes, + // we end up with the following items: + // + // [Ok(1), Err(Elapsed), Ok(2), Ok(3), Err(Elapsed), Ok(4)] + + let stream = stream::iter(1..=4).then(maybe_sleep).timeout(ms(100)); + let mut stream = task::spawn(stream); + + // First item completes immediately + assert_ready_eq!(stream.poll_next(), Some(Ok(1))); + + // Second item is delayed 200ms, times out after 100ms + assert_pending!(stream.poll_next()); + + time::advance(ms(150)).await; + let v = assert_ready!(stream.poll_next()); + assert!(v.unwrap().is_err()); + + assert_pending!(stream.poll_next()); + + time::advance(ms(100)).await; + assert_ready_eq!(stream.poll_next(), Some(Ok(2))); + + // Third item is ready immediately + assert_ready_eq!(stream.poll_next(), Some(Ok(3))); + + // Fourth item is delayed 200ms, times out after 100ms + assert_pending!(stream.poll_next()); + + time::advance(ms(60)).await; + assert_pending!(stream.poll_next()); // nothing ready yet + + time::advance(ms(60)).await; + let v = assert_ready!(stream.poll_next()); + assert!(v.unwrap().is_err()); // timeout! + + time::advance(ms(120)).await; + assert_ready_eq!(stream.poll_next(), Some(Ok(4))); + + // Done. + assert_ready_eq!(stream.poll_next(), None); +} + +#[ignore] +#[tokio::test] +async fn return_elapsed_errors_only_once() { + time::pause(); + + let stream = stream::iter(1..=3).then(maybe_sleep).timeout(ms(50)); + let mut stream = task::spawn(stream); + + // First item completes immediately + assert_ready_eq!(stream.poll_next(), Some(Ok(1))); + + // Second item is delayed 200ms, times out after 50ms. Only one `Elapsed` + // error is returned. + assert_pending!(stream.poll_next()); + // + time::advance(ms(51)).await; + let v = assert_ready!(stream.poll_next()); + assert!(v.unwrap().is_err()); // timeout! + + // deadline elapses again, but no error is returned + time::advance(ms(50)).await; + assert_pending!(stream.poll_next()); + + time::advance(ms(100)).await; + assert_ready_eq!(stream.poll_next(), Some(Ok(2))); + assert_ready_eq!(stream.poll_next(), Some(Ok(3))); + + // Done + assert_ready_eq!(stream.poll_next(), None); +} + +#[ignore] +#[tokio::test] +async fn no_timeouts() { + let stream = stream::iter(vec![1, 3, 5]).then(maybe_sleep).timeout(ms(100)); + + let mut stream = task::spawn(stream); + + assert_ready_eq!(stream.poll_next(), Some(Ok(1))); + assert_ready_eq!(stream.poll_next(), Some(Ok(3))); + assert_ready_eq!(stream.poll_next(), Some(Ok(5))); + assert_ready_eq!(stream.poll_next(), None); +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/support/mpsc.rs b/wrappers/tokio/impls/tokio-stream/tests/support/mpsc.rs new file mode 100644 index 00000000..5d66bb79 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/support/mpsc.rs @@ -0,0 +1,15 @@ +use async_stream::stream; +use shuttle_tokio_stream_impl::Stream; +use tokio::sync::mpsc::{self, UnboundedSender}; + +pub fn unbounded_channel_stream() -> (UnboundedSender, impl Stream) { + let (tx, mut rx) = mpsc::unbounded_channel(); + + let stream = stream! { + while let Some(item) = rx.recv().await { + yield item; + } + }; + + (tx, stream) +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/time_throttle.rs b/wrappers/tokio/impls/tokio-stream/tests/time_throttle.rs new file mode 100644 index 00000000..b89290ce --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/time_throttle.rs @@ -0,0 +1,33 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "time", feature = "sync", feature = "io-util"))] + +use shuttle_tokio_stream_impl::StreamExt; +use tokio::time; +use tokio_test::*; + +use std::time::Duration; + +// TODO: This test is `ignore`d because it relies on pausing and manually advancing time, +// and thus fail as time is currently unsupported. However, it is good tests to have if +// we ever add time support, which is why it is not removed entirely. + +#[ignore] +#[tokio::test] +async fn usage() { + time::pause(); + + let mut stream = task::spawn(futures::stream::repeat(()).throttle(Duration::from_millis(100))); + + assert_ready!(stream.poll_next()); + assert_pending!(stream.poll_next()); + + time::advance(Duration::from_millis(90)).await; + + assert_pending!(stream.poll_next()); + + time::advance(Duration::from_millis(101)).await; + + assert!(stream.is_woken()); + + assert_ready!(stream.poll_next()); +} diff --git a/wrappers/tokio/impls/tokio-stream/tests/watch.rs b/wrappers/tokio/impls/tokio-stream/tests/watch.rs new file mode 100644 index 00000000..bdce3065 --- /dev/null +++ b/wrappers/tokio/impls/tokio-stream/tests/watch.rs @@ -0,0 +1,55 @@ +#![cfg(feature = "sync")] + +use shuttle_tokio_stream_impl::wrappers::WatchStream; +use shuttle_tokio_stream_impl::StreamExt; +use tokio::sync::watch; +use tokio_test::assert_pending; +use tokio_test::task::spawn; + +#[tokio::test] +async fn watch_stream_message_not_twice() { + let (tx, rx) = watch::channel("hello"); + + let mut counter = 0; + let mut stream = WatchStream::new(rx).map(move |payload| { + println!("{payload}"); + if payload == "goodbye" { + counter += 1; + } + assert!((counter < 2), "too many goodbyes"); + }); + + let task = tokio::spawn(async move { while stream.next().await.is_some() {} }); + + // Send goodbye just once + tx.send("goodbye").unwrap(); + + drop(tx); + task.await.unwrap(); +} + +#[tokio::test] +async fn watch_stream_from_rx() { + let (tx, rx) = watch::channel("hello"); + + let mut stream = WatchStream::from(rx); + + assert_eq!(stream.next().await.unwrap(), "hello"); + + tx.send("bye").unwrap(); + + assert_eq!(stream.next().await.unwrap(), "bye"); +} + +#[tokio::test] +async fn watch_stream_from_changes() { + let (tx, rx) = watch::channel("hello"); + + let mut stream = WatchStream::from_changes(rx); + + assert_pending!(spawn(&mut stream).poll_next()); + + tx.send("bye").unwrap(); + + assert_eq!(stream.next().await.unwrap(), "bye"); +} diff --git a/wrappers/tokio/impls/tokio-test/Cargo.toml b/wrappers/tokio/impls/tokio-test/Cargo.toml new file mode 100644 index 00000000..845b18a4 --- /dev/null +++ b/wrappers/tokio/impls/tokio-test/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "shuttle-tokio-test-impl" +version = "0.1.0" # Forked from: "0.4.3" +edition = "2021" +rust-version = "1.63" + +# Tokio publishes this, but we currently don't. Package exists for tests in the other tokio-related packages +publish = false + + +[dependencies] +tokio_orig = { package = "tokio", version = "*", features = ["rt", "sync", "time", "test-util"] } +tokio = { package = "shuttle-tokio-impl", version = "*", path = "../tokio", features = ["rt", "sync", "time", "test-util"] } +shuttle = { path = "../../../../shuttle", version = "*" } +tokio-stream = { version = "0.1.0", path = "../tokio-stream", package = "shuttle-tokio-stream-impl" } +async-stream = "0.3.3" + +bytes = "1.0.0" +futures-core = "0.3.0" + +[dev-dependencies] +#tokio = { version = "1.2.0", path = "../tokio", features = ["full"] } +futures-util = "0.3.0" diff --git a/wrappers/tokio/impls/tokio-test/src/io.rs b/wrappers/tokio/impls/tokio-test/src/io.rs new file mode 100644 index 00000000..edf578bf --- /dev/null +++ b/wrappers/tokio/impls/tokio-test/src/io.rs @@ -0,0 +1,500 @@ +//! A mock type implementing [`AsyncRead`] and [`AsyncWrite`]. +//! +//! +//! # Overview +//! +//! Provides a type that implements [`AsyncRead`] + [`AsyncWrite`] that can be configured +//! to handle an arbitrary sequence of read and write operations. This is useful +//! for writing unit tests for networking services as using an actual network +//! type is fairly non deterministic. +//! +//! # Usage +//! +//! Attempting to write data that the mock isn't expecting will result in a +//! panic. +//! +//! [`AsyncRead`]: tokio::io::AsyncRead +//! [`AsyncWrite`]: tokio::io::AsyncWrite + +// Get from tokio_orig as they are not currently exported by shuttle-tokio +use tokio_orig::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use tokio::sync::mpsc; +use tokio::time::{self, Duration, Instant, Sleep}; +use tokio_stream::wrappers::UnboundedReceiverStream; + +use futures_core::{ready, Stream}; +use shuttle::sync::Arc; +use std::collections::VecDeque; +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{self, Poll, Waker}; +use std::{cmp, io}; + +/// An I/O object that follows a predefined script. +/// +/// This value is created by `Builder` and implements `AsyncRead` + `AsyncWrite`. It +/// follows the scenario described by the builder and panics otherwise. +#[derive(Debug)] +pub struct Mock { + inner: Inner, +} + +/// A handle to send additional actions to the related `Mock`. +#[derive(Debug)] +pub struct Handle { + tx: mpsc::UnboundedSender, +} + +/// Builds `Mock` instances. +#[derive(Debug, Clone, Default)] +pub struct Builder { + // Sequence of actions for the Mock to take + actions: VecDeque, +} + +#[derive(Debug, Clone)] +enum Action { + Read(Vec), + Write(Vec), + Wait(Duration), + // Wrapped in Arc so that Builder can be cloned and Send. + // Mock is not cloned as does not need to check Rc for ref counts. + ReadError(Option>), + WriteError(Option>), +} + +struct Inner { + actions: VecDeque, + waiting: Option, + sleep: Option>>, + read_wait: Option, + rx: UnboundedReceiverStream, +} + +impl Builder { + /// Return a new, empty `Builder`. + pub fn new() -> Self { + Self::default() + } + + /// Sequence a `read` operation. + /// + /// The next operation in the mock's script will be to expect a `read` call + /// and return `buf`. + pub fn read(&mut self, buf: &[u8]) -> &mut Self { + self.actions.push_back(Action::Read(buf.into())); + self + } + + /// Sequence a `read` operation that produces an error. + /// + /// The next operation in the mock's script will be to expect a `read` call + /// and return `error`. + pub fn read_error(&mut self, error: io::Error) -> &mut Self { + let error = Some(error.into()); + self.actions.push_back(Action::ReadError(error)); + self + } + + /// Sequence a `write` operation. + /// + /// The next operation in the mock's script will be to expect a `write` + /// call. + pub fn write(&mut self, buf: &[u8]) -> &mut Self { + self.actions.push_back(Action::Write(buf.into())); + self + } + + /// Sequence a `write` operation that produces an error. + /// + /// The next operation in the mock's script will be to expect a `write` + /// call that provides `error`. + pub fn write_error(&mut self, error: io::Error) -> &mut Self { + let error = Some(error.into()); + self.actions.push_back(Action::WriteError(error)); + self + } + + /// Sequence a wait. + /// + /// The next operation in the mock's script will be to wait without doing so + /// for `duration` amount of time. + pub fn wait(&mut self, duration: Duration) -> &mut Self { + let duration = cmp::max(duration, Duration::from_millis(1)); + self.actions.push_back(Action::Wait(duration)); + self + } + + /// Build a `Mock` value according to the defined script. + pub fn build(&mut self) -> Mock { + let (mock, _) = self.build_with_handle(); + mock + } + + /// Build a `Mock` value paired with a handle + pub fn build_with_handle(&mut self) -> (Mock, Handle) { + let (inner, handle) = Inner::new(self.actions.clone()); + + let mock = Mock { inner }; + + (mock, handle) + } +} + +impl Handle { + /// Sequence a `read` operation. + /// + /// The next operation in the mock's script will be to expect a `read` call + /// and return `buf`. + pub fn read(&mut self, buf: &[u8]) -> &mut Self { + self.tx.send(Action::Read(buf.into())).unwrap(); + self + } + + /// Sequence a `read` operation error. + /// + /// The next operation in the mock's script will be to expect a `read` call + /// and return `error`. + pub fn read_error(&mut self, error: io::Error) -> &mut Self { + let error = Some(error.into()); + self.tx.send(Action::ReadError(error)).unwrap(); + self + } + + /// Sequence a `write` operation. + /// + /// The next operation in the mock's script will be to expect a `write` + /// call. + pub fn write(&mut self, buf: &[u8]) -> &mut Self { + self.tx.send(Action::Write(buf.into())).unwrap(); + self + } + + /// Sequence a `write` operation error. + /// + /// The next operation in the mock's script will be to expect a `write` + /// call error. + pub fn write_error(&mut self, error: io::Error) -> &mut Self { + let error = Some(error.into()); + self.tx.send(Action::WriteError(error)).unwrap(); + self + } +} + +impl Inner { + fn new(actions: VecDeque) -> (Inner, Handle) { + let (tx, rx) = mpsc::unbounded_channel(); + + let rx = UnboundedReceiverStream::new(rx); + + let inner = Inner { + actions, + sleep: None, + read_wait: None, + rx, + waiting: None, + }; + + let handle = Handle { tx }; + + (inner, handle) + } + + fn poll_action(&mut self, cx: &mut task::Context<'_>) -> Poll> { + Pin::new(&mut self.rx).poll_next(cx) + } + + fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> { + match self.action() { + Some(&mut Action::Read(ref mut data)) => { + // Figure out how much to copy + let n = cmp::min(dst.remaining(), data.len()); + + // Copy the data into the `dst` slice + dst.put_slice(&data[..n]); + + // Drain the data from the source + data.drain(..n); + + Ok(()) + } + Some(&mut Action::ReadError(ref mut err)) => { + // As the + let err = err.take().expect("Should have been removed from actions."); + let err = Arc::try_unwrap(err).expect("There are no other references."); + Err(err) + } + Some(_) => { + // Either waiting or expecting a write + Err(io::ErrorKind::WouldBlock.into()) + } + None => Ok(()), + } + } + + fn write(&mut self, mut src: &[u8]) -> io::Result { + let mut ret = 0; + + if self.actions.is_empty() { + return Err(io::ErrorKind::BrokenPipe.into()); + } + + if let Some(&mut Action::Wait(..)) = self.action() { + return Err(io::ErrorKind::WouldBlock.into()); + } + + if let Some(&mut Action::WriteError(ref mut err)) = self.action() { + let err = err.take().expect("Should have been removed from actions."); + let err = Arc::try_unwrap(err).expect("There are no other references."); + return Err(err); + } + + for i in 0..self.actions.len() { + match self.actions[i] { + Action::Write(ref mut expect) => { + let n = cmp::min(src.len(), expect.len()); + + assert_eq!(&src[..n], &expect[..n]); + + // Drop data that was matched + expect.drain(..n); + src = &src[n..]; + + ret += n; + + if src.is_empty() { + return Ok(ret); + } + } + Action::Wait(..) | Action::WriteError(..) => { + break; + } + _ => {} + } + + // TODO: remove write + } + + Ok(ret) + } + + fn remaining_wait(&mut self) -> Option { + match self.action() { + Some(&mut Action::Wait(dur)) => Some(dur), + _ => None, + } + } + + fn action(&mut self) -> Option<&mut Action> { + loop { + if self.actions.is_empty() { + return None; + } + + match self.actions[0] { + Action::Read(ref mut data) => { + if !data.is_empty() { + break; + } + } + Action::Write(ref mut data) => { + if !data.is_empty() { + break; + } + } + Action::Wait(ref mut dur) => { + if let Some(until) = self.waiting { + let now = Instant::now(); + + if now < until { + break; + } + self.waiting = None; + } else { + self.waiting = Some(Instant::now() + *dur); + break; + } + } + Action::ReadError(ref mut error) | Action::WriteError(ref mut error) => { + if error.is_some() { + break; + } + } + } + + let _action = self.actions.pop_front(); + } + + self.actions.front_mut() + } +} + +// ===== impl Inner ===== + +impl Mock { + fn maybe_wakeup_reader(&mut self) { + match self.inner.action() { + Some(&mut (Action::Read(_) | Action::ReadError(_))) | None => { + if let Some(waker) = self.inner.read_wait.take() { + waker.wake(); + } + } + _ => {} + } + } +} + +impl AsyncRead for Mock { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + loop { + if let Some(ref mut sleep) = self.inner.sleep { + ready!(Pin::new(sleep).poll(cx)); + } + + // If a sleep is set, it has already fired + self.inner.sleep = None; + + // Capture 'filled' to monitor if it changed + let filled = buf.filled().len(); + + match self.inner.read(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + if let Some(rem) = self.inner.remaining_wait() { + let until = Instant::now() + rem; + self.inner.sleep = Some(Box::pin(time::sleep_until(until))); + } else { + self.inner.read_wait = Some(cx.waker().clone()); + return Poll::Pending; + } + } + Ok(()) => { + if buf.filled().len() == filled { + match ready!(self.inner.poll_action(cx)) { + Some(action) => { + self.inner.actions.push_back(action); + continue; + } + None => { + return Poll::Ready(Ok(())); + } + } + } + return Poll::Ready(Ok(())); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } +} + +impl AsyncWrite for Mock { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll> { + loop { + if let Some(ref mut sleep) = self.inner.sleep { + ready!(Pin::new(sleep).poll(cx)); + } + + // If a sleep is set, it has already fired + self.inner.sleep = None; + + if self.inner.actions.is_empty() { + match self.inner.poll_action(cx) { + Poll::Pending => { + // do not propagate pending + } + Poll::Ready(Some(action)) => { + self.inner.actions.push_back(action); + } + Poll::Ready(None) => { + panic!("unexpected write"); + } + } + } + + match self.inner.write(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + if let Some(rem) = self.inner.remaining_wait() { + let until = Instant::now() + rem; + self.inner.sleep = Some(Box::pin(time::sleep_until(until))); + } else { + panic!("unexpected WouldBlock"); + } + } + Ok(0) => { + // TODO: Is this correct? + if !self.inner.actions.is_empty() { + return Poll::Pending; + } + + // TODO: Extract + match ready!(self.inner.poll_action(cx)) { + Some(action) => { + self.inner.actions.push_back(action); + continue; + } + None => { + panic!("unexpected write"); + } + } + } + ret => { + self.maybe_wakeup_reader(); + return Poll::Ready(ret); + } + } + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +/// Ensures that Mock isn't dropped with data "inside". +impl Drop for Mock { + fn drop(&mut self) { + // Avoid double panicking, since makes debugging much harder. + if std::thread::panicking() { + return; + } + + self.inner.actions.iter().for_each(|a| match a { + Action::Read(data) => assert!(data.is_empty(), "There is still data left to read."), + Action::Write(data) => assert!(data.is_empty(), "There is still data left to write."), + _ => (), + }); + } +} +/* +/// Returns `true` if called from the context of a futures-rs Task +fn is_task_ctx() -> bool { + use std::panic; + + // Save the existing panic hook + let h = panic::take_hook(); + + // Install a new one that does nothing + panic::set_hook(Box::new(|_| {})); + + // Attempt to call the fn + let r = panic::catch_unwind(|| task::current()).is_ok(); + + // Re-install the old one + panic::set_hook(h); + + // Return the result + r +} +*/ + +impl fmt::Debug for Inner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Inner {{...}}") + } +} diff --git a/wrappers/tokio/impls/tokio-test/src/lib.rs b/wrappers/tokio/impls/tokio-test/src/lib.rs new file mode 100644 index 00000000..2208f0e0 --- /dev/null +++ b/wrappers/tokio/impls/tokio-test/src/lib.rs @@ -0,0 +1,31 @@ +#![warn(missing_debug_implementations, missing_docs, rust_2018_idioms, unreachable_pub)] +#![doc(test( + no_crate_inject, + attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) +))] + +//! A fork of [https://docs.rs/crate/tokio-test](https://docs.rs/crate/tokio-test) +//! This package is not intended to be depended on directly, and is not currently published. +//! It is a fork of the tokio-util package with imports swapped so that it builds atop Shuttle. +//! It exists for the other packages in the shuttle-tokio-"family" to use. + +pub mod io; +pub mod stream_mock; + +mod macros; +pub mod task; + +/// Runs the provided future, blocking the current thread until the +/// future completes. +/// +/// For more information, see the documentation for +/// [`tokio::runtime::Runtime::block_on`][runtime-block-on]. +/// +/// [runtime-block-on]: https://docs.rs/tokio/1.3.0/tokio/runtime/struct.Runtime.html#method.block_on +pub fn block_on(future: F) -> F::Output { + use tokio::runtime; + + let rt = runtime::Builder::new_current_thread().enable_all().build().unwrap(); + + rt.block_on(future) +} diff --git a/wrappers/tokio/impls/tokio-test/src/macros.rs b/wrappers/tokio/impls/tokio-test/src/macros.rs new file mode 100644 index 00000000..345a5c2b --- /dev/null +++ b/wrappers/tokio/impls/tokio-test/src/macros.rs @@ -0,0 +1,295 @@ +//! A collection of useful macros for testing futures and tokio based code + +/// Asserts a `Poll` is ready, returning the value. +/// +/// This will invoke `panic!` if the provided `Poll` does not evaluate to `Poll::Ready` at +/// runtime. +/// +/// # Custom Messages +/// +/// This macro has a second form, where a custom panic message can be provided with or without +/// arguments for formatting. +/// +/// # Examples +/// +/// ```ignore +/// use futures_util::future; +/// use shuttle_tokio_test_impl::{assert_ready, task}; +/// +/// let mut fut = task::spawn(future::ready(())); +/// assert_ready!(fut.poll()); +/// ``` +#[macro_export] +macro_rules! assert_ready { + ($e:expr) => {{ + use core::task::Poll; + match $e { + Poll::Ready(v) => v, + Poll::Pending => panic!("pending"), + } + }}; + ($e:expr, $($msg:tt)+) => {{ + use core::task::Poll; + match $e { + Poll::Ready(v) => v, + Poll::Pending => { + panic!("pending; {}", format_args!($($msg)+)) + } + } + }}; +} + +/// Asserts a `Poll>` is ready and `Ok`, returning the value. +/// +/// This will invoke `panic!` if the provided `Poll` does not evaluate to `Poll::Ready(Ok(..))` at +/// runtime. +/// +/// # Custom Messages +/// +/// This macro has a second form, where a custom panic message can be provided with or without +/// arguments for formatting. +/// +/// # Examples +/// +/// ```ignore +/// use futures_util::future; +/// use shuttle_tokio_test_impl::{assert_ready_ok, task}; +/// +/// let mut fut = task::spawn(future::ok::<_, ()>(())); +/// assert_ready_ok!(fut.poll()); +/// ``` +#[macro_export] +macro_rules! assert_ready_ok { + ($e:expr) => {{ + use tokio_test::{assert_ready, assert_ok}; + let val = assert_ready!($e); + assert_ok!(val) + }}; + ($e:expr, $($msg:tt)+) => {{ + use tokio_test::{assert_ready, assert_ok}; + let val = assert_ready!($e, $($msg)*); + assert_ok!(val, $($msg)*) + }}; +} + +/// Asserts a `Poll>` is ready and `Err`, returning the error. +/// +/// This will invoke `panic!` if the provided `Poll` does not evaluate to `Poll::Ready(Err(..))` at +/// runtime. +/// +/// # Custom Messages +/// +/// This macro has a second form, where a custom panic message can be provided with or without +/// arguments for formatting. +/// +/// # Examples +/// +/// ```ignore +/// use futures_util::future; +/// use shuttle_tokio_test_impl::{assert_ready_err, task}; +/// +/// let mut fut = task::spawn(future::err::<(), _>(())); +/// assert_ready_err!(fut.poll()); +/// ``` +#[macro_export] +macro_rules! assert_ready_err { + ($e:expr) => {{ + use tokio_test::{assert_ready, assert_err}; + let val = assert_ready!($e); + assert_err!(val) + }}; + ($e:expr, $($msg:tt)+) => {{ + use tokio_test::{assert_ready, assert_err}; + let val = assert_ready!($e, $($msg)*); + assert_err!(val, $($msg)*) + }}; +} + +/// Asserts a `Poll` is pending. +/// +/// This will invoke `panic!` if the provided `Poll` does not evaluate to `Poll::Pending` at +/// runtime. +/// +/// # Custom Messages +/// +/// This macro has a second form, where a custom panic message can be provided with or without +/// arguments for formatting. +/// +/// # Examples +/// +/// ```ignore +/// use futures_util::future; +/// use shuttle_tokio_test_impl::{assert_pending, task}; +/// +/// let mut fut = task::spawn(future::pending::<()>()); +/// assert_pending!(fut.poll()); +/// ``` +#[macro_export] +macro_rules! assert_pending { + ($e:expr) => {{ + use core::task::Poll; + match $e { + Poll::Pending => {} + Poll::Ready(v) => panic!("ready; value = {:?}", v), + } + }}; + ($e:expr, $($msg:tt)+) => {{ + use core::task::Poll; + match $e { + Poll::Pending => {} + Poll::Ready(v) => { + panic!("ready; value = {:?}; {}", v, format_args!($($msg)+)) + } + } + }}; +} + +/// Asserts if a poll is ready and check for equality on the value +/// +/// This will invoke `panic!` if the provided `Poll` does not evaluate to `Poll::Ready` at +/// runtime and the value produced does not partially equal the expected value. +/// +/// # Custom Messages +/// +/// This macro has a second form, where a custom panic message can be provided with or without +/// arguments for formatting. +/// +/// # Examples +/// +/// ```ignore +/// use futures_util::future; +/// use shuttle_tokio_test_impl::{assert_ready_eq, task}; +/// +/// let mut fut = task::spawn(future::ready(42)); +/// assert_ready_eq!(fut.poll(), 42); +/// ``` +#[macro_export] +macro_rules! assert_ready_eq { + ($e:expr, $expect:expr) => { + let val = $crate::assert_ready!($e); + assert_eq!(val, $expect) + }; + + ($e:expr, $expect:expr, $($msg:tt)+) => { + let val = $crate::assert_ready!($e, $($msg)*); + assert_eq!(val, $expect, $($msg)*) + }; +} + +/// Asserts that the expression evaluates to `Ok` and returns the value. +/// +/// This will invoke the `panic!` macro if the provided expression does not evaluate to `Ok` at +/// runtime. +/// +/// # Custom Messages +/// +/// This macro has a second form, where a custom panic message can be provided with or without +/// arguments for formatting. +/// +/// # Examples +/// +/// ```ignore +/// use shuttle_tokio_test_impl::assert_ok; +/// +/// let n: u32 = assert_ok!("123".parse()); +/// +/// let s = "123"; +/// let n: u32 = assert_ok!(s.parse(), "testing parsing {:?} as a u32", s); +/// ``` +#[macro_export] +macro_rules! assert_ok { + ($e:expr) => { + assert_ok!($e,) + }; + ($e:expr,) => {{ + use std::result::Result::*; + match $e { + Ok(v) => v, + Err(e) => panic!("assertion failed: Err({:?})", e), + } + }}; + ($e:expr, $($arg:tt)+) => {{ + use std::result::Result::*; + match $e { + Ok(v) => v, + Err(e) => panic!("assertion failed: Err({:?}): {}", e, format_args!($($arg)+)), + } + }}; +} + +/// Asserts that the expression evaluates to `Err` and returns the error. +/// +/// This will invoke the `panic!` macro if the provided expression does not evaluate to `Err` at +/// runtime. +/// +/// # Custom Messages +/// +/// This macro has a second form, where a custom panic message can be provided with or without +/// arguments for formatting. +/// +/// # Examples +/// +/// ```ignore +/// use shuttle_tokio_test_impl::assert_err; +/// use std::str::FromStr; +/// +/// +/// let err = assert_err!(u32::from_str("fail")); +/// +/// let msg = "fail"; +/// let err = assert_err!(u32::from_str(msg), "testing parsing {:?} as u32", msg); +/// ``` +#[macro_export] +macro_rules! assert_err { + ($e:expr) => { + assert_err!($e,); + }; + ($e:expr,) => {{ + use std::result::Result::*; + match $e { + Ok(v) => panic!("assertion failed: Ok({:?})", v), + Err(e) => e, + } + }}; + ($e:expr, $($arg:tt)+) => {{ + use std::result::Result::*; + match $e { + Ok(v) => panic!("assertion failed: Ok({:?}): {}", v, format_args!($($arg)+)), + Err(e) => e, + } + }}; +} + +/// Asserts that an exact duration has elapsed since the start instant ±1ms. +/// +/// ```ignore +/// use tokio::time::{self, Instant}; +/// use std::time::Duration; +/// use shuttle_tokio_test_impl::assert_elapsed; +/// # async fn test_time_passed() { +/// +/// let start = Instant::now(); +/// let dur = Duration::from_millis(50); +/// time::sleep(dur).await; +/// assert_elapsed!(start, dur); +/// # } +/// ``` +/// +/// This 1ms buffer is required because Tokio's hashed-wheel timer has finite time resolution and +/// will not always sleep for the exact interval. +#[macro_export] +macro_rules! assert_elapsed { + ($start:expr, $dur:expr) => {{ + let elapsed = $start.elapsed(); + // type ascription improves compiler error when wrong type is passed + let lower: std::time::Duration = $dur; + + // Handles ms rounding + assert!( + elapsed >= lower && elapsed <= lower + std::time::Duration::from_millis(1), + "actual = {:?}, expected = {:?}", + elapsed, + lower + ); + }}; +} diff --git a/wrappers/tokio/impls/tokio-test/src/stream_mock.rs b/wrappers/tokio/impls/tokio-test/src/stream_mock.rs new file mode 100644 index 00000000..0be71b66 --- /dev/null +++ b/wrappers/tokio/impls/tokio-test/src/stream_mock.rs @@ -0,0 +1,165 @@ +//! A mock stream implementing [`Stream`]. +//! +//! # Overview +//! This crate provides a `StreamMock` that can be used to test code that interacts with streams. +//! It allows you to mock the behavior of a stream and control the items it yields and the waiting +//! intervals between items. +//! +//! # Usage +//! To use the `StreamMock`, you need to create a builder using [`StreamMockBuilder`]. The builder +//! allows you to enqueue actions such as returning items or waiting for a certain duration. +//! +//! # Example +//! ```ignore +//! +//! use futures_util::StreamExt; +//! use std::time::Duration; +//! use tokio_test::stream_mock::StreamMockBuilder; +//! +//! async fn test_stream_mock_wait() { +//! let mut stream_mock = StreamMockBuilder::new() +//! .next(1) +//! .wait(Duration::from_millis(300)) +//! .next(2) +//! .build(); +//! +//! assert_eq!(stream_mock.next().await, Some(1)); +//! let start = std::time::Instant::now(); +//! assert_eq!(stream_mock.next().await, Some(2)); +//! let elapsed = start.elapsed(); +//! assert!(elapsed >= Duration::from_millis(300)); +//! assert_eq!(stream_mock.next().await, None); +//! } +//! ``` + +use std::collections::VecDeque; +use std::pin::Pin; +use std::task::Poll; +use std::time::Duration; + +use futures_core::{ready, Stream}; +use std::future::Future; +use tokio::time::{sleep_until, Instant, Sleep}; + +#[derive(Debug, Clone)] +enum Action { + Next(T), + Wait(Duration), +} + +/// A builder for [`StreamMock`] +#[derive(Debug, Clone)] +pub struct StreamMockBuilder { + actions: VecDeque>, +} + +impl StreamMockBuilder { + /// Create a new empty [`StreamMockBuilder`] + pub fn new() -> Self { + StreamMockBuilder::default() + } + + /// Queue an item to be returned by the stream + pub fn next(mut self, value: T) -> Self { + self.actions.push_back(Action::Next(value)); + self + } + + // Queue an item to be consumed by the sink, + // commented out until Sink is implemented. + // + // pub fn consume(mut self, value: T) -> Self { + // self.actions.push_back(Action::Consume(value)); + // self + // } + + /// Queue the stream to wait for a duration + pub fn wait(mut self, duration: Duration) -> Self { + self.actions.push_back(Action::Wait(duration)); + self + } + + /// Build the [`StreamMock`] + pub fn build(self) -> StreamMock { + StreamMock { + actions: self.actions, + sleep: None, + } + } +} + +impl Default for StreamMockBuilder { + fn default() -> Self { + StreamMockBuilder { + actions: VecDeque::new(), + } + } +} + +/// A mock stream implementing [`Stream`] +/// +/// See [`StreamMockBuilder`] for more information. +#[derive(Debug)] +pub struct StreamMock { + actions: VecDeque>, + sleep: Option>>, +} + +impl StreamMock { + fn next_action(&mut self) -> Option> { + self.actions.pop_front() + } +} + +impl Stream for StreamMock { + type Item = T; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // Try polling the sleep future first + if let Some(ref mut sleep) = self.sleep { + ready!(Pin::new(sleep).poll(cx)); + // Since we're ready, discard the sleep future + self.sleep.take(); + } + + match self.next_action() { + Some(action) => match action { + Action::Next(item) => Poll::Ready(Some(item)), + Action::Wait(duration) => { + // Set up a sleep future and schedule this future to be polled again for it. + self.sleep = Some(Box::pin(sleep_until(Instant::now() + duration))); + cx.waker().wake_by_ref(); + + Poll::Pending + } + }, + None => Poll::Ready(None), + } + } +} + +impl Drop for StreamMock { + fn drop(&mut self) { + // Avoid double panicking to make debugging easier. + if std::thread::panicking() { + return; + } + + let undropped_count = self + .actions + .iter() + .filter(|action| match action { + Action::Next(_) => true, + Action::Wait(_) => false, + }) + .count(); + + assert!( + undropped_count == 0, + "StreamMock was dropped before all actions were consumed, {undropped_count} actions were not consumed" + ); + } +} diff --git a/wrappers/tokio/impls/tokio-test/src/task.rs b/wrappers/tokio/impls/tokio-test/src/task.rs new file mode 100644 index 00000000..3fb8424e --- /dev/null +++ b/wrappers/tokio/impls/tokio-test/src/task.rs @@ -0,0 +1,281 @@ +//! Futures task based helpers to easily test futures and manually written futures. +//! +//! The [`Spawn`] type is used as a mock task harness that allows you to poll futures +//! without needing to setup pinning or context. Any future can be polled but if the +//! future requires the tokio async context you will need to ensure that you poll the +//! [`Spawn`] within a tokio context, this means that as long as you are inside the +//! runtime it will work and you can poll it via [`Spawn`]. +//! +//! [`Spawn`] also supports [`Stream`] to call `poll_next` without pinning +//! or context. +//! +//! In addition to circumventing the need for pinning and context, [`Spawn`] also tracks +//! the amount of times the future/task was woken. This can be useful to track if some +//! leaf future notified the root task correctly. +//! +//! # Example +//! +//! ```ignore +//! use shuttle_tokio_test_impl::task; +//! +//! let fut = async {}; +//! +//! let mut task = task::spawn(fut); +//! +//! assert!(task.poll().is_ready(), "Task was not ready!"); +//! ``` + +use shuttle::sync::{Arc, Condvar, Mutex}; +use std::future::Future; +use std::mem; +use std::ops; +use std::pin::Pin; +use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; + +use tokio_stream::Stream; + +/// Spawn a future into a [`Spawn`] which wraps the future in a mocked executor. +/// +/// This can be used to spawn a [`Future`] or a [`Stream`]. +/// +/// For more information, check the module docs. +pub fn spawn(task: T) -> Spawn { + Spawn { + task: MockTask::new(), + future: Box::pin(task), + } +} + +/// Future spawned on a mock task that can be used to poll the future or stream +/// without needing pinning or context types. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Spawn { + task: MockTask, + future: Pin>, +} + +#[derive(Debug, Clone)] +struct MockTask { + waker: Arc, +} + +#[derive(Debug)] +struct ThreadWaker { + state: Mutex, + condvar: Condvar, +} + +const IDLE: usize = 0; +const WAKE: usize = 1; +const SLEEP: usize = 2; + +impl Spawn { + /// Consumes `self` returning the inner value + pub fn into_inner(self) -> T + where + T: Unpin, + { + *Pin::into_inner(self.future) + } + + /// Returns `true` if the inner future has received a wake notification + /// since the last call to `enter`. + pub fn is_woken(&self) -> bool { + self.task.is_woken() + } + + /// Returns the number of references to the task waker + /// + /// The task itself holds a reference. The return value will never be zero. + pub fn waker_ref_count(&self) -> usize { + self.task.waker_ref_count() + } + + /// Enter the task context + pub fn enter(&mut self, f: F) -> R + where + F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R, + { + let fut = self.future.as_mut(); + self.task.enter(|cx| f(cx, fut)) + } +} + +impl ops::Deref for Spawn { + type Target = T; + + fn deref(&self) -> &T { + &self.future + } +} + +impl ops::DerefMut for Spawn { + fn deref_mut(&mut self) -> &mut T { + &mut self.future + } +} + +impl Spawn { + /// If `T` is a [`Future`] then poll it. This will handle pinning and the context + /// type for the future. + pub fn poll(&mut self) -> Poll { + let fut = self.future.as_mut(); + self.task.enter(|cx| fut.poll(cx)) + } +} + +impl Spawn { + /// If `T` is a [`Stream`] then `poll_next` it. This will handle pinning and the context + /// type for the stream. + pub fn poll_next(&mut self) -> Poll> { + let stream = self.future.as_mut(); + self.task.enter(|cx| stream.poll_next(cx)) + } +} + +impl Future for Spawn { + type Output = T::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.future.as_mut().poll(cx) + } +} + +impl Stream for Spawn { + type Item = T::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.future.as_mut().poll_next(cx) + } +} + +impl MockTask { + /// Creates new mock task + fn new() -> Self { + MockTask { + waker: Arc::new(ThreadWaker::new()), + } + } + + /// Runs a closure from the context of the task. + /// + /// Any wake notifications resulting from the execution of the closure are + /// tracked. + fn enter(&mut self, f: F) -> R + where + F: FnOnce(&mut Context<'_>) -> R, + { + self.waker.clear(); + let waker = self.waker(); + let mut cx = Context::from_waker(&waker); + + f(&mut cx) + } + + /// Returns `true` if the inner future has received a wake notification + /// since the last call to `enter`. + fn is_woken(&self) -> bool { + self.waker.is_woken() + } + + /// Returns the number of references to the task waker + /// + /// The task itself holds a reference. The return value will never be zero. + fn waker_ref_count(&self) -> usize { + Arc::strong_count(&self.waker) + } + + fn waker(&self) -> Waker { + unsafe { + let raw = to_raw(self.waker.clone()); + Waker::from_raw(raw) + } + } +} + +impl Default for MockTask { + fn default() -> Self { + Self::new() + } +} + +impl ThreadWaker { + fn new() -> Self { + ThreadWaker { + state: Mutex::new(IDLE), + condvar: Condvar::new(), + } + } + + /// Clears any previously received wakes, avoiding potential spurious + /// wake notifications. This should only be called immediately before running the + /// task. + fn clear(&self) { + *self.state.lock().unwrap() = IDLE; + } + + fn is_woken(&self) -> bool { + match *self.state.lock().unwrap() { + IDLE => false, + WAKE => true, + _ => unreachable!(), + } + } + + fn wake(&self) { + // First, try transitioning from IDLE -> NOTIFY, this does not require a lock. + let mut state = self.state.lock().unwrap(); + let prev = *state; + + if prev == WAKE { + return; + } + + *state = WAKE; + + if prev == IDLE { + return; + } + + // The other half is sleeping, so we wake it up. + assert_eq!(prev, SLEEP); + self.condvar.notify_one(); + } +} + +static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker); + +unsafe fn to_raw(waker: Arc) -> RawWaker { + RawWaker::new(Arc::into_raw(waker).cast::<()>(), &VTABLE) +} + +unsafe fn from_raw(raw: *const ()) -> Arc { + Arc::from_raw(raw.cast::()) +} + +unsafe fn clone(raw: *const ()) -> RawWaker { + let waker = from_raw(raw); + + // Increment the ref count + mem::forget(waker.clone()); + + to_raw(waker) +} + +unsafe fn wake(raw: *const ()) { + let waker = from_raw(raw); + waker.wake(); +} + +unsafe fn wake_by_ref(raw: *const ()) { + let waker = from_raw(raw); + waker.wake(); + + // We don't actually own a reference to the unparker + mem::forget(waker); +} + +unsafe fn drop_waker(raw: *const ()) { + let _ = from_raw(raw); +} diff --git a/wrappers/tokio/impls/tokio-util/Cargo.toml b/wrappers/tokio/impls/tokio-util/Cargo.toml new file mode 100644 index 00000000..3e95f613 --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "shuttle-tokio-util-impl" +version = "0.1.0" # Forked from "0.7.11" +edition = "2021" + +publish = false + +[features] +# Shorthand for enabling everything +full = ["codec", "compat", "io-util", "time", "net", "rt"] + +net = ["tokio/net"] +compat = [] +codec = ["tokio-util-orig/codec"] +time = ["tokio/time"] +io = [] +io-util = ["io", "tokio/rt", "tokio/io-util"] +rt = ["tokio/rt", "tokio/sync"] + +__docs_rs = [] + +[dependencies] +pin-project-lite = "0.2.11" +tokio = { version = "*", package = "shuttle-tokio-impl", path = "../tokio", features = ["sync"] } +shuttle = { path = "../../../../shuttle", version = "*" } +tokio-util-orig = { version = "0.7.11", package = "tokio-util" } + +[dev-dependencies] +tokio-util = { path = ".", package = "shuttle-tokio-util-impl", features = ["full"] } +tokio = { package = "shuttle-tokio-impl", path = "../tokio", features = ["full"] } +tokio-test = { package = "shuttle-tokio-test-impl", path = "../tokio-test" } +tokio-stream = { package = "shuttle-tokio-stream-impl", path = "../tokio-stream" } + +futures-test = "0.3.5" diff --git a/wrappers/tokio/impls/tokio-util/src/cfg.rs b/wrappers/tokio/impls/tokio-util/src/cfg.rs new file mode 100644 index 00000000..1946f1bc --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/src/cfg.rs @@ -0,0 +1,73 @@ +#![allow(unused)] + +macro_rules! cfg_codec { + ($($item:item)*) => { + $( + #[cfg(feature = "codec")] + #[cfg_attr(docsrs, doc(cfg(feature = "codec")))] + $item + )* + } +} + +macro_rules! cfg_compat { + ($($item:item)*) => { + $( + #[cfg(feature = "compat")] + #[cfg_attr(docsrs, doc(cfg(feature = "compat")))] + $item + )* + } +} + +macro_rules! cfg_net { + ($($item:item)*) => { + $( + #[cfg(all(feature = "net", feature = "codec"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "net", feature = "codec"))))] + $item + )* + } +} + +macro_rules! cfg_io { + ($($item:item)*) => { + $( + #[cfg(feature = "io")] + #[cfg_attr(docsrs, doc(cfg(feature = "io")))] + $item + )* + } +} + +cfg_io! { + macro_rules! cfg_io_util { + ($($item:item)*) => { + $( + #[cfg(feature = "io-util")] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + $item + )* + } + } +} + +macro_rules! cfg_rt { + ($($item:item)*) => { + $( + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + $item + )* + } +} + +macro_rules! cfg_time { + ($($item:item)*) => { + $( + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + $item + )* + } +} diff --git a/wrappers/tokio/impls/tokio-util/src/codec/mod.rs b/wrappers/tokio/impls/tokio-util/src/codec/mod.rs new file mode 100644 index 00000000..fe4ffa35 --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/src/codec/mod.rs @@ -0,0 +1,6 @@ +// There is nothing in tokio-util::codec which relies on tokio (apart from the traits AsyncRead/AsyncWrite) or std::sync. +// It should therefore be fine to simply reexport codec from tokio-util. +// Note that this is not code which is really used, or is going to be used much, and the reason we have it here +// is mostly for convenience, ie., it should "just work" to swap out tokio-util with shuttle-tokio-util. + +pub use tokio_util_orig::codec; \ No newline at end of file diff --git a/wrappers/tokio/impls/tokio-util/src/lib.rs b/wrappers/tokio/impls/tokio-util/src/lib.rs new file mode 100644 index 00000000..0aef8eb4 --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/src/lib.rs @@ -0,0 +1,20 @@ +#![allow(unknown_lints, unexpected_cfgs)] +//! A fork of [https://docs.rs/crate/tokio-util](https://docs.rs/crate/tokio-util) +//! This package is not intended to be depended on directly, and is not currently published. +//! It is a fork of the tokio-util package with imports swapped so that it builds atop Shuttle. +//! It exists for the other packages in the shuttle-tokio-"family" to use. + +#[macro_use] +mod cfg; + +cfg_rt! { + pub mod task; +} + +cfg_codec! { + pub mod codec; +} + +pub mod sync; + +mod util; diff --git a/wrappers/tokio/impls/tokio-util/src/sync/cancellation_token.rs b/wrappers/tokio/impls/tokio-util/src/sync/cancellation_token.rs new file mode 100644 index 00000000..142c3431 --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/src/sync/cancellation_token.rs @@ -0,0 +1,384 @@ +//! An asynchronously awaitable `CancellationToken`. +//! The token allows to signal a cancellation request to one or more tasks. +pub(crate) mod guard; +mod tree_node; + +use crate::util::MaybeDangling; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; +use shuttle::sync::Arc; + +use guard::DropGuard; +use pin_project_lite::pin_project; + +/// A token which can be used to signal a cancellation request to one or more +/// tasks. +/// +/// Tasks can call [`CancellationToken::cancelled()`] in order to +/// obtain a Future which will be resolved when cancellation is requested. +/// +/// Cancellation can be requested through the [`CancellationToken::cancel`] method. +/// +/// # Examples +/// +/// ```ignore +/// use tokio::select; +/// use tokio_util::sync::CancellationToken; +/// +/// #[tokio::main] +/// async fn main() { +/// let token = CancellationToken::new(); +/// let cloned_token = token.clone(); +/// +/// let join_handle = tokio::spawn(async move { +/// // Wait for either cancellation or a very long time +/// select! { +/// _ = cloned_token.cancelled() => { +/// // The token was cancelled +/// 5 +/// } +/// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { +/// 99 +/// } +/// } +/// }); +/// +/// tokio::spawn(async move { +/// tokio::time::sleep(std::time::Duration::from_millis(10)).await; +/// token.cancel(); +/// }); +/// +/// assert_eq!(5, join_handle.await.unwrap()); +/// } +/// ``` +pub struct CancellationToken { + inner: Arc, +} + +impl std::panic::UnwindSafe for CancellationToken {} +impl std::panic::RefUnwindSafe for CancellationToken {} + +pin_project! { + /// A Future that is resolved once the corresponding [`CancellationToken`] + /// is cancelled. + #[must_use = "futures do nothing unless polled"] + pub struct WaitForCancellationFuture<'a> { + cancellation_token: &'a CancellationToken, + #[pin] + future: tokio::sync::futures::Notified<'a>, + } +} + +pin_project! { + /// A Future that is resolved once the corresponding [`CancellationToken`] + /// is cancelled. + /// + /// This is the counterpart to [`WaitForCancellationFuture`] that takes + /// [`CancellationToken`] by value instead of using a reference. + #[must_use = "futures do nothing unless polled"] + pub struct WaitForCancellationFutureOwned { + // This field internally has a reference to the cancellation token, but camouflages + // the relationship with `'static`. To avoid Undefined Behavior, we must ensure + // that the reference is only used while the cancellation token is still alive. To + // do that, we ensure that the future is the first field, so that it is dropped + // before the cancellation token. + // + // We use `MaybeDanglingFuture` here because without it, the compiler could assert + // the reference inside `future` to be valid even after the destructor of that + // field runs. (Specifically, when the `WaitForCancellationFutureOwned` is passed + // as an argument to a function, the reference can be asserted to be valid for the + // rest of that function.) To avoid that, we use `MaybeDangling` which tells the + // compiler that the reference stored inside it might not be valid. + // + // See + // for more info. + #[pin] + future: MaybeDangling>, + cancellation_token: CancellationToken, + } +} + +// ===== impl CancellationToken ===== + +impl core::fmt::Debug for CancellationToken { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("CancellationToken") + .field("is_cancelled", &self.is_cancelled()) + .finish() + } +} + +impl Clone for CancellationToken { + /// Creates a clone of the `CancellationToken` which will get cancelled + /// whenever the current token gets cancelled, and vice versa. + fn clone(&self) -> Self { + tree_node::increase_handle_refcount(&self.inner); + CancellationToken { + inner: self.inner.clone(), + } + } +} + +impl Drop for CancellationToken { + fn drop(&mut self) { + tree_node::decrease_handle_refcount(&self.inner); + } +} + +impl Default for CancellationToken { + fn default() -> CancellationToken { + CancellationToken::new() + } +} + +impl CancellationToken { + /// Creates a new `CancellationToken` in the non-cancelled state. + pub fn new() -> CancellationToken { + CancellationToken { + inner: Arc::new(tree_node::TreeNode::new()), + } + } + + /// Creates a `CancellationToken` which will get cancelled whenever the + /// current token gets cancelled. Unlike a cloned `CancellationToken`, + /// cancelling a child token does not cancel the parent token. + /// + /// If the current token is already cancelled, the child token will get + /// returned in cancelled state. + /// + /// # Examples + /// + /// ```ignore + /// use tokio::select; + /// use tokio_util::sync::CancellationToken; + /// + /// #[tokio::main] + /// async fn main() { + /// let token = CancellationToken::new(); + /// let child_token = token.child_token(); + /// + /// let join_handle = tokio::spawn(async move { + /// // Wait for either cancellation or a very long time + /// select! { + /// _ = child_token.cancelled() => { + /// // The token was cancelled + /// 5 + /// } + /// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { + /// 99 + /// } + /// } + /// }); + /// + /// tokio::spawn(async move { + /// tokio::time::sleep(std::time::Duration::from_millis(10)).await; + /// token.cancel(); + /// }); + /// + /// assert_eq!(5, join_handle.await.unwrap()); + /// } + /// ``` + pub fn child_token(&self) -> CancellationToken { + CancellationToken { + inner: tree_node::child_node(&self.inner), + } + } + + /// Cancel the [`CancellationToken`] and all child tokens which had been + /// derived from it. + /// + /// This will wake up all tasks which are waiting for cancellation. + /// + /// Be aware that cancellation is not an atomic operation. It is possible + /// for another thread running in parallel with a call to `cancel` to first + /// receive `true` from `is_cancelled` on one child node, and then receive + /// `false` from `is_cancelled` on another child node. However, once the + /// call to `cancel` returns, all child nodes have been fully cancelled. + pub fn cancel(&self) { + tree_node::cancel(&self.inner); + } + + /// Returns `true` if the `CancellationToken` is cancelled. + pub fn is_cancelled(&self) -> bool { + tree_node::is_cancelled(&self.inner) + } + + /// Returns a `Future` that gets fulfilled when cancellation is requested. + /// + /// The future will complete immediately if the token is already cancelled + /// when this method is called. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { + WaitForCancellationFuture { + cancellation_token: self, + future: self.inner.notified(), + } + } + + /// Returns a `Future` that gets fulfilled when cancellation is requested. + /// + /// The future will complete immediately if the token is already cancelled + /// when this method is called. + /// + /// The function takes self by value and returns a future that owns the + /// token. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + pub fn cancelled_owned(self) -> WaitForCancellationFutureOwned { + WaitForCancellationFutureOwned::new(self) + } + + /// Creates a `DropGuard` for this token. + /// + /// Returned guard will cancel this token (and all its children) on drop + /// unless disarmed. + pub fn drop_guard(self) -> DropGuard { + DropGuard { inner: Some(self) } + } + + /// Runs a future to completion and returns its result wrapped inside of an `Option` + /// unless the `CancellationToken` is cancelled. In that case the function returns + /// `None` and the future gets dropped. + /// + /// # Cancel safety + /// + /// This method is only cancel safe if `fut` is cancel safe. + pub async fn run_until_cancelled(&self, fut: F) -> Option + where + F: Future, + { + pin_project! { + /// A Future that is resolved once the corresponding [`CancellationToken`] + /// is cancelled or a given Future gets resolved. It is biased towards the + /// Future completion. + #[must_use = "futures do nothing unless polled"] + struct RunUntilCancelledFuture<'a, F: Future> { + #[pin] + cancellation: WaitForCancellationFuture<'a>, + #[pin] + future: F, + } + } + + impl Future for RunUntilCancelledFuture<'_, F> { + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + if let Poll::Ready(res) = this.future.poll(cx) { + Poll::Ready(Some(res)) + } else if this.cancellation.poll(cx).is_ready() { + Poll::Ready(None) + } else { + Poll::Pending + } + } + } + + RunUntilCancelledFuture { + cancellation: self.cancelled(), + future: fut, + } + .await + } +} + +// ===== impl WaitForCancellationFuture ===== + +impl core::fmt::Debug for WaitForCancellationFuture<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitForCancellationFuture").finish() + } +} + +impl Future for WaitForCancellationFuture<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let mut this = self.project(); + loop { + if this.cancellation_token.is_cancelled() { + return Poll::Ready(()); + } + + // No wakeups can be lost here because there is always a call to + // `is_cancelled` between the creation of the future and the call to + // `poll`, and the code that sets the cancelled flag does so before + // waking the `Notified`. + if this.future.as_mut().poll(cx).is_pending() { + return Poll::Pending; + } + + this.future.set(this.cancellation_token.inner.notified()); + } + } +} + +// ===== impl WaitForCancellationFutureOwned ===== + +impl core::fmt::Debug for WaitForCancellationFutureOwned { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitForCancellationFutureOwned").finish() + } +} + +impl WaitForCancellationFutureOwned { + fn new(cancellation_token: CancellationToken) -> Self { + WaitForCancellationFutureOwned { + // cancellation_token holds a heap allocation and is guaranteed to have a + // stable deref, thus it would be ok to move the cancellation_token while + // the future holds a reference to it. + // + // # Safety + // + // cancellation_token is dropped after future due to the field ordering. + future: MaybeDangling::new(unsafe { Self::new_future(&cancellation_token) }), + cancellation_token, + } + } + + /// # Safety + /// The returned future must be destroyed before the cancellation token is + /// destroyed. + unsafe fn new_future(cancellation_token: &CancellationToken) -> tokio::sync::futures::Notified<'static> { + let inner_ptr = Arc::as_ptr(&cancellation_token.inner); + // SAFETY: The `Arc::as_ptr` method guarantees that `inner_ptr` remains + // valid until the strong count of the Arc drops to zero, and the caller + // guarantees that they will drop the future before that happens. + (*inner_ptr).notified() + } +} + +impl Future for WaitForCancellationFutureOwned { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let mut this = self.project(); + + loop { + if this.cancellation_token.is_cancelled() { + return Poll::Ready(()); + } + + // No wakeups can be lost here because there is always a call to + // `is_cancelled` between the creation of the future and the call to + // `poll`, and the code that sets the cancelled flag does so before + // waking the `Notified`. + if this.future.as_mut().poll(cx).is_pending() { + return Poll::Pending; + } + + // # Safety + // + // cancellation_token is dropped after future due to the field ordering. + this.future + .set(MaybeDangling::new(unsafe { Self::new_future(this.cancellation_token) })); + } + } +} diff --git a/wrappers/tokio/impls/tokio-util/src/sync/cancellation_token/guard.rs b/wrappers/tokio/impls/tokio-util/src/sync/cancellation_token/guard.rs new file mode 100644 index 00000000..a9740118 --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/src/sync/cancellation_token/guard.rs @@ -0,0 +1,25 @@ +use crate::sync::CancellationToken; + +/// A wrapper for cancellation token which automatically cancels +/// it on drop. It is created using `drop_guard` method on the `CancellationToken`. +#[derive(Debug)] +pub struct DropGuard { + pub(super) inner: Option, +} + +impl DropGuard { + /// Returns stored cancellation token and removes this drop guard instance + /// (i.e. it will no longer cancel token). Other guards for this token + /// are not affected. + pub fn disarm(mut self) -> CancellationToken { + self.inner.take().expect("`inner` can be only None in a destructor") + } +} + +impl Drop for DropGuard { + fn drop(&mut self) { + if let Some(inner) = &self.inner { + inner.cancel(); + } + } +} diff --git a/wrappers/tokio/impls/tokio-util/src/sync/cancellation_token/tree_node.rs b/wrappers/tokio/impls/tokio-util/src/sync/cancellation_token/tree_node.rs new file mode 100644 index 00000000..50dfa04d --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/src/sync/cancellation_token/tree_node.rs @@ -0,0 +1,374 @@ +//! This mod provides the logic for the inner tree structure of the `CancellationToken`. +//! +//! `CancellationTokens` are only light handles with references to [`TreeNode`]. +//! All the logic is actually implemented in the [`TreeNode`]. +//! +//! A [`TreeNode`] is part of the cancellation tree and may have one parent and an arbitrary number of +//! children. +//! +//! A [`TreeNode`] can receive the request to perform a cancellation through a `CancellationToken`. +//! This cancellation request will cancel the node and all of its descendants. +//! +//! As soon as a node cannot get cancelled any more (because it was already cancelled or it has no +//! more `CancellationTokens` pointing to it any more), it gets removed from the tree, to keep the +//! tree as small as possible. +//! +//! # Invariants +//! +//! Those invariants shall be true at any time. +//! +//! 1. A node that has no parents and no handles can no longer be cancelled. +//! This is important during both cancellation and refcounting. +//! +//! 2. If node B *is* or *was* a child of node A, then node B was created *after* node A. +//! This is important for deadlock safety, as it is used for lock order. +//! Node B can only become the child of node A in two ways: +//! - being created with `child_node()`, in which case it is trivially true that +//! node A already existed when node B was created +//! - being moved A->C->B to A->B because node C was removed in `decrease_handle_refcount()` +//! or `cancel()`. In this case the invariant still holds, as B was younger than C, and C +//! was younger than A, therefore B is also younger than A. +//! +//! 3. If two nodes are both unlocked and node A is the parent of node B, then node B is a child of +//! node A. It is important to always restore that invariant before dropping the lock of a node. +//! +//! # Deadlock safety +//! +//! We always lock in the order of creation time. We can prove this through invariant #2. +//! Specifically, through invariant #2, we know that we always have to lock a parent +//! before its child. +//! +use shuttle::sync::Arc; + +// NOTE: Different from Tokio. +// This implementation uses tokio::sync::Mutex instead of std::sync::Mutex to avoid +// internal consistency violations in Shuttle. +// NOTE: Uncomment the `TryLockError` parts below (2 branches in `with_locked_node_and_parent`) if we swap back to std::sync. +use tokio::sync::{Mutex, MutexGuard}; + +/// A node of the cancellation tree structure +/// +/// The actual data it holds is wrapped inside a mutex for synchronization. +pub(crate) struct TreeNode { + inner: Mutex, + waker: tokio::sync::Notify, +} +impl TreeNode { + pub(crate) fn new() -> Self { + Self { + inner: Mutex::new(Inner { + parent: None, + parent_idx: 0, + children: vec![], + is_cancelled: false, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + } + } + + pub(crate) fn notified(&self) -> tokio::sync::futures::Notified<'_> { + self.waker.notified() + } +} + +/// The data contained inside a `TreeNode`. +/// +/// This struct exists so that the data of the node can be wrapped +/// in a Mutex. +struct Inner { + parent: Option>, + parent_idx: usize, + children: Vec>, + is_cancelled: bool, + num_handles: usize, +} + +/// Returns whether or not the node is cancelled +pub(crate) fn is_cancelled(node: &Arc) -> bool { + node.inner.blocking_lock().is_cancelled +} + +/// Creates a child node +pub(crate) fn child_node(parent: &Arc) -> Arc { + let mut locked_parent = parent.inner.blocking_lock(); + + // Do not register as child if we are already cancelled. + // Cancelled trees can never be uncancelled and therefore + // need no connection to parents or children any more. + if locked_parent.is_cancelled { + return Arc::new(TreeNode { + inner: Mutex::new(Inner { + parent: None, + parent_idx: 0, + children: vec![], + is_cancelled: true, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + }); + } + + let child = Arc::new(TreeNode { + inner: Mutex::new(Inner { + parent: Some(parent.clone()), + parent_idx: locked_parent.children.len(), + children: vec![], + is_cancelled: false, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + }); + + locked_parent.children.push(child.clone()); + + child +} + +/// Disconnects the given parent from all of its children. +/// +/// Takes a reference to [Inner] to make sure the parent is already locked. +fn disconnect_children(node: &mut Inner) { + for child in std::mem::take(&mut node.children) { + let mut locked_child = child.inner.blocking_lock(); + locked_child.parent_idx = 0; + locked_child.parent = None; + } +} + +/// Figures out the parent of the node and locks the node and its parent atomically. +/// +/// The basic principle of preventing deadlocks in the tree is +/// that we always lock the parent first, and then the child. +/// For more info look at *deadlock safety* and *invariant #2*. +/// +/// Sadly, it's impossible to figure out the parent of a node without +/// locking it. To then achieve locking order consistency, the node +/// has to be unlocked before the parent gets locked. +/// This leaves a small window where we already assume that we know the parent, +/// but neither the parent nor the node is locked. Therefore, the parent could change. +/// +/// To prevent that this problem leaks into the rest of the code, it is abstracted +/// in this function. +/// +/// The locked child and optionally its locked parent, if a parent exists, get passed +/// to the `func` argument via (node, None) or (node, Some(parent)). +fn with_locked_node_and_parent(node: &Arc, func: F) -> Ret +where + F: FnOnce(MutexGuard<'_, Inner>, Option>) -> Ret, +{ + let mut locked_node = node.inner.blocking_lock(); + + // Every time this fails, the number of ancestors of the node decreases, + // so the loop must succeed after a finite number of iterations. + loop { + // Look up the parent of the currently locked node. + let potential_parent = match locked_node.parent.as_ref() { + Some(potential_parent) => potential_parent.clone(), + None => return func(locked_node, None), + }; + + // Lock the parent. This may require unlocking the child first. + let locked_parent = match potential_parent.inner.try_lock() { + Ok(locked_parent) => locked_parent, + // NOTE: Uncomment the `TryLockError::WouldBlock` part here and uncomment below if we swap back to std::sync + Err(_ /* TryLockError::WouldBlock */) => { + drop(locked_node); + // Deadlock safety: + // + // Due to invariant #2, the potential parent must come before + // the child in the creation order. Therefore, we can safely + // lock the child while holding the parent lock. + let locked_parent = potential_parent.inner.blocking_lock(); + locked_node = node.inner.blocking_lock(); + locked_parent + } // NOTE: Uncomment this and the `TryLockError::WouldBlock` part above if we swap back to std::sync + /* + // https://github.com/tokio-rs/tokio/pull/6273#discussion_r1443752911 + #[allow(clippy::unnecessary_literal_unwrap)] + Err(TryLockError::Poisoned(err)) => Err(err).unwrap(), + */ + }; + + // If we unlocked the child, then the parent may have changed. Check + // that we still have the right parent. + if let Some(actual_parent) = locked_node.parent.as_ref() { + if Arc::ptr_eq(actual_parent, &potential_parent) { + return func(locked_node, Some(locked_parent)); + } + } + } +} + +/// Moves all children from `node` to `parent`. +/// +/// `parent` MUST have been a parent of the node when they both got locked, +/// otherwise there is a potential for a deadlock as invariant #2 would be violated. +/// +/// To acquire the locks for node and parent, use [`with_locked_node_and_parent`]. +fn move_children_to_parent(node: &mut Inner, parent: &mut Inner) { + // Pre-allocate in the parent, for performance + parent.children.reserve(node.children.len()); + + for child in std::mem::take(&mut node.children) { + { + let mut child_locked = child.inner.blocking_lock(); + child_locked.parent.clone_from(&node.parent); + child_locked.parent_idx = parent.children.len(); + } + parent.children.push(child); + } +} + +/// Removes a child from the parent. +/// +/// `parent` MUST be the parent of `node`. +/// To acquire the locks for node and parent, use [`with_locked_node_and_parent`]. +fn remove_child(parent: &mut Inner, mut node: MutexGuard<'_, Inner>) { + // Query the position from where to remove a node + let pos = node.parent_idx; + node.parent = None; + node.parent_idx = 0; + + // Unlock node, so that only one child at a time is locked. + // Otherwise we would violate the lock order (see 'deadlock safety') as we + // don't know the creation order of the child nodes + drop(node); + + // If `node` is the last element in the list, we don't need any swapping + if parent.children.len() == pos + 1 { + parent.children.pop().unwrap(); + } else { + // If `node` is not the last element in the list, we need to + // replace it with the last element + let replacement_child = parent.children.pop().unwrap(); + replacement_child.inner.blocking_lock().parent_idx = pos; + parent.children[pos] = replacement_child; + } + + let len = parent.children.len(); + if 4 * len <= parent.children.capacity() { + parent.children.shrink_to(2 * len); + } +} + +/// Increases the reference count of handles. +pub(crate) fn increase_handle_refcount(node: &Arc) { + let mut locked_node = node.inner.blocking_lock(); + + // Once no handles are left over, the node gets detached from the tree. + // There should never be a new handle once all handles are dropped. + assert!(locked_node.num_handles > 0); + + locked_node.num_handles += 1; +} + +/// Decreases the reference count of handles. +/// +/// Once no handle is left, we can remove the node from the +/// tree and connect its parent directly to its children. +pub(crate) fn decrease_handle_refcount(node: &Arc) { + let num_handles = { + let mut locked_node = node.inner.blocking_lock(); + locked_node.num_handles -= 1; + locked_node.num_handles + }; + + if num_handles == 0 { + with_locked_node_and_parent(node, |mut node, parent| { + // Remove the node from the tree + match parent { + Some(mut parent) => { + // As we want to remove ourselves from the tree, + // we have to move the children to the parent, so that + // they still receive the cancellation event without us. + // Moving them does not violate invariant #1. + move_children_to_parent(&mut node, &mut parent); + + // Remove the node from the parent + remove_child(&mut parent, node); + } + None => { + // Due to invariant #1, we can assume that our + // children can no longer be cancelled through us. + // (as we now have neither a parent nor handles) + // Therefore we can disconnect them. + disconnect_children(&mut node); + } + } + }); + } +} + +/// Cancels a node and its children. +pub(crate) fn cancel(node: &Arc) { + let mut locked_node = node.inner.blocking_lock(); + + if locked_node.is_cancelled { + return; + } + + // One by one, adopt grandchildren and then cancel and detach the child + while let Some(child) = locked_node.children.pop() { + // This can't deadlock because the mutex we are already + // holding is the parent of child. + let mut locked_child = child.inner.blocking_lock(); + + // Detach the child from node + // No need to modify node.children, as the child already got removed with `.pop` + locked_child.parent = None; + locked_child.parent_idx = 0; + + // If child is already cancelled, detaching is enough + if locked_child.is_cancelled { + continue; + } + + // Cancel or adopt grandchildren + while let Some(grandchild) = locked_child.children.pop() { + // This can't deadlock because the two mutexes we are already + // holding is the parent and grandparent of grandchild. + let mut locked_grandchild = grandchild.inner.blocking_lock(); + + // Detach the grandchild + locked_grandchild.parent = None; + locked_grandchild.parent_idx = 0; + + // If grandchild is already cancelled, detaching is enough + if locked_grandchild.is_cancelled { + continue; + } + + // For performance reasons, only adopt grandchildren that have children. + // Otherwise, just cancel them right away, no need for another iteration. + if locked_grandchild.children.is_empty() { + // Cancel the grandchild + locked_grandchild.is_cancelled = true; + locked_grandchild.children = Vec::new(); + drop(locked_grandchild); + grandchild.waker.notify_waiters(); + } else { + // Otherwise, adopt grandchild + locked_grandchild.parent = Some(node.clone()); + locked_grandchild.parent_idx = locked_node.children.len(); + drop(locked_grandchild); + locked_node.children.push(grandchild); + } + } + + // Cancel the child + locked_child.is_cancelled = true; + locked_child.children = Vec::new(); + drop(locked_child); + child.waker.notify_waiters(); + + // Now the child is cancelled and detached and all its children are adopted. + // Just continue until all (including adopted) children are cancelled and detached. + } + + // Cancel the node itself. + locked_node.is_cancelled = true; + locked_node.children = Vec::new(); + drop(locked_node); + node.waker.notify_waiters(); +} diff --git a/wrappers/tokio/impls/tokio-util/src/sync/mod.rs b/wrappers/tokio/impls/tokio-util/src/sync/mod.rs new file mode 100644 index 00000000..a330d6aa --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/src/sync/mod.rs @@ -0,0 +1,6 @@ +//! Synchronization primitives + +mod cancellation_token; +pub use cancellation_token::{ + guard::DropGuard, CancellationToken, WaitForCancellationFuture, WaitForCancellationFutureOwned, +}; diff --git a/wrappers/tokio/impls/tokio-util/src/task/mod.rs b/wrappers/tokio/impls/tokio-util/src/task/mod.rs new file mode 100644 index 00000000..7ac06f10 --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/src/task/mod.rs @@ -0,0 +1,4 @@ +//! Extra utilities for spawning tasks + +pub mod task_tracker; +pub use task_tracker::TaskTracker; \ No newline at end of file diff --git a/wrappers/tokio/impls/tokio-util/src/task/task_tracker.rs b/wrappers/tokio/impls/tokio-util/src/task/task_tracker.rs new file mode 100644 index 00000000..d2ff810f --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/src/task/task_tracker.rs @@ -0,0 +1,722 @@ +//! Types related to the [`TaskTracker`] collection. +//! +//! See the documentation of [`TaskTracker`] for more information. + +use pin_project_lite::pin_project; +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use shuttle::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::{futures::Notified, Notify}; + +#[cfg(feature = "rt")] +use tokio::{ + runtime::Handle, + task::JoinHandle, +}; + +/// A task tracker used for waiting until tasks exit. +/// +/// This is usually used together with [`CancellationToken`] to implement [graceful shutdown]. The +/// `CancellationToken` is used to signal to tasks that they should shut down, and the +/// `TaskTracker` is used to wait for them to finish shutting down. +/// +/// The `TaskTracker` will also keep track of a `closed` boolean. This is used to handle the case +/// where the `TaskTracker` is empty, but we don't want to shut down yet. This means that the +/// [`wait`] method will wait until *both* of the following happen at the same time: +/// +/// * The `TaskTracker` must be closed using the [`close`] method. +/// * The `TaskTracker` must be empty, that is, all tasks that it is tracking must have exited. +/// +/// When a call to [`wait`] returns, it is guaranteed that all tracked tasks have exited and that +/// the destructor of the future has finished running. However, there might be a short amount of +/// time where [`JoinHandle::is_finished`] returns false. +/// +/// # Comparison to `JoinSet` +/// +/// The main Tokio crate has a similar collection known as [`JoinSet`]. The `JoinSet` type has a +/// lot more features than `TaskTracker`, so `TaskTracker` should only be used when one of its +/// unique features is required: +/// +/// 1. When tasks exit, a `TaskTracker` will allow the task to immediately free its memory. +/// 2. By not closing the `TaskTracker`, [`wait`] will be prevented from from returning even if +/// the `TaskTracker` is empty. +/// 3. A `TaskTracker` does not require mutable access to insert tasks. +/// 4. A `TaskTracker` can be cloned to share it with many tasks. +/// +/// The first point is the most important one. A [`JoinSet`] keeps track of the return value of +/// every inserted task. This means that if the caller keeps inserting tasks and never calls +/// [`join_next`], then their return values will keep building up and consuming memory, _even if_ +/// most of the tasks have already exited. This can cause the process to run out of memory. With a +/// `TaskTracker`, this does not happen. Once tasks exit, they are immediately removed from the +/// `TaskTracker`. +/// +/// # Examples +/// +/// For more examples, please see the topic page on [graceful shutdown]. +/// +/// ## Spawn tasks and wait for them to exit +/// +/// This is a simple example. For this case, [`JoinSet`] should probably be used instead. +/// +/// ```ignore +/// use tokio_util::task::TaskTracker; +/// +/// #[tokio::main] +/// async fn main() { +/// let tracker = TaskTracker::new(); +/// +/// for i in 0..10 { +/// tracker.spawn(async move { +/// println!("Task {} is running!", i); +/// }); +/// } +/// // Once we spawned everything, we close the tracker. +/// tracker.close(); +/// +/// // Wait for everything to finish. +/// tracker.wait().await; +/// +/// println!("This is printed after all of the tasks."); +/// } +/// ``` +/// +/// ## Wait for tasks to exit +/// +/// This example shows the intended use-case of `TaskTracker`. It is used together with +/// [`CancellationToken`] to implement graceful shutdown. +/// ```ignore +/// use tokio_util::sync::CancellationToken; +/// use tokio_util::task::TaskTracker; +/// use tokio::time::{self, Duration}; +/// +/// async fn background_task(num: u64) { +/// for i in 0..10 { +/// time::sleep(Duration::from_millis(100*num)).await; +/// println!("Background task {} in iteration {}.", num, i); +/// } +/// } +/// +/// #[tokio::main] +/// # async fn _hidden() {} +/// # #[tokio::main(flavor = "current_thread", start_paused = true)] +/// async fn main() { +/// let tracker = TaskTracker::new(); +/// let token = CancellationToken::new(); +/// +/// for i in 0..10 { +/// let token = token.clone(); +/// tracker.spawn(async move { +/// // Use a `tokio::select!` to kill the background task if the token is +/// // cancelled. +/// tokio::select! { +/// () = background_task(i) => { +/// println!("Task {} exiting normally.", i); +/// }, +/// () = token.cancelled() => { +/// // Do some cleanup before we really exit. +/// time::sleep(Duration::from_millis(50)).await; +/// println!("Task {} finished cleanup.", i); +/// }, +/// } +/// }); +/// } +/// +/// // Spawn a background task that will send the shutdown signal. +/// { +/// let tracker = tracker.clone(); +/// tokio::spawn(async move { +/// // Normally you would use something like ctrl-c instead of +/// // sleeping. +/// time::sleep(Duration::from_secs(2)).await; +/// tracker.close(); +/// token.cancel(); +/// }); +/// } +/// +/// // Wait for all tasks to exit. +/// tracker.wait().await; +/// +/// println!("All tasks have exited now."); +/// } +/// ``` +/// +/// [`CancellationToken`]: crate::sync::CancellationToken +/// [`JoinHandle::is_finished`]: tokio::task::JoinHandle::is_finished +/// [`JoinSet`]: tokio::task::JoinSet +/// [`close`]: Self::close +/// [`join_next`]: tokio::task::JoinSet::join_next +/// [`wait`]: Self::wait +/// [graceful shutdown]: https://tokio.rs/tokio/topics/shutdown +pub struct TaskTracker { + inner: Arc, +} + +/// Represents a task tracked by a [`TaskTracker`]. +#[must_use] +#[derive(Debug)] +pub struct TaskTrackerToken { + task_tracker: TaskTracker, +} + +struct TaskTrackerInner { + /// Keeps track of the state. + /// + /// The lowest bit is whether the task tracker is closed. + /// + /// The rest of the bits count the number of tracked tasks. + state: AtomicUsize, + /// Used to notify when the last task exits. + on_last_exit: Notify, +} + +pin_project! { + /// A future that is tracked as a task by a [`TaskTracker`]. + /// + /// The associated [`TaskTracker`] cannot complete until this future is dropped. + /// + /// This future is returned by [`TaskTracker::track_future`]. + #[must_use = "futures do nothing unless polled"] + pub struct TrackedFuture { + #[pin] + future: F, + token: TaskTrackerToken, + } +} + +pin_project! { + /// A future that completes when the [`TaskTracker`] is empty and closed. + /// + /// This future is returned by [`TaskTracker::wait`]. + #[must_use = "futures do nothing unless polled"] + pub struct TaskTrackerWaitFuture<'a> { + #[pin] + future: Notified<'a>, + inner: Option<&'a TaskTrackerInner>, + } +} + +impl TaskTrackerInner { + #[inline] + fn new() -> Self { + Self { + state: AtomicUsize::new(0), + on_last_exit: Notify::new(), + } + } + + #[inline] + fn is_closed_and_empty(&self) -> bool { + // If empty and closed bit set, then we are done. + // + // The acquire load will synchronize with the release store of any previous call to + // `set_closed` and `drop_task`. + self.state.load(Ordering::Acquire) == 1 + } + + #[inline] + fn set_closed(&self) -> bool { + // The AcqRel ordering makes the closed bit behave like a `Mutex` for synchronization + // purposes. We do this because it makes the return value of `TaskTracker::{close,reopen}` + // more meaningful for the user. Without these orderings, this assert could fail: + // ``` + // // thread 1 + // some_other_atomic.store(true, Relaxed); + // tracker.close(); + // + // // thread 2 + // if tracker.reopen() { + // assert!(some_other_atomic.load(Relaxed)); + // } + // ``` + // However, with the AcqRel ordering, we establish a happens-before relationship from the + // call to `close` and the later call to `reopen` that returned true. + let state = self.state.fetch_or(1, Ordering::AcqRel); + + // If there are no tasks, and if it was not already closed: + if state == 0 { + self.notify_now(); + } + + (state & 1) == 0 + } + + #[inline] + fn set_open(&self) -> bool { + // See `set_closed` regarding the AcqRel ordering. + let state = self.state.fetch_and(!1, Ordering::AcqRel); + (state & 1) == 1 + } + + #[inline] + fn add_task(&self) { + self.state.fetch_add(2, Ordering::Relaxed); + } + + #[inline] + fn drop_task(&self) { + let state = self.state.fetch_sub(2, Ordering::Release); + + // If this was the last task and we are closed: + if state == 3 { + self.notify_now(); + } + } + + #[cold] + fn notify_now(&self) { + // Insert an acquire fence. This matters for `drop_task` but doesn't matter for + // `set_closed` since it already uses AcqRel. + // + // This synchronizes with the release store of any other call to `drop_task`, and with the + // release store in the call to `set_closed`. That ensures that everything that happened + // before those other calls to `drop_task` or `set_closed` will be visible after this load, + // and those things will also be visible to anything woken by the call to `notify_waiters`. + self.state.load(Ordering::Acquire); + + self.on_last_exit.notify_waiters(); + } +} + +impl TaskTracker { + /// Creates a new `TaskTracker`. + /// + /// The `TaskTracker` will start out as open. + #[must_use] + pub fn new() -> Self { + Self { + inner: Arc::new(TaskTrackerInner::new()), + } + } + + /// Waits until this `TaskTracker` is both closed and empty. + /// + /// If the `TaskTracker` is already closed and empty when this method is called, then it + /// returns immediately. + /// + /// The `wait` future is resistant against [ABA problems][aba]. That is, if the `TaskTracker` + /// becomes both closed and empty for a short amount of time, then it is guarantee that all + /// `wait` futures that were created before the short time interval will trigger, even if they + /// are not polled during that short time interval. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + /// + /// However, the resistance against [ABA problems][aba] is lost when using `wait` as the + /// condition in a `tokio::select!` loop. + /// + /// [aba]: https://en.wikipedia.org/wiki/ABA_problem + #[inline] + pub fn wait(&self) -> TaskTrackerWaitFuture<'_> { + TaskTrackerWaitFuture { + future: self.inner.on_last_exit.notified(), + inner: if self.inner.is_closed_and_empty() { + None + } else { + Some(&self.inner) + }, + } + } + + /// Close this `TaskTracker`. + /// + /// This allows [`wait`] futures to complete. It does not prevent you from spawning new tasks. + /// + /// Returns `true` if this closed the `TaskTracker`, or `false` if it was already closed. + /// + /// [`wait`]: Self::wait + #[inline] + pub fn close(&self) -> bool { + self.inner.set_closed() + } + + /// Reopen this `TaskTracker`. + /// + /// This prevents [`wait`] futures from completing even if the `TaskTracker` is empty. + /// + /// Returns `true` if this reopened the `TaskTracker`, or `false` if it was already open. + /// + /// [`wait`]: Self::wait + #[inline] + pub fn reopen(&self) -> bool { + self.inner.set_open() + } + + /// Returns `true` if this `TaskTracker` is [closed](Self::close). + #[inline] + #[must_use] + pub fn is_closed(&self) -> bool { + (self.inner.state.load(Ordering::Acquire) & 1) != 0 + } + + /// Returns the number of tasks tracked by this `TaskTracker`. + #[inline] + #[must_use] + pub fn len(&self) -> usize { + self.inner.state.load(Ordering::Acquire) >> 1 + } + + /// Returns `true` if there are no tasks in this `TaskTracker`. + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.inner.state.load(Ordering::Acquire) <= 1 + } + + /// Spawn the provided future on the current Tokio runtime, and track it in this `TaskTracker`. + /// + /// This is equivalent to `tokio::spawn(tracker.track_future(task))`. + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn(&self, task: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + tokio::task::spawn(self.track_future(task)) + } + + /// Spawn the provided future on the provided Tokio runtime, and track it in this `TaskTracker`. + /// + /// This is equivalent to `handle.spawn(tracker.track_future(task))`. + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_on(&self, task: F, handle: &Handle) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + handle.spawn(self.track_future(task)) + } + + // TODO: Requires `shuttle_tokio::task::LocalSet` + /* + /// Spawn the provided future on the current [`LocalSet`], and track it in this `TaskTracker`. + /// + /// This is equivalent to `tokio::task::spawn_local(tracker.track_future(task))`. + /// + /// [`LocalSet`]: tokio::task::LocalSet + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_local(&self, task: F) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + tokio::task::spawn_local(self.track_future(task)) + } + + /// Spawn the provided future on the provided [`LocalSet`], and track it in this `TaskTracker`. + /// + /// This is equivalent to `local_set.spawn_local(tracker.track_future(task))`. + /// + /// [`LocalSet`]: tokio::task::LocalSet + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_local_on(&self, task: F, local_set: &LocalSet) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + local_set.spawn_local(self.track_future(task)) + } + */ + + /// Spawn the provided blocking task on the current Tokio runtime, and track it in this `TaskTracker`. + /// + /// This is equivalent to `tokio::task::spawn_blocking(tracker.track_future(task))`. + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg(not(target_family = "wasm"))] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_blocking(&self, task: F) -> JoinHandle + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send + 'static, + { + let token = self.token(); + tokio::task::spawn_blocking(move || { + let res = task(); + drop(token); + res + }) + } + + /// Spawn the provided blocking task on the provided Tokio runtime, and track it in this `TaskTracker`. + /// + /// This is equivalent to `handle.spawn_blocking(tracker.track_future(task))`. + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg(not(target_family = "wasm"))] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_blocking_on(&self, task: F, handle: &Handle) -> JoinHandle + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send + 'static, + { + let token = self.token(); + handle.spawn_blocking(move || { + let res = task(); + drop(token); + res + }) + } + + /// Track the provided future. + /// + /// The returned [`TrackedFuture`] will count as a task tracked by this collection, and will + /// prevent calls to [`wait`] from returning until the task is dropped. + /// + /// The task is removed from the collection when it is dropped, not when [`poll`] returns + /// [`Poll::Ready`]. + /// + /// # Examples + /// + /// Track a future spawned with [`tokio::spawn`]. + /// + /// ```ignore + /// # async fn my_async_fn() {} + /// use tokio_util::task::TaskTracker; + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let tracker = TaskTracker::new(); + /// + /// tokio::spawn(tracker.track_future(my_async_fn())); + /// # } + /// ``` + /// + /// Track a future spawned on a [`JoinSet`]. + /// ```ignore + /// # async fn my_async_fn() {} + /// use tokio::task::JoinSet; + /// use tokio_util::task::TaskTracker; + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let tracker = TaskTracker::new(); + /// let mut join_set = JoinSet::new(); + /// + /// join_set.spawn(tracker.track_future(my_async_fn())); + /// # } + /// ``` + /// + /// [`JoinSet`]: tokio::task::JoinSet + /// [`Poll::Pending`]: std::task::Poll::Pending + /// [`poll`]: std::future::Future::poll + /// [`wait`]: Self::wait + #[inline] + pub fn track_future(&self, future: F) -> TrackedFuture { + TrackedFuture { + future, + token: self.token(), + } + } + + /// Creates a [`TaskTrackerToken`] representing a task tracked by this `TaskTracker`. + /// + /// This token is a lower-level utility than the spawn methods. Each token is considered to + /// correspond to a task. As long as the token exists, the `TaskTracker` cannot complete. + /// Furthermore, the count returned by the [`len`] method will include the tokens in the count. + /// + /// Dropping the token indicates to the `TaskTracker` that the task has exited. + /// + /// [`len`]: TaskTracker::len + #[inline] + pub fn token(&self) -> TaskTrackerToken { + self.inner.add_task(); + TaskTrackerToken { + task_tracker: self.clone(), + } + } + + /// Returns `true` if both task trackers correspond to the same set of tasks. + /// + /// # Examples + /// + /// ```ignore + /// use tokio_util::task::TaskTracker; + /// + /// let tracker_1 = TaskTracker::new(); + /// let tracker_2 = TaskTracker::new(); + /// let tracker_1_clone = tracker_1.clone(); + /// + /// assert!(TaskTracker::ptr_eq(&tracker_1, &tracker_1_clone)); + /// assert!(!TaskTracker::ptr_eq(&tracker_1, &tracker_2)); + /// ``` + #[inline] + #[must_use] + pub fn ptr_eq(left: &TaskTracker, right: &TaskTracker) -> bool { + Arc::ptr_eq(&left.inner, &right.inner) + } +} + +impl Default for TaskTracker { + /// Creates a new `TaskTracker`. + /// + /// The `TaskTracker` will start out as open. + #[inline] + fn default() -> TaskTracker { + TaskTracker::new() + } +} + +impl Clone for TaskTracker { + /// Returns a new `TaskTracker` that tracks the same set of tasks. + /// + /// Since the new `TaskTracker` shares the same set of tasks, changes to one set are visible in + /// all other clones. + /// + /// # Examples + /// + /// ```ignore + /// use tokio_util::task::TaskTracker; + /// + /// #[tokio::main] + /// # async fn _hidden() {} + /// # #[tokio::main(flavor = "current_thread")] + /// async fn main() { + /// let tracker = TaskTracker::new(); + /// let cloned = tracker.clone(); + /// + /// // Spawns on `tracker` are visible in `cloned`. + /// tracker.spawn(std::future::pending::<()>()); + /// assert_eq!(cloned.len(), 1); + /// + /// // Spawns on `cloned` are visible in `tracker`. + /// cloned.spawn(std::future::pending::<()>()); + /// assert_eq!(tracker.len(), 2); + /// + /// // Calling `close` is visible to `cloned`. + /// tracker.close(); + /// assert!(cloned.is_closed()); + /// + /// // Calling `reopen` is visible to `tracker`. + /// cloned.reopen(); + /// assert!(!tracker.is_closed()); + /// } + /// ``` + #[inline] + fn clone(&self) -> TaskTracker { + Self { + inner: self.inner.clone(), + } + } +} + +fn debug_inner(inner: &TaskTrackerInner, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let state = inner.state.load(Ordering::Acquire); + let is_closed = (state & 1) != 0; + let len = state >> 1; + + f.debug_struct("TaskTracker") + .field("len", &len) + .field("is_closed", &is_closed) + .field("inner", &std::ptr::from_ref::(inner)) + .finish() +} + +impl fmt::Debug for TaskTracker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_inner(&self.inner, f) + } +} + +impl TaskTrackerToken { + /// Returns the [`TaskTracker`] that this token is associated with. + #[inline] + #[must_use] + pub fn task_tracker(&self) -> &TaskTracker { + &self.task_tracker + } +} + +impl Clone for TaskTrackerToken { + /// Returns a new `TaskTrackerToken` associated with the same [`TaskTracker`]. + /// + /// This is equivalent to `token.task_tracker().token()`. + #[inline] + fn clone(&self) -> TaskTrackerToken { + self.task_tracker.token() + } +} + +impl Drop for TaskTrackerToken { + /// Dropping the token indicates to the [`TaskTracker`] that the task has exited. + #[inline] + fn drop(&mut self) { + self.task_tracker.inner.drop_task(); + } +} + +impl Future for TrackedFuture { + type Output = F::Output; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().future.poll(cx) + } +} + +impl fmt::Debug for TrackedFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TrackedFuture") + .field("future", &self.future) + .field("task_tracker", self.token.task_tracker()) + .finish() + } +} + +impl Future for TaskTrackerWaitFuture<'_> { + type Output = (); + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let me = self.project(); + + let inner = match me.inner.as_ref() { + None => return Poll::Ready(()), + Some(inner) => inner, + }; + + let ready = inner.is_closed_and_empty() || me.future.poll(cx).is_ready(); + if ready { + *me.inner = None; + Poll::Ready(()) + } else { + Poll::Pending + } + } +} + +impl fmt::Debug for TaskTrackerWaitFuture<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Helper<'a>(&'a TaskTrackerInner); + + impl fmt::Debug for Helper<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_inner(self.0, f) + } + } + + f.debug_struct("TaskTrackerWaitFuture") + .field("future", &self.future) + .field("task_tracker", &self.inner.map(Helper)) + .finish() + } +} diff --git a/wrappers/tokio/impls/tokio-util/src/util/maybe_dangling.rs b/wrappers/tokio/impls/tokio-util/src/util/maybe_dangling.rs new file mode 100644 index 00000000..c29a0894 --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/src/util/maybe_dangling.rs @@ -0,0 +1,67 @@ +use core::future::Future; +use core::mem::MaybeUninit; +use core::pin::Pin; +use core::task::{Context, Poll}; + +/// A wrapper type that tells the compiler that the contents might not be valid. +/// +/// This is necessary mainly when `T` contains a reference. In that case, the +/// compiler will sometimes assume that the reference is always valid; in some +/// cases it will assume this even after the destructor of `T` runs. For +/// example, when a reference is used as a function argument, then the compiler +/// will assume that the reference is valid until the function returns, even if +/// the reference is destroyed during the function. When the reference is used +/// as part of a self-referential struct, that assumption can be false. Wrapping +/// the reference in this type prevents the compiler from making that +/// assumption. +/// +/// # Invariants +/// +/// The `MaybeUninit` will always contain a valid value until the destructor runs. +// +// Reference +// See +// +// TODO: replace this with an official solution once RFC #3336 or similar is available. +// +#[repr(transparent)] +pub(crate) struct MaybeDangling(MaybeUninit); + +impl Drop for MaybeDangling { + fn drop(&mut self) { + // Safety: `0` is always initialized. + unsafe { core::ptr::drop_in_place(self.0.as_mut_ptr()) }; + } +} + +impl MaybeDangling { + pub(crate) fn new(inner: T) -> Self { + Self(MaybeUninit::new(inner)) + } +} + +impl Future for MaybeDangling { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Safety: `0` is always initialized. + let fut = unsafe { self.map_unchecked_mut(|this| this.0.assume_init_mut()) }; + fut.poll(cx) + } +} + +#[test] +fn maybedangling_runs_drop() { + struct SetOnDrop<'a>(&'a mut bool); + + impl Drop for SetOnDrop<'_> { + fn drop(&mut self) { + *self.0 = true; + } + } + + let mut success = false; + + drop(MaybeDangling::new(SetOnDrop(&mut success))); + assert!(success); +} diff --git a/wrappers/tokio/impls/tokio-util/src/util/mod.rs b/wrappers/tokio/impls/tokio-util/src/util/mod.rs new file mode 100644 index 00000000..90e563e6 --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/src/util/mod.rs @@ -0,0 +1,3 @@ +mod maybe_dangling; + +pub(crate) use maybe_dangling::MaybeDangling; diff --git a/wrappers/tokio/impls/tokio-util/tests/sync_cancellation_token.rs b/wrappers/tokio/impls/tokio-util/tests/sync_cancellation_token.rs new file mode 100644 index 00000000..59bdaa29 --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/tests/sync_cancellation_token.rs @@ -0,0 +1,450 @@ +#![warn(rust_2018_idioms)] + +use shuttle_tokio_util_impl::sync::{CancellationToken, WaitForCancellationFuture}; +use tokio::pin; +use tokio::sync::oneshot; + +use core::future::Future; +use core::task::{Context, Poll}; +use futures_test::task::new_count_waker; + +#[tokio::test] +async fn cancel_token() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + assert!(!token.is_cancelled()); + + let wait_fut = token.cancelled(); + pin!(wait_fut); + + assert_eq!(Poll::Pending, wait_fut.as_mut().poll(&mut Context::from_waker(&waker))); + assert_eq!(wake_counter, 0); + + let wait_fut_2 = token.cancelled(); + pin!(wait_fut_2); + + token.cancel(); + assert_eq!(wake_counter, 1); + assert!(token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[tokio::test] +async fn cancel_token_owned() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + assert!(!token.is_cancelled()); + + let wait_fut = token.clone().cancelled_owned(); + pin!(wait_fut); + + assert_eq!(Poll::Pending, wait_fut.as_mut().poll(&mut Context::from_waker(&waker))); + assert_eq!(wake_counter, 0); + + let wait_fut_2 = token.clone().cancelled_owned(); + pin!(wait_fut_2); + + token.cancel(); + assert_eq!(wake_counter, 1); + assert!(token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[tokio::test] +async fn cancel_token_owned_drop_test() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let future = token.cancelled_owned(); + pin!(future); + + assert_eq!(Poll::Pending, future.as_mut().poll(&mut Context::from_waker(&waker))); + assert_eq!(wake_counter, 0); + + // let future be dropped while pinned and under pending state to + // find potential memory related bugs. +} + +#[tokio::test] +async fn cancel_child_token_through_parent() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let child_token = token.child_token(); + assert!(!child_token.is_cancelled()); + + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!(Poll::Pending, child_fut.as_mut().poll(&mut Context::from_waker(&waker))); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + token.cancel(); + assert_eq!(wake_counter, 2); + assert!(token.is_cancelled()); + assert!(child_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[tokio::test] +async fn cancel_grandchild_token_through_parent_if_child_was_dropped() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let intermediate_token = token.child_token(); + let child_token = intermediate_token.child_token(); + drop(intermediate_token); + assert!(!child_token.is_cancelled()); + + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!(Poll::Pending, child_fut.as_mut().poll(&mut Context::from_waker(&waker))); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + token.cancel(); + assert_eq!(wake_counter, 2); + assert!(token.is_cancelled()); + assert!(child_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[tokio::test] +async fn cancel_child_token_without_parent() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let child_token_1 = token.child_token(); + + let child_fut = child_token_1.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!(Poll::Pending, child_fut.as_mut().poll(&mut Context::from_waker(&waker))); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + child_token_1.cancel(); + assert_eq!(wake_counter, 1); + assert!(!token.is_cancelled()); + assert!(child_token_1.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + let child_token_2 = token.child_token(); + let child_fut_2 = child_token_2.cancelled(); + pin!(child_fut_2); + + assert_eq!( + Poll::Pending, + child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + token.cancel(); + assert_eq!(wake_counter, 3); + assert!(token.is_cancelled()); + assert!(child_token_2.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[tokio::test] +async fn create_child_token_after_parent_was_cancelled() { + for drop_child_first in [true, false].iter().copied() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + token.cancel(); + + let child_token = token.child_token(); + assert!(child_token.is_cancelled()); + + { + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + } + + if drop_child_first { + drop(child_token); + drop(token); + } else { + drop(token); + drop(child_token); + } + } +} + +#[tokio::test] +async fn drop_multiple_child_tokens() { + for drop_first_child_first in &[true, false] { + let token = CancellationToken::new(); + let mut child_tokens = [None, None, None]; + for child in &mut child_tokens { + *child = Some(token.child_token()); + } + + assert!(!token.is_cancelled()); + assert!(!child_tokens[0].as_ref().unwrap().is_cancelled()); + + for i in 0..child_tokens.len() { + if *drop_first_child_first { + child_tokens[i] = None; + } else { + child_tokens[child_tokens.len() - 1 - i] = None; + } + assert!(!token.is_cancelled()); + } + + drop(token); + } +} + +#[tokio::test] +async fn cancel_only_all_descendants() { + // ARRANGE + let (waker, wake_counter) = new_count_waker(); + + let parent_token = CancellationToken::new(); + let token = parent_token.child_token(); + let sibling_token = parent_token.child_token(); + let child1_token = token.child_token(); + let child2_token = token.child_token(); + let grandchild_token = child1_token.child_token(); + let grandchild2_token = child1_token.child_token(); + let great_grandchild_token = grandchild_token.child_token(); + + assert!(!parent_token.is_cancelled()); + assert!(!token.is_cancelled()); + assert!(!sibling_token.is_cancelled()); + assert!(!child1_token.is_cancelled()); + assert!(!child2_token.is_cancelled()); + assert!(!grandchild_token.is_cancelled()); + assert!(!grandchild2_token.is_cancelled()); + assert!(!great_grandchild_token.is_cancelled()); + + let parent_fut = parent_token.cancelled(); + let fut = token.cancelled(); + let sibling_fut = sibling_token.cancelled(); + let child1_fut = child1_token.cancelled(); + let child2_fut = child2_token.cancelled(); + let grandchild_fut = grandchild_token.cancelled(); + let grandchild2_fut = grandchild2_token.cancelled(); + let great_grandchild_fut = great_grandchild_token.cancelled(); + + pin!(parent_fut); + pin!(fut); + pin!(sibling_fut); + pin!(child1_fut); + pin!(child2_fut); + pin!(grandchild_fut); + pin!(grandchild2_fut); + pin!(great_grandchild_fut); + + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(Poll::Pending, fut.as_mut().poll(&mut Context::from_waker(&waker))); + assert_eq!( + Poll::Pending, + sibling_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + child1_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + child2_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + grandchild_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + grandchild2_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + great_grandchild_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + // ACT + token.cancel(); + + // ASSERT + assert_eq!(wake_counter, 6); + assert!(!parent_token.is_cancelled()); + assert!(token.is_cancelled()); + assert!(!sibling_token.is_cancelled()); + assert!(child1_token.is_cancelled()); + assert!(child2_token.is_cancelled()); + assert!(grandchild_token.is_cancelled()); + assert!(grandchild2_token.is_cancelled()); + assert!(great_grandchild_token.is_cancelled()); + + assert_eq!(Poll::Ready(()), fut.as_mut().poll(&mut Context::from_waker(&waker))); + assert_eq!( + Poll::Ready(()), + child1_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + child2_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + grandchild_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + grandchild2_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + great_grandchild_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 6); +} + +#[tokio::test] +async fn drop_parent_before_child_tokens() { + let token = CancellationToken::new(); + let child1 = token.child_token(); + let child2 = token.child_token(); + + drop(token); + assert!(!child1.is_cancelled()); + + drop(child1); + drop(child2); +} + +#[tokio::test] +async fn derives_send_sync() { + fn assert_send() {} + fn assert_sync() {} + + assert_send::(); + assert_sync::(); + + assert_send::>(); + assert_sync::>(); +} + +#[tokio::test] +async fn run_until_cancelled_test() { + let (waker, _) = new_count_waker(); + + { + let token = CancellationToken::new(); + + let fut = token.run_until_cancelled(std::future::pending::<()>()); + pin!(fut); + + assert_eq!(Poll::Pending, fut.as_mut().poll(&mut Context::from_waker(&waker))); + + token.cancel(); + + assert_eq!(Poll::Ready(None), fut.as_mut().poll(&mut Context::from_waker(&waker))); + } + + { + let (tx, rx) = oneshot::channel::<()>(); + + let token = CancellationToken::new(); + let fut = token.run_until_cancelled(async move { + rx.await.unwrap(); + 42 + }); + pin!(fut); + + assert_eq!(Poll::Pending, fut.as_mut().poll(&mut Context::from_waker(&waker))); + + tx.send(()).unwrap(); + + assert_eq!( + Poll::Ready(Some(42)), + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + } +} diff --git a/wrappers/tokio/impls/tokio-util/tests/task_tracker.rs b/wrappers/tokio/impls/tokio-util/tests/task_tracker.rs new file mode 100644 index 00000000..b818fcd2 --- /dev/null +++ b/wrappers/tokio/impls/tokio-util/tests/task_tracker.rs @@ -0,0 +1,178 @@ +#![warn(rust_2018_idioms)] + +use shuttle_tokio_util_impl::task::TaskTracker; +use tokio_test::{assert_pending, assert_ready, task}; + +#[tokio::test] +async fn open_close() { + let tracker = TaskTracker::new(); + assert!(!tracker.is_closed()); + assert!(tracker.is_empty()); + assert_eq!(tracker.len(), 0); + + tracker.close(); + assert!(tracker.is_closed()); + assert!(tracker.is_empty()); + assert_eq!(tracker.len(), 0); + + tracker.reopen(); + assert!(!tracker.is_closed()); + tracker.reopen(); + assert!(!tracker.is_closed()); + + assert!(tracker.is_empty()); + assert_eq!(tracker.len(), 0); + + tracker.close(); + assert!(tracker.is_closed()); + tracker.close(); + assert!(tracker.is_closed()); + + assert!(tracker.is_empty()); + assert_eq!(tracker.len(), 0); +} + +#[tokio::test] +async fn token_len() { + let tracker = TaskTracker::new(); + + let mut tokens = Vec::new(); + for i in 0..10 { + assert_eq!(tracker.len(), i); + tokens.push(tracker.token()); + } + + assert!(!tracker.is_empty()); + assert_eq!(tracker.len(), 10); + + for (i, token) in tokens.into_iter().enumerate() { + drop(token); + assert_eq!(tracker.len(), 9 - i); + } +} + +#[tokio::test] +async fn notify_immediately() { + let tracker = TaskTracker::new(); + tracker.close(); + + let mut wait = task::spawn(tracker.wait()); + assert_ready!(wait.poll()); +} + +#[tokio::test] +async fn notify_immediately_on_reopen() { + let tracker = TaskTracker::new(); + tracker.close(); + + let mut wait = task::spawn(tracker.wait()); + tracker.reopen(); + assert_ready!(wait.poll()); +} + +#[tokio::test] +async fn notify_on_close() { + let tracker = TaskTracker::new(); + + let mut wait = task::spawn(tracker.wait()); + + assert_pending!(wait.poll()); + tracker.close(); + assert_ready!(wait.poll()); +} + +#[tokio::test] +async fn notify_on_close_reopen() { + let tracker = TaskTracker::new(); + + let mut wait = task::spawn(tracker.wait()); + + assert_pending!(wait.poll()); + tracker.close(); + tracker.reopen(); + assert_ready!(wait.poll()); +} + +#[tokio::test] +async fn notify_on_last_task() { + let tracker = TaskTracker::new(); + tracker.close(); + let token = tracker.token(); + + let mut wait = task::spawn(tracker.wait()); + assert_pending!(wait.poll()); + drop(token); + assert_ready!(wait.poll()); +} + +#[tokio::test] +async fn notify_on_last_task_respawn() { + let tracker = TaskTracker::new(); + tracker.close(); + let token = tracker.token(); + + let mut wait = task::spawn(tracker.wait()); + assert_pending!(wait.poll()); + drop(token); + let token2 = tracker.token(); + assert_ready!(wait.poll()); + drop(token2); +} + +#[tokio::test] +async fn no_notify_on_respawn_if_open() { + let tracker = TaskTracker::new(); + let token = tracker.token(); + + let mut wait = task::spawn(tracker.wait()); + assert_pending!(wait.poll()); + drop(token); + let token2 = tracker.token(); + assert_pending!(wait.poll()); + drop(token2); +} + +#[tokio::test] +async fn close_during_exit() { + const ITERS: usize = 5; + + for close_spot in 0..=ITERS { + let tracker = TaskTracker::new(); + let tokens: Vec<_> = (0..ITERS).map(|_| tracker.token()).collect(); + + let mut wait = task::spawn(tracker.wait()); + + for (i, token) in tokens.into_iter().enumerate() { + assert_pending!(wait.poll()); + if i == close_spot { + tracker.close(); + assert_pending!(wait.poll()); + } + drop(token); + } + + if close_spot == ITERS { + assert_pending!(wait.poll()); + tracker.close(); + } + + assert_ready!(wait.poll()); + } +} + +#[tokio::test] +async fn notify_many() { + let tracker = TaskTracker::new(); + + let mut waits: Vec<_> = (0..10).map(|_| task::spawn(tracker.wait())).collect(); + + for wait in &mut waits { + assert_pending!(wait.poll()); + } + + tracker.close(); + + for wait in &mut waits { + assert_ready!(wait.poll()); + } +} diff --git a/wrappers/tokio/impls/tokio/Cargo.toml b/wrappers/tokio/impls/tokio/Cargo.toml new file mode 100644 index 00000000..d98124b7 --- /dev/null +++ b/wrappers/tokio/impls/tokio/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "shuttle-tokio-impl" +version = "0.1.0" +edition = "2021" + +[features] +# Include nothing by default +default = [] + +# enable everything +full = ["tokio-orig/full", "shuttle-tokio-impl-inner/full"] + +fs = ["tokio-orig/fs", "shuttle-tokio-impl-inner/fs"] +io-util = ["tokio-orig/io-util", "shuttle-tokio-impl-inner/io-util"] +# stdin, stdout, stderr +io-std = ["tokio-orig/io-std", "shuttle-tokio-impl-inner/io-std"] +macros = ["tokio-orig/macros", "shuttle-tokio-impl-inner/macros"] +net = ["tokio-orig/net", "shuttle-tokio-impl-inner/net"] +process = ["tokio-orig/process", "shuttle-tokio-impl-inner/process"] +# Includes basic task execution capabilities +rt = ["tokio-orig/rt", "shuttle-tokio-impl-inner/rt"] +rt-multi-thread = ["tokio-orig/rt-multi-thread", "shuttle-tokio-impl-inner/rt-multi-thread"] +signal = ["tokio-orig/signal", "shuttle-tokio-impl-inner/signal"] +sync = ["tokio-orig/sync", "shuttle-tokio-impl-inner/sync"] +test-util = ["tokio-orig/test-util", "shuttle-tokio-impl-inner/test-util"] +time = ["tokio-orig/time", "shuttle-tokio-impl-inner/time"] +# Unstable feature. Requires `--cfg tokio_unstable` to enable. +io-uring = ["tokio-orig/io-uring", "shuttle-tokio-impl-inner/io-uring"] +# Unstable feature. Requires `--cfg tokio_unstable` to enable. +taskdump = ["tokio-orig/taskdump", "shuttle-tokio-impl-inner/taskdump"] + +[dependencies] +cfg-if = "1.0" +tokio-orig = { package = "tokio", version = "1.43" } +shuttle-tokio-impl-inner = { version = "0.1.0", path = "./inner" } + +[dev-dependencies] +shuttle = { version = "*", path = "../../../../shuttle" } +futures = "0.3.15" diff --git a/wrappers/tokio/impls/tokio/inner/Cargo.toml b/wrappers/tokio/impls/tokio/inner/Cargo.toml new file mode 100644 index 00000000..0cb4e02b --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "shuttle-tokio-impl-inner" +version = "0.1.0" +edition = "2021" + +[features] +# Include nothing by default +default = [] + +# enable everything +full = [] + +fs = [] +io-util = [] +# stdin, stdout, stderr +io-std = [] +macros = [] +net = [] +process = [] +# Includes basic task execution capabilities +rt = [] +rt-multi-thread = [] +signal = [] +sync = [] +test-util = [] +time = [] +# Unstable feature. Requires `--cfg tokio_unstable` to enable. +io-uring = [] +# Unstable feature. Requires `--cfg tokio_unstable` to enable. +taskdump = [] + +[dependencies] +criterion = { version = "0.5", features = [ + "async", + "async_tokio", + "html_reports", +] } +futures = "0.3.24" +pin-project = "1.0.12" +regex = "1.10.1" +shuttle = { version = "*", path = "../../../../../shuttle" } +smallvec = "1.6.1" +test-log = { version = "0.2.11", default-features = false, features = [ + "trace", +] } +tracing = { version = "0.1.21", default-features = false, features = ["std"] } +tracing-subscriber = { version = "0.3.9", features = ["env-filter"] } +tokio = { version = "1.42", features = ["macros", "time", "sync"] } +tokio-macros = { package = "shuttle-tokio-macros-impl", path = "../../tokio-macros", version = "*" } + +[dev-dependencies] +assert_matches = "1.5" diff --git a/wrappers/tokio/impls/tokio/inner/src/lib.rs b/wrappers/tokio/impls/tokio/inner/src/lib.rs new file mode 100644 index 00000000..f556c2c6 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/lib.rs @@ -0,0 +1,197 @@ +//! This is the "impl-inner" crate implementing [tokio] support for [`Shuttle`]. +//! This crate should not be depended on directly, the intended way to use this crate is via +//! the `shuttle-tokio` crate and feature flag `shuttle`. +//! +//! [`Shuttle`]: +//! +//! [`tokio`]: + +use shuttle::scheduler::{PctScheduler, RandomScheduler}; +use shuttle::{PortfolioRunner, Runner}; +use std::panic; +use tracing_subscriber::util::SubscriberInitExt as _; + +pub mod runtime; +pub mod sync; +pub mod task; +pub mod time; + +pub use task::spawn; + +pub use tokio::pin; + +/// Implementation detail of the `select!` macro. This macro is **not** +/// intended to be used as part of the public API and is permitted to +/// change. +#[doc(hidden)] +pub use tokio_macros::select_priv_declare_output_enum; + +/// Implementation detail of the `select!` macro. This macro is **not** +/// intended to be used as part of the public API and is permitted to +/// change. +#[doc(hidden)] +pub use tokio_macros::select_priv_clean_pattern; + +#[macro_use] +#[doc(hidden)] +pub mod macros; + +// TODO: Finish deprecation and move out of this crate +#[deprecated = "`default_shuttle_config` is going to be moved out of shuttle-tokio-impl-inner. Removing it will not be treated as a breaking change."] +#[doc(hidden)] +pub fn default_shuttle_config() -> shuttle::Config { + let mut config = shuttle::Config::new(); + config.stack_size = 0x0008_0000; + config.max_steps = shuttle::MaxSteps::FailAfter(10_000_000); + config +} + +// Default for the `test` macro expansions +#[doc(hidden)] +pub fn __default_shuttle_config() -> shuttle::Config { + let mut config = shuttle::Config::new(); + config.stack_size = 0x0008_0000; + config.max_steps = shuttle::MaxSteps::FailAfter(10_000_000); + config +} + +// This exists so that the test macro can expand to a call to this. +#[doc(hidden)] +pub fn __check(f: F, config: shuttle::Config, max_iterations: usize) +where + F: Fn() + Send + Sync + 'static, +{ + #[allow(deprecated)] + check(f, config, max_iterations) +} + +/// Helper function that allows failing tests to be easily replayed using Shuttle. +/// +/// Overall workflow: +/// 1. Run the failing test under Shuttle to generate a replayable schedule in a specific directory +/// (specified by the environment variable `SHUTTLE_TRACE_DIR`): +/// `$ env SHUTTLE_TRACE_DIR=./mydir cargo test --release --features shuttle-tests -- ` +/// This will generate a schedule file ./mydir/schedule000.txt +/// 2. To replay the failure, rerun the test with the environment variable `SHUTTLE_TRACE_FILE` set to +/// the schedule file: +/// `$ env RUST_BACKTRACE=1 SHUTTLE_TRACE_FILE=./mydir/schedule000.txt cargo test --release --features shuttle-tests -- ` +/// +/// Alternatively, if you already have a failing schedule string printed out by Shuttle (e.g., from +/// a dry-run failure) save the schedule string to a file, and run step 2 above with +/// `SHUTTLE_TRACE_FILE` pointing to that file. +/// +/// This function also initializes a tracing subscriber, if none has been set yet. +/// +/// The following environment variables influence Shuttle execution: +/// - `SHUTTLE_ITERATIONS` sets the number of iterations to run, overriding whatever default is set +/// in the test itself. +/// - `SHUTTLE_TIMEOUT_SECS` sets a time limit (in seconds) to run each test for. If set this will +/// ignore both `SHUTTLE_ITERATIONS` and the test's default iteration count, and instead run the +/// test as many times as possible within the given timeout. +/// - `SHUTTLE_PCT_MAX_DEPTH` sets the maximum depth parameter of the PCT scheduler, overriding the +/// default value 3. +/// - `SHUTTLE_SCHEDULER` sets the scheduler that Shuttle uses. The value can be either `PCT` (for the +/// PCT scheduler), or `PORTFOLIO` (which runs PCT and random in parallel). Any other value (or if +/// the variable is not defined) causes Shuttle to use the random scheduler. +/// - `SHUTTLE_INTERVAL_TICKS` sets the max number of ticks each `Interval` generates. If the variable +/// not defined, the value defaults to `usize::MAX` (essentially, each Interval will generate ticks +/// forever). Setting the value to 0 means Intervals don't generate any ticks. +/// - `SHUTTLE_HIDE_TRACE` initializes a tracing subscriber that swallows everything. This is useful +/// for code that accesses synchronization primitives in tracing statements, which causes schedules +/// to not replay across verbosity levels. The idea is that we can run randomized tests with something +/// like `RUST_LOG=trace SHUTTLE_HIDE_TRACE=true` without drowning in log messages. A failing schedule +/// can then still be replayed with `RUST_LOG=trace`, giving access to all log messages for debugging. +// TODO: Remove. It's around due to the volume of code still using it and due to `shuttle_tokio::test` being built on it. +#[deprecated = "`check` is going to be moved out of shuttle-tokio-impl-inner. Removing it will not be treated as a breaking change."] +#[doc(hidden)] +pub fn check(f: F, mut config: shuttle::Config, max_iterations: usize) +where + F: Fn() + Send + Sync + 'static, +{ + match std::env::var("SHUTTLE_HIDE_TRACE").as_ref().map(String::as_str) { + Ok("true" | "1") => { + _ = tracing_subscriber::fmt::Subscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_writer(std::io::sink) + .finish() + .try_init(); + } + _ => _ = tracing_subscriber::fmt::try_init(), + } + + if let Ok(path) = std::env::var("SHUTTLE_TRACE_FILE") { + // Don't spew the schedule out; it's already in a file! + config.failure_persistence = shuttle::FailurePersistence::None; + + let scheduler = shuttle::scheduler::ReplayScheduler::new_from_file(path).expect("could not read schedule file"); + let runner = shuttle::Runner::new(scheduler, config); + runner.run(f); + } else { + if let Ok(path) = std::env::var("SHUTTLE_TRACE_DIR") { + config.failure_persistence = shuttle::FailurePersistence::File(Some(std::path::PathBuf::from(path))); + } + + let max_iterations = if let Some(timeout) = std::env::var("SHUTTLE_TIMEOUT_SECS") + .ok() + .and_then(|v| v.parse::().ok()) + { + config.max_time = Some(std::time::Duration::from_secs(timeout)); + usize::MAX + } else { + std::env::var("SHUTTLE_ITERATIONS") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(max_iterations) + }; + + let max_depth = std::env::var("SHUTTLE_PCT_MAX_DEPTH") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(3); + + match std::env::var("SHUTTLE_SCHEDULER") { + Ok(s) if s == "PORTFOLIO" => { + let mut runner = PortfolioRunner::new(true, config); + runner.add(RandomScheduler::new(max_iterations)); + runner.add(PctScheduler::new(max_depth, max_iterations)); + runner.run(f); + } + Ok(s) if s == "PCT" => { + let scheduler = PctScheduler::new(max_depth, max_iterations); + std::thread::spawn(|| Runner::new(scheduler, config).run(f)) + .join() + .unwrap_or_else(|e| panic::resume_unwind(e)); + } + _ => { + let scheduler = RandomScheduler::new(max_iterations); + std::thread::spawn(|| Runner::new(scheduler, config).run(f)) + .join() + .unwrap_or_else(|e| panic::resume_unwind(e)); + } + } + } +} + +/// Run the given function under a scheduler that checks whether the function +/// contains randomness which is not controlled by Shuttle. +/// Each iteration will check a different random schedule and replay that schedule once. +#[deprecated = "`check` is going to be moved out of shuttle-tokio-impl-inner. Removing it will not be treated as a breaking change."] +#[doc(hidden)] +pub fn check_for_uncontrolled_nondeterminism(f: F, config: shuttle::Config, max_iterations: usize) +where + F: Fn() + Send + Sync + 'static, +{ + use shuttle::scheduler::UncontrolledNondeterminismCheckScheduler; + let _ = tracing_subscriber::fmt::try_init(); + + let scheduler = UncontrolledNondeterminismCheckScheduler::new(RandomScheduler::new(max_iterations)); + + std::thread::spawn(|| Runner::new(scheduler, config).run(f)) + .join() + .unwrap_or_else(|e| panic::resume_unwind(e)); +} + +// Note that we are just pubbing `test`, and not `main`. `main` and `test` share the same code path, so +// both should work, but `main` has not been tested, and we have also yet to experience a need for running +// the `main` function under Shuttle. +pub use tokio_macros::test; diff --git a/wrappers/tokio/impls/tokio/inner/src/macros/mod.rs b/wrappers/tokio/impls/tokio/inner/src/macros/mod.rs new file mode 100644 index 00000000..570ac13e --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/macros/mod.rs @@ -0,0 +1,6 @@ +#[macro_use] +mod select; + +// Includes re-exports needed to implement macros +#[doc(hidden)] +pub mod support; diff --git a/wrappers/tokio/impls/tokio/inner/src/macros/select.rs b/wrappers/tokio/impls/tokio/inner/src/macros/select.rs new file mode 100644 index 00000000..7d09f0a9 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/macros/select.rs @@ -0,0 +1,852 @@ +//! This file is lifted from [tokio/src/macros/select.rs](https://github.com/tokio-rs/tokio/blob/911ab21d7012a50e53971ad1292a9f18de22d4c8/tokio/src/macros/select.rs), +//! and has had the following changes applied to it: +//! 1. Documentation (what is gated behind `doc!`) has been removed +//! 2. `poll_budget_available` is changed to always be `Poll::Ready`, as we don't model cooperative scheduling. +//! 3. `thread_rng_n` is changed to be Shuttle-compatible. + +#[macro_export] +macro_rules! select { + // Uses a declarative macro to do **most** of the work. While it is possible + // to implement fully with a declarative macro, a procedural macro is used + // to enable improved error messages. + // + // The macro is structured as a tt-muncher. All branches are processed and + // normalized. Once the input is normalized, it is passed to the top-most + // rule. When entering the macro, `@{ }` is inserted at the front. This is + // used to collect the normalized input. + // + // The macro only recurses once per branch. This allows using `select!` + // without requiring the user to increase the recursion limit. + + // All input is normalized, now transform. + (@ { + // The index of the future to poll first (in bias mode), or the RNG + // expression to use to pick a future to poll first. + start=$start:expr; + + // One `_` for each branch in the `select!` macro. Passing this to + // `count!` converts $skip to an integer. + ( $($count:tt)* ) + + // Normalized select branches. `( $skip )` is a set of `_` characters. + // There is one `_` for each select branch **before** this one. Given + // that all input futures are stored in a tuple, $skip is useful for + // generating a pattern to reference the future for the current branch. + // $skip is also used as an argument to `count!`, returning the index of + // the current select branch. + $( ( $($skip:tt)* ) $bind:pat = $fut:expr, if $c:expr => $handle:expr, )+ + + // Fallback expression used when all select branches have been disabled. + ; $else:expr + + }) => {{ + // Enter a context where stable "function-like" proc macros can be used. + // + // This module is defined within a scope and should not leak out of this + // macro. + #[doc(hidden)] + mod __tokio_select_util { + // Generate an enum with one variant per select branch + $crate::select_priv_declare_output_enum!( ( $($count)* ) ); + } + + // `tokio::macros::support` is a public, but doc(hidden) module + // including a re-export of all types needed by this macro. + use $crate::macros::support::Future; + use $crate::macros::support::Pin; + use $crate::macros::support::Poll::{Ready, Pending}; + + const BRANCHES: u32 = $crate::count!( $($count)* ); + + let mut disabled: __tokio_select_util::Mask = Default::default(); + + // First, invoke all the pre-conditions. For any that return true, + // set the appropriate bit in `disabled`. + $( + if !$c { + let mask: __tokio_select_util::Mask = 1 << $crate::count!( $($skip)* ); + disabled |= mask; + } + )* + + // Create a scope to separate polling from handling the output. This + // adds borrow checker flexibility when using the macro. + let mut output = { + // Store each future directly first (that is, without wrapping the future in a call to + // `IntoFuture::into_future`). This allows the `$fut` expression to make use of + // temporary lifetime extension. + // + // https://doc.rust-lang.org/1.58.1/reference/destructors.html#temporary-lifetime-extension + let futures_init = ($( $fut, )+); + + // Safety: Nothing must be moved out of `futures`. This is to + // satisfy the requirement of `Pin::new_unchecked` called below. + // + // We can't use the `pin!` macro for this because `futures` is a + // tuple and the standard library provides no way to pin-project to + // the fields of a tuple. + let mut futures = ($( $crate::macros::support::IntoFuture::into_future( + $crate::count_field!( futures_init.$($skip)* ) + ),)+); + + // This assignment makes sure that the `poll_fn` closure only has a + // reference to the futures, instead of taking ownership of them. + // This mitigates the issue described in + // + let mut futures = &mut futures; + + $crate::macros::support::poll_fn(|cx| { + // Return `Pending` when the task budget is depleted since budget-aware futures + // are going to yield anyway and other futures will not cooperate. + + // SHUTTLE_CHANGES: The line is the same, but `poll_budget_available` always resolves to `Poll::Ready` + ::std::task::ready!($crate::macros::support::poll_budget_available(cx)); + + // Track if any branch returns pending. If no branch completes + // **or** returns pending, this implies that all branches are + // disabled. + let mut is_pending = false; + + // Choose a starting index to begin polling the futures at. In + // practice, this will either be a pseudo-randomly generated + // number by default, or the constant 0 if `biased;` is + // supplied. + let start = $start; + + for i in 0..BRANCHES { + let branch; + #[allow(clippy::modulo_one)] + { + branch = (start + i) % BRANCHES; + } + match branch { + $( + #[allow(unreachable_code)] + $crate::count!( $($skip)* ) => { + // First, if the future has previously been + // disabled, do not poll it again. This is done + // by checking the associated bit in the + // `disabled` bit field. + let mask = 1 << branch; + + if disabled & mask == mask { + // The future has been disabled. + continue; + } + + // Extract the future for this branch from the + // tuple + let ( $($skip,)* fut, .. ) = &mut *futures; + + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; + + // Try polling it + let out = match Future::poll(fut, cx) { + Ready(out) => out, + Pending => { + // Track that at least one future is + // still pending and continue polling. + is_pending = true; + continue; + } + }; + + // Disable the future from future polling. + disabled |= mask; + + // The future returned a value, check if matches + // the specified pattern. + #[allow(unused_variables)] + #[allow(unused_mut)] + match &out { + $crate::select_priv_clean_pattern!($bind) => {} + _ => continue, + } + + // The select is complete, return the value + return Ready($crate::select_variant!(__tokio_select_util::Out, ($($skip)*))(out)); + } + )* + _ => unreachable!("reaching this means there probably is an off by one bug"), + } + } + + if is_pending { + Pending + } else { + // All branches have been disabled. + Ready(__tokio_select_util::Out::Disabled) + } + }).await + }; + + match output { + $( + $crate::select_variant!(__tokio_select_util::Out, ($($skip)*) ($bind)) => $handle, + )* + __tokio_select_util::Out::Disabled => $else, + _ => unreachable!("failed to match bind"), + } + }}; + + // ==== Normalize ===== + + // These rules match a single `select!` branch and normalize it for + // processing by the first rule. + + (@ { start=$start:expr; $($t:tt)* } ) => { + // No `else` branch + $crate::select!(@{ start=$start; $($t)*; panic!("all branches are disabled and there is no else branch") }) + }; + (@ { start=$start:expr; $($t:tt)* } else => $else:expr $(,)?) => { + $crate::select!(@{ start=$start; $($t)*; $else }) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr, if $c:expr => $h:block, $($r:tt)* ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if $c => $h, } $($r)*) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr => $h:block, $($r:tt)* ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if true => $h, } $($r)*) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr, if $c:expr => $h:block $($r:tt)* ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if $c => $h, } $($r)*) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr => $h:block $($r:tt)* ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if true => $h, } $($r)*) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr, if $c:expr => $h:expr ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if $c => $h, }) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr => $h:expr ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if true => $h, }) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr, if $c:expr => $h:expr, $($r:tt)* ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if $c => $h, } $($r)*) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr => $h:expr, $($r:tt)* ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if true => $h, } $($r)*) + }; + + // ===== Entry point ===== + + ($(biased;)? else => $else:expr $(,)? ) => {{ + $else + }}; + + (biased; $p:pat = $($t:tt)* ) => { + $crate::select!(@{ start=0; () } $p = $($t)*) + }; + + ( $p:pat = $($t:tt)* ) => { + // Randomly generate a starting point. This makes `select!` a bit more + // fair and avoids always polling the first future. + + // SHUTTLE_CHANGHES: `thread_rng_n` is changed to be Shuttle-compatible + $crate::select!(@{ start={ $crate::macros::support::thread_rng_n(BRANCHES) }; () } $p = $($t)*) + }; + + () => { + compile_error!("select! requires at least one branch.") + }; +} + +// And here... we manually list out matches for up to 64 branches... I'm not +// happy about it either, but this is how we manage to use a declarative macro! + +#[macro_export] +#[doc(hidden)] +macro_rules! count { + () => { + 0 + }; + (_) => { + 1 + }; + (_ _) => { + 2 + }; + (_ _ _) => { + 3 + }; + (_ _ _ _) => { + 4 + }; + (_ _ _ _ _) => { + 5 + }; + (_ _ _ _ _ _) => { + 6 + }; + (_ _ _ _ _ _ _) => { + 7 + }; + (_ _ _ _ _ _ _ _) => { + 8 + }; + (_ _ _ _ _ _ _ _ _) => { + 9 + }; + (_ _ _ _ _ _ _ _ _ _) => { + 10 + }; + (_ _ _ _ _ _ _ _ _ _ _) => { + 11 + }; + (_ _ _ _ _ _ _ _ _ _ _ _) => { + 12 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _) => { + 13 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 14 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 15 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 16 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 17 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 18 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 19 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 20 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 21 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 22 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 23 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 24 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 25 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 26 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 27 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 28 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 29 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 30 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 31 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 32 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 33 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 34 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 35 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 36 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 37 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 38 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 39 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 40 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 41 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 42 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 43 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 44 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 45 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 46 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 47 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 48 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 49 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 50 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 51 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 52 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 53 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 54 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 55 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 56 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 57 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 58 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 59 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 60 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 61 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 62 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 63 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 64 + }; +} + +#[macro_export] +#[doc(hidden)] +macro_rules! count_field { + ($var:ident. ) => { + $var.0 + }; + ($var:ident. _) => { + $var.1 + }; + ($var:ident. _ _) => { + $var.2 + }; + ($var:ident. _ _ _) => { + $var.3 + }; + ($var:ident. _ _ _ _) => { + $var.4 + }; + ($var:ident. _ _ _ _ _) => { + $var.5 + }; + ($var:ident. _ _ _ _ _ _) => { + $var.6 + }; + ($var:ident. _ _ _ _ _ _ _) => { + $var.7 + }; + ($var:ident. _ _ _ _ _ _ _ _) => { + $var.8 + }; + ($var:ident. _ _ _ _ _ _ _ _ _) => { + $var.9 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _) => { + $var.10 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _) => { + $var.11 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.12 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.13 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.14 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.15 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.16 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.17 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.18 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.19 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.20 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.21 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.22 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.23 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.24 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.25 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.26 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.27 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.28 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.29 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.30 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.31 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.32 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.33 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.34 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.35 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.36 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.37 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.38 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.39 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.40 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.41 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.42 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.43 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.44 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.45 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.46 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.47 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.48 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.49 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.50 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.51 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.52 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.53 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.54 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.55 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.56 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.57 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.58 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.59 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.60 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.61 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.62 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.63 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.64 + }; +} + +#[macro_export] +#[doc(hidden)] +macro_rules! select_variant { + ($($p:ident)::*, () $($t:tt)*) => { + $($p)::*::_0 $($t)* + }; + ($($p:ident)::*, (_) $($t:tt)*) => { + $($p)::*::_1 $($t)* + }; + ($($p:ident)::*, (_ _) $($t:tt)*) => { + $($p)::*::_2 $($t)* + }; + ($($p:ident)::*, (_ _ _) $($t:tt)*) => { + $($p)::*::_3 $($t)* + }; + ($($p:ident)::*, (_ _ _ _) $($t:tt)*) => { + $($p)::*::_4 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _) $($t:tt)*) => { + $($p)::*::_5 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_6 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_7 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_8 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_9 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_10 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_11 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_12 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_13 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_14 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_15 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_16 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_17 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_18 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_19 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_20 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_21 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_22 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_23 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_24 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_25 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_26 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_27 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_28 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_29 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_30 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_31 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_32 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_33 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_34 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_35 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_36 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_37 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_38 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_39 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_40 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_41 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_42 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_43 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_44 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_45 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_46 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_47 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_48 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_49 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_50 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_51 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_52 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_53 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_54 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_55 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_56 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_57 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_58 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_59 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_60 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_61 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_62 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_63 $($t)* + }; +} diff --git a/wrappers/tokio/impls/tokio/inner/src/macros/support.rs b/wrappers/tokio/impls/tokio/inner/src/macros/support.rs new file mode 100644 index 00000000..6b2eb154 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/macros/support.rs @@ -0,0 +1,19 @@ +pub use std::future::poll_fn; + +// SHUTTLE_CHANGES: Changed to use rand from Shuttle +#[doc(hidden)] +pub fn thread_rng_n(n: u32) -> u32 { + use shuttle::rand::RngCore; + shuttle::rand::thread_rng().next_u32() % n +} + +// SHUTTLE_CHANGES: Changed to always be `Poll::Ready(())` (ie. same as Tokio when coop is not enabled). We don't model cooperative scheduling. +#[doc(hidden)] +#[inline] +pub fn poll_budget_available(_: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) +} + +pub use std::future::{Future, IntoFuture}; +pub use std::pin::Pin; +pub use std::task::{Context, Poll}; diff --git a/wrappers/tokio/impls/tokio/inner/src/runtime.rs b/wrappers/tokio/impls/tokio/inner/src/runtime.rs new file mode 100644 index 00000000..5291f7a5 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/runtime.rs @@ -0,0 +1,392 @@ +//! Shuttle stubs for the tokio runtime + +use crate::runtime::dump::Dump; +use crate::task::JoinHandle; +use criterion::async_executor::AsyncExecutor; +use std::fmt; +use std::future::Future; +use std::num::{NonZeroU32, NonZeroU64}; +use std::ops::Range; +use std::time::Duration; + +/// Runtime which is used to spawn and run async tasks +pub struct Runtime { + handle: Handle, +} + +impl Runtime { + /// Create a new Runtime with default configuration + pub fn new() -> std::io::Result { + Builder::new_multi_thread().enable_all().build() + } + + /// Spawns a future onto the runtime + pub fn spawn(&self, future: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.handle.spawn(future) + } + + // TODO: See comment in `Handle::spawn_blocking` + pub fn spawn_blocking(&self, func: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.handle.spawn_blocking(func) + } + + /// Runs a future to completion on this runtime. + pub fn block_on(&self, future: F) -> F::Output { + shuttle::future::block_on(future) + } + + /// Returns a handle to the runtime's spawner. + pub fn handle(&self) -> &Handle { + &self.handle + } +} + +impl AsyncExecutor for Runtime { + fn block_on(&self, future: impl Future) -> T { + self.block_on(future) + } +} + +impl AsyncExecutor for &Runtime { + fn block_on(&self, future: impl Future) -> T { + (*self).block_on(future) + } +} + +// NOTE: This only exists to get stuff to compile, and currently returns `0`s, `false` and `Duration::new(0, 0)`s +// Protected behind tokio_unstable and feature rt. +#[derive(Clone, Debug)] +pub struct RuntimeMetrics {} + +impl RuntimeMetrics { + pub fn num_workers(&self) -> usize { + 0 + } + + pub fn num_blocking_threads(&self) -> usize { + 0 + } + + pub fn active_tasks_count(&self) -> usize { + 0 + } + + pub fn num_alive_tasks(&self) -> usize { + 0 + } + + pub fn num_idle_blocking_threads(&self) -> usize { + 0 + } + + pub fn remote_schedule_count(&self) -> u64 { + 0 + } + + pub fn budget_forced_yield_count(&self) -> u64 { + 0 + } + + pub fn worker_park_count(&self, _worker: usize) -> u64 { + 0 + } + + pub fn worker_noop_count(&self, _worker: usize) -> u64 { + 0 + } + + pub fn worker_steal_count(&self, _worker: usize) -> u64 { + 0 + } + + pub fn worker_steal_operations(&self, _worker: usize) -> u64 { + 0 + } + + pub fn worker_poll_count(&self, _worker: usize) -> u64 { + 0 + } + + pub fn worker_total_busy_duration(&self, _worker: usize) -> Duration { + Duration::new(0, 0) + } + + pub fn worker_local_schedule_count(&self, _worker: usize) -> u64 { + 0 + } + + pub fn worker_overflow_count(&self, _worker: usize) -> u64 { + 0 + } + + pub fn injection_queue_depth(&self) -> usize { + 0 + } + + pub fn worker_local_queue_depth(&self, _worker: usize) -> usize { + 0 + } + + pub fn poll_count_histogram_enabled(&self) -> bool { + false + } + + pub fn poll_count_histogram_num_buckets(&self) -> usize { + 0 + } + + pub fn poll_count_histogram_bucket_range(&self, _bucket: usize) -> Range { + Duration::new(0, 0)..Duration::new(0, 0) + } + + pub fn poll_count_histogram_bucket_count(&self, _worker: usize, _bucket: usize) -> u64 { + 0 + } + + pub fn worker_mean_poll_time(&self, _worker: usize) -> Duration { + Duration::new(0, 0) + } + + pub fn blocking_queue_depth(&self) -> usize { + 0 + } +} + +// Protected behind cfg! net +impl RuntimeMetrics { + pub fn io_driver_fd_registered_count(&self) -> u64 { + 0 + } + + pub fn io_driver_fd_deregistered_count(&self) -> u64 { + 0 + } + + pub fn io_driver_ready_count(&self) -> u64 { + 0 + } +} + +/// Builds a runtime with custom configuration +pub struct Builder {} + +impl Builder { + /// Returns a new builder with the current thread scheduler + pub fn new_current_thread() -> Self { + Self {} + } + + /// Returns a new builder with the multi thread scheduler + pub fn new_multi_thread() -> Self { + Self::new_current_thread() + } + + /// Enables both I/O and time drivers + pub fn enable_all(&mut self) -> &mut Self { + self + } + + /// Enables the time driver + pub fn enable_time(&mut self) -> &mut Self { + self + } + + /// Start tasks paused + pub fn start_paused(&mut self, _paused: bool) -> &mut Self { + self + } + + /// Sets thread name + pub fn thread_name(&mut self, _name: impl Into) -> &mut Self { + self + } + + /// Sets a function used to generate the name of threads spawned by the `Runtime`'s thread pool. + pub fn thread_name_fn(&mut self, _f: F) -> &mut Self + where + F: Fn() -> String + Send + Sync + 'static, + { + self + } + + /// Sets the number of worker threads the `Runtime` will use. + pub fn worker_threads(&mut self, val: usize) -> &mut Self { + assert!(val > 0, "Worker threads cannot be set to 0"); + self + } + + /// Creates the configured `Runtime` + pub fn build(&mut self) -> std::io::Result { + Ok(Runtime { handle: Handle {} }) + } +} + +#[derive(Debug, Clone)] +pub struct Handle {} + +#[derive(Debug)] +pub struct TryCurrentError {} + +impl std::fmt::Display for TryCurrentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("TryCurrentError") + } +} + +impl std::error::Error for TryCurrentError {} + +#[derive(Debug, PartialEq, Eq)] +pub enum RuntimeFlavor { + CurrentThread, + MultiThread, +} + +impl Handle { + // TODO?: Make this panic when outside of Shuttle? + /// Returns a `Handle` view over the currently running `Runtime` + pub fn current() -> Self { + Self {} + } + + // TODO: Add a hook to Shuttle to check whether there is currently an ExecutionState, and return an error if there is not. + pub fn try_current() -> Result { + Ok(Self {}) + } + + /// Spawns a future onto the runtime + pub fn spawn(&self, future: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + JoinHandle::new(shuttle::future::spawn(future)) + } + + // TODO: + // There is a case to be made for implementing this inside Shuttle, and having it map to `Task::from_closure` + // (over `Task::from_future`, which is "wrong"). + // Deciding not to do that for now, as the API is inherently tied to tokio (as is `spawn` and `shuttle::future::JoinHandle`), + // and I'd rather move those APIs out, than more in. We also generally don't want to encourage the use of `spawn_blocking`, + // since we don't have any notion of an executor dedicated to blocking operations, and the kind of things one would run which + // block the runtime are not the kind of things which should be run under Shuttle. It exists here to enable plug-and-play compilation. + // TODO END. + pub fn spawn_blocking(&self, func: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + JoinHandle::new(shuttle::future::spawn(async { func() })) + } + + /// Runs a future to completion on this `Handle`'s associated `Runtime`. + pub fn block_on(&self, future: F) -> F::Output { + shuttle::future::block_on(future) + } + + // NOTE: only available on `tokio_unstable` + // NOTE: Shuttle does not have a notion of runtimes so this always returns 1. + pub fn id(&self) -> Id { + Id(NonZeroU64::new(1).unwrap()) + } + + // NOTE: only available on `tokio_unstable` + pub fn metrics(&self) -> RuntimeMetrics { + RuntimeMetrics {} + } + + /// Captures a snapshot of the runtime's state. + pub async fn dump(&self) -> Dump { + unimplemented!(); + } + + /// Returns the flavor of the current `Runtime`. + pub fn runtime_flavor(&self) -> RuntimeFlavor { + unimplemented!() + } +} + +impl AsyncExecutor for Handle { + fn block_on(&self, future: impl Future) -> T { + self.block_on(future) + } +} + +impl AsyncExecutor for &Handle { + fn block_on(&self, future: impl Future) -> T { + (*self).block_on(future) + } +} + +// This whole module is unimplemented. It exists just to get code to compile. +mod dump { + #[derive(Debug)] + pub struct Dump {} + + #[derive(Debug)] + pub struct Tasks { + tasks: Vec, + } + + #[derive(Debug)] + pub struct Task {} + + #[derive(Debug)] + pub struct Trace {} + + impl Dump { + pub fn tasks(&self) -> &Tasks { + unimplemented!(); + } + } + + impl Tasks { + pub fn iter(&self) -> impl Iterator { + self.tasks.iter() + } + } + + impl Task { + pub fn id(&self) -> crate::task::Id { + unimplemented!() + } + + pub fn trace(&self) -> &Trace { + unimplemented!() + } + } + + impl std::fmt::Display for Trace { + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } + } +} + +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] +pub struct Id(NonZeroU64); + +impl From for Id { + fn from(value: NonZeroU64) -> Self { + Id(value) + } +} + +impl From for Id { + fn from(value: NonZeroU32) -> Self { + Id(value.into()) + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/wrappers/tokio/impls/tokio/inner/src/sync/broadcast.rs b/wrappers/tokio/impls/tokio/inner/src/sync/broadcast.rs new file mode 100644 index 00000000..eee3ba2b --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/sync/broadcast.rs @@ -0,0 +1,71 @@ +//! A multi-producer, multi-consumer broadcast queue. Each sent value is seen by +//! all consumers. +//! +//! Currently a stub implementation, we need this around for cargo check but we +//! don't need this to actually run. + +use std::marker::PhantomData; + +pub mod error { + pub use tokio::sync::broadcast::error::*; +} + +pub fn channel(_capacity: usize) -> (Sender, Receiver) { + let tx = Sender(PhantomData); + let rx = Receiver(PhantomData); + (tx, rx) +} + +#[derive(Debug)] +pub struct Sender(PhantomData); + +#[derive(Debug)] +pub struct Receiver(PhantomData); + +impl Receiver { + pub fn len(&self) -> usize { + todo!() + } + + pub fn is_empty(&self) -> bool { + todo!() + } +} + +impl Receiver { + pub fn resubscribe(&self) -> Self { + todo!() + } + + pub async fn recv(&mut self) -> Result { + todo!() + } + + pub async fn try_recv(&mut self) -> Result { + todo!() + } +} + +impl Sender { + pub fn len(&self) -> usize { + todo!() + } + + pub fn is_empty(&self) -> bool { + todo!() + } + + pub fn send(&self, _value: T) -> Result> { + todo!() + } + + pub fn subscribe(&self) -> Receiver { + todo!() + } +} + +impl Clone for Sender { + fn clone(&self) -> Self { + Self(PhantomData) + } +} diff --git a/wrappers/tokio/impls/tokio/inner/src/sync/mod.rs b/wrappers/tokio/impls/tokio/inner/src/sync/mod.rs new file mode 100644 index 00000000..62cb4f7e --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/sync/mod.rs @@ -0,0 +1,73 @@ +mod mutex; +pub use mutex::{Mutex, MutexGuard, OwnedMutexGuard, TryLockError}; + +pub use shuttle::future::batch_semaphore::{AcquireError, TryAcquireError}; + +mod semaphore; +pub use semaphore::{OwnedSemaphorePermit, Semaphore, SemaphorePermit}; + +mod rwlock; +pub use rwlock::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}; + +pub mod broadcast; +pub mod mpsc; + +pub mod notify; +pub use notify::Notify; + +pub mod oneshot; +pub mod watch; + +pub mod time { + // Re-export for convenience + #[doc(no_inline)] + pub use std::time::Duration; +} + +pub mod futures { + pub use super::notify::Notified; +} + +mod once_cell; +pub use self::once_cell::OnceCell; + +#[cfg(test)] +mod test { + /// This example demonstrates a historic bug in the interaction of Tokio's `select` and Shuttle's `Mutex`, which caused + /// internal consistency violations in Shuttle. The same bug existed for `RwLock` as well. + fn select_mutex_bug() { + use crate::sync::mpsc; + use shuttle::sync::{Arc, Mutex}; + + // async wrapper for `Mutex::lock` + async fn async_lock(m: Arc>) { + *m.lock().unwrap(); + } + + shuttle::future::block_on(async { + let (tx, mut rx) = mpsc::unbounded_channel(); + let mutex = Arc::new(Mutex::new(())); + let mutex2 = mutex.clone(); + + let h1 = shuttle::future::spawn(async move { + tokio::select! { + biased; + _ = rx.recv() => {} + () = async_lock(mutex2) => {} + } + }); + + let h2 = shuttle::future::spawn(async move { + let _m = mutex.lock().unwrap(); + _ = tx.send(()); + }); + + futures::future::join_all([h1, h2]).await; + }); + } + + #[test_log::test] + fn check_select_mutex_bug() { + shuttle::check_dfs(select_mutex_bug, None); + } +} diff --git a/wrappers/tokio/impls/tokio/inner/src/sync/mpsc.rs b/wrappers/tokio/impls/tokio/inner/src/sync/mpsc.rs new file mode 100644 index 00000000..dd5fa224 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/sync/mpsc.rs @@ -0,0 +1,750 @@ +//! A multi-producer, single-consumer queue for sending values between +//! asynchronous tasks. + +use shuttle::future::{ + self, + batch_semaphore::{BatchSemaphore, Fairness, TryAcquireError}, +}; +use smallvec::SmallVec; +use std::fmt::{self, Debug}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; +use tracing::trace; + +pub use tokio::sync::mpsc::error; +use tokio::sync::mpsc::error::{SendError, TryRecvError, TrySendError}; + +const MAX_INLINE_MESSAGES: usize = 32; + +// === Base Channel === + +struct Channel { + // If all senders have left and the channel is empty, we want to ensure that the receiver is + // not blocked. To ensure this, we'll maintain the following invariant + // (state.known_senders == 0 && state.messages.is_empty()) == (recv_semaphore is closed) + bound: Option, // None for an unbounded channel, Some(k) for bounded channel of size k + recv_semaphore: Arc, // semaphore used to signal receivers + send_semaphore: Arc, // semaphore used to block senders. Also tracks whether the channel is closed for sending messages. + state: Arc>>, +} + +impl Debug for Channel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Channel {{ ")?; + write!(f, "recv_semaphore: {:?} ", self.recv_semaphore)?; + write!(f, "send_semaphore: {:?} ", self.send_semaphore)?; + write!(f, "state: {:?} ", self.state)?; + write!(f, "}}") + } +} + +struct ChannelState { + messages: SmallVec<[T; MAX_INLINE_MESSAGES]>, // messages in the channel + known_senders: usize, // number of senders referencing this channel +} + +impl Debug for ChannelState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ChannelState {{ ")?; + write!(f, "num_messages: {} ", self.messages.len())?; + write!(f, "known_senders {}", self.known_senders,)?; + write!(f, "}}") + } +} + +impl Channel { + fn new(bound: Option) -> Self { + let recv_semaphore = Arc::new(BatchSemaphore::new(0, Fairness::StrictlyFair)); + let send_semaphore = Arc::new(BatchSemaphore::new(bound.unwrap_or(usize::MAX), Fairness::StrictlyFair)); + + Self { + bound, + recv_semaphore, + send_semaphore, + state: Arc::new(Mutex::new(ChannelState { + messages: SmallVec::new(), + known_senders: 1, + })), + } + } + + // Send a message on the channel. Note that callers of this method must ensure that + // the channel has enough capacity for the send to be successful. + fn send(&self, message: T) -> Result<(), SendError> { + if self.is_closed() { + return Err(SendError(message)); + } + + let mut state = self.state.try_lock().unwrap(); + + if let Some(bound) = self.bound { + assert!(state.messages.len() < bound); + } + + state.messages.push(message); + trace!( + "sent message on channel {:p} num_messages {}", + self, + state.messages.len() + ); + + Ok(()) + } + + // Receive a message from the channel if one is available + fn recv(&self) -> Option { + let mut state = self.state.try_lock().unwrap(); + trace!( + "receiving message on channel {:p} with {} messages", + self, + state.messages.len() + ); + + // TODO / nit: If we update `is_empty` / `len` / `close` to be `VectorClock`ed functions, then the code below will have wasteful clock work. + if state.messages.is_empty() { + None + } else { + let msg = Some(state.messages.remove(0)); + + if state.messages.is_empty() && state.known_senders == 0 { + trace!( + "closing receiving semaphore {:p} for channel {:p} after having drained the channel post last sender drop", + self.recv_semaphore, + self + ); + + // `close` is a scheduling point, so we need to release the lock on `state` here + drop(state); + + // To ensure the invariant above; when the receiver picks up the last message + // from a channel with no senders, it closes the recv_semaphore + self.recv_semaphore.close(); + } + + msg + } + } + + fn is_closed(&self) -> bool { + self.send_semaphore.is_closed() + } + + fn close(&self) { + trace!( + "closing sending semaphore {:p} for channel {:p}", + self.send_semaphore, + self + ); + self.send_semaphore.close(); + } + + fn drop_receiver(&self) { + trace!("closing channel {:p} on receiver drop", self); + + self.close(); + + // need to drop after releasing lock and closing semaphore to avoid deadlocks + let _unreceived_messages_to_drop = std::mem::take(&mut self.state.try_lock().unwrap().messages); + } + + fn drop_sender(&self) { + // Note that we deliberately limit how long we are holding the lock both here and below. + // We have to do this because `BatchSemaphore::close` is a scheduling point. If we were to hold + // the Mutex across a scheduling point, then we run the risk of trying to reacquire the lock, + // deadlocking on ourself. + let known_senders = { + let mut state = self.state.try_lock().unwrap(); + trace!( + "dropping sender for channel {:p} at count {:?}", + self, + state.known_senders + ); + + assert!(state.known_senders > 0); + state.known_senders -= 1; + state.known_senders + }; + + if known_senders == 0 { + self.close(); + + let no_messages_in_channel = { + let state = self.state.try_lock().unwrap(); + state.messages.is_empty() + }; + + // If there are messages, then the `recv_semaphore` will remain open until the last message is `recv`d. + if no_messages_in_channel { + trace!("closing semaphore {:p} on last sender drop", self.recv_semaphore); + // See invariant above; when the last sender leaves an empty channel, it + // closes the recv_semaphore + self.recv_semaphore.close_no_scheduling_point(); + } + } + } + + // TODO: This must be VectorClocked right? If not then we can use this as an AtomicBool/AtomicUsize without any clocking. + /// Returns the number of messages in the channel. + fn len(&self) -> usize { + self.state.try_lock().unwrap().messages.len() + } + + /// Checks if the channel is empty. + /// + /// This method returns `true` if the channel has no messages. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn is_bounded(&self) -> bool { + self.bound.is_some() + } +} + +/// Common building block to build [`Receiver`]/[`UnboundedReceiver`] atop. +struct ReceiverInternal { + chan: Arc>, +} + +impl fmt::Debug for ReceiverInternal { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{:?}", self.chan) + } +} + +impl ReceiverInternal { + pub fn new(chan: Arc>) -> Self { + Self { chan } + } + + /// Receives the next value for this receiver. + pub async fn recv(&mut self) -> Option { + if self.is_closed() && self.is_empty() { + return None; + } + + self.chan.recv_semaphore.acquire(1).await.ok()?; + let message = self.chan.recv()?; + + if self.chan.is_bounded() { + self.chan.send_semaphore.release(1); + } + + Some(message) + } + + /// Tries to receive the next value for this receiver. + pub fn try_recv(&mut self) -> Result { + match self.chan.recv_semaphore.try_acquire(1) { + Err(TryAcquireError::Closed) => Err(TryRecvError::Disconnected), + Err(TryAcquireError::NoPermits) => Err(TryRecvError::Empty), + Ok(()) => { + let message = self.chan.recv().expect( + "Internal Shuttle error. We acquired a permit for an empty channel. This should never happen.", + ); + if self.chan.is_bounded() { + self.chan.send_semaphore.release(1); + } + Ok(message) + } + } + } + + /// Blocking receive to call outside of asynchronous contexts. + pub fn blocking_recv(&mut self) -> Option { + if self.is_closed() && self.is_empty() { + return None; + } + + self.chan.recv_semaphore.acquire_blocking(1).ok()?; + self.chan.recv() + } + + /// Closes the receiving half of a channel, without dropping it. + pub fn close(&mut self) { + self.chan.close(); + } + + /// Checks if a channel is closed. + /// + /// This method returns `true` if the channel has been closed. The channel is closed + /// when all [`UnboundedSender`] have been dropped, or when [`UnboundedReceiver::close`] is called. + pub fn is_closed(&self) -> bool { + self.chan.is_closed() + } + + /// Polls to receive the next message on this channel. + /// + /// This method returns: + /// + /// * `Poll::Pending` if no messages are available but the channel is not + /// closed, or if a spurious failure happens. + /// * `Poll::Ready(Some(message))` if a message is available. + /// * `Poll::Ready(None)` if the channel has been closed and all messages + /// sent before it was closed have been received. + pub fn poll_recv(&mut self, _cx: &mut Context<'_>) -> Poll> { + unimplemented!() + } + + /// Checks if a channel is empty. + /// + /// This method returns `true` if the channel has no messages. + pub fn is_empty(&self) -> bool { + self.chan.is_empty() + } + + /// Returns the number of messages in the channel. + pub fn len(&self) -> usize { + self.chan.len() + } +} + +impl Drop for ReceiverInternal { + fn drop(&mut self) { + self.chan.drop_receiver(); + } +} + +/// Common building block to build [`Sender`]/[`UnboundedSender`] atop. +struct SenderInternal { + chan: Arc>, +} + +impl fmt::Debug for SenderInternal { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{:?}", self.chan) + } +} + +impl SenderInternal { + fn new(chan: Arc>) -> Self { + Self { chan } + } + + /// Sends a value, waiting until there is capacity. + pub async fn send(&self, message: T) -> Result<(), SendError> { + if self.chan.is_bounded() { + match self.chan.send_semaphore.acquire(1).await { + Ok(()) => {} + Err(_) => return Err(SendError(message)), + } + } + + self.chan.send(message)?; + self.chan.recv_semaphore.release(1); + + Ok(()) + } + + /// Completes when the receiver has dropped. + pub async fn closed(&self) { + unimplemented!() + } + + /// Attempts to immediately send a message on this `Sender` + pub fn try_send(&self, message: T) -> Result<(), TrySendError> { + match self.chan.send_semaphore.try_acquire(1) { + Err(TryAcquireError::Closed) => Err(TrySendError::Closed(message)), + Err(TryAcquireError::NoPermits) => Err(TrySendError::Full(message)), + Ok(()) => { + self.chan.send(message)?; + self.chan.recv_semaphore.release(1); + Ok(()) + } + } + } + + /// Blocking send to call outside of asynchronous contexts. + pub fn blocking_send(&self, message: T) -> Result<(), SendError> { + future::block_on(self.send(message)) + } + + /// Checks if the channel has been closed. This happens when the + /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is + /// called. + pub fn is_closed(&self) -> bool { + self.chan.is_closed() + } + + /// Waits for channel capacity. Once capacity to send one message is + /// available, it is reserved for the caller. + pub async fn reserve(&self) -> Result, SendError<()>> { + unimplemented!() + } + + /// Waits for channel capacity, moving the `Sender` and returning an owned + /// permit. Once capacity to send one message is available, it is reserved + /// for the caller. + pub async fn reserve_owned(self) -> Result, SendError<()>> { + unimplemented!() + } + + /// Returns `true` if senders belong to the same channel. + pub fn same_channel(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.chan, &other.chan) + } + + /// Returns the current capacity of the channel. + pub fn capacity(&self) -> usize { + self.chan.send_semaphore.available_permits() + } + + /// Returns the maximum buffer capacity of the channel. + pub fn max_capacity(&self) -> usize { + match self.chan.bound { + None => usize::MAX, + Some(k) => k, + } + } +} + +impl Clone for SenderInternal { + fn clone(&self) -> Self { + { + let mut state = self.chan.state.try_lock().unwrap(); + state.known_senders += 1; + } + + SenderInternal { + chan: self.chan.clone(), + } + } +} + +impl Drop for SenderInternal { + fn drop(&mut self) { + self.chan.drop_sender(); + } +} + +// === Unbounded Channel === + +/// Receive values from the associated `UnboundedSender`. +pub struct UnboundedReceiver { + inner: ReceiverInternal, +} + +impl fmt::Debug for UnboundedReceiver { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("UnboundedReceiver") + .field("chan", &self.inner) + .finish() + } +} + +/// Creates an unbounded mpsc channel for communicating between asynchronous +/// tasks without backpressure. +pub fn unbounded_channel() -> (UnboundedSender, UnboundedReceiver) { + let chan = Arc::new(Channel::new(None)); + let sender = UnboundedSender { + inner: SenderInternal::new(chan.clone()), + }; + let receiver = UnboundedReceiver { + inner: ReceiverInternal::new(chan), + }; + (sender, receiver) +} + +impl UnboundedReceiver { + /// Receives the next value for this receiver. + pub async fn recv(&mut self) -> Option { + self.inner.recv().await + } + + /// Tries to receive the next value for this receiver. + pub fn try_recv(&mut self) -> Result { + self.inner.try_recv() + } + + /// Blocking receive to call outside of asynchronous contexts. + pub fn blocking_recv(&mut self) -> Option { + self.inner.blocking_recv() + } + + /// Closes the receiving half of a channel, without dropping it. + pub fn close(&mut self) { + self.inner.close(); + } + + /// Checks if a channel is closed. + /// + /// This method returns `true` if the channel has been closed. The channel is closed + /// when all [`UnboundedSender`] have been dropped, or when [`UnboundedReceiver::close`] is called. + pub fn is_closed(&self) -> bool { + self.inner.is_closed() + } + + /// Polls to receive the next message on this channel. + /// + /// This method returns: + /// + /// * `Poll::Pending` if no messages are available but the channel is not + /// closed, or if a spurious failure happens. + /// * `Poll::Ready(Some(message))` if a message is available. + /// * `Poll::Ready(None)` if the channel has been closed and all messages + /// sent before it was closed have been received. + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_recv(cx) + } + + /// Checks if a channel is empty. + /// + /// This method returns `true` if the channel has no messages. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Returns the number of messages in the channel. + pub fn len(&self) -> usize { + self.inner.len() + } +} + +// == UnboundedSender == + +/// Send values to the associated `UnboundedReceiver`. +pub struct UnboundedSender { + inner: SenderInternal, +} + +// Note that this cannot be derived, as then we get a `T: Clone` bound, but `UnboundedSender` should +// be `Clone` even if `T` is not +impl Clone for UnboundedSender { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl fmt::Debug for UnboundedSender { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("UnboundedSender").field("chan", &self.inner).finish() + } +} + +impl UnboundedSender { + /// Attempts to send a message on this `UnboundedSender` without blocking. + pub fn send(&self, message: T) -> Result<(), SendError> { + future::block_on(self.inner.send(message)) + } + + /// Completes when the receiver has dropped. + pub async fn closed(&self) { + self.inner.closed().await; + } + + /// Checks if the channel has been closed. This happens when the + /// [`UnboundedReceiver`] is dropped, or when the + /// [`UnboundedReceiver::close`] method is called. + pub fn is_closed(&self) -> bool { + self.inner.is_closed() + } + + /// Returns `true` if senders belong to the same channel. + pub fn same_channel(&self, other: &Self) -> bool { + self.inner.same_channel(&other.inner) + } +} + +// ==== BOUNDED CHANNEL + +/// Receives values from the associated `Sender`. +pub struct Receiver { + inner: ReceiverInternal, +} + +impl fmt::Debug for Receiver { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Receiver").field("chan", &self.inner).finish() + } +} + +/// Creates a bounded mpsc channel for communicating between asynchronous tasks +/// with backpressure. +/// +/// The channel will buffer up to the provided number of messages. Once the +/// buffer is full, attempts to send new messages will wait until a message is +/// received from the channel. The provided buffer capacity must be at least 1. +/// +/// All data sent on `Sender` will become available on `Receiver` in the same +/// order as it was sent. +/// +/// The `Sender` can be cloned to `send` to the same channel from multiple code +/// locations. Only one `Receiver` is supported. +/// +/// If the `Receiver` is disconnected while trying to `send`, the `send` method +/// will return a `SendError`. Similarly, if `Sender` is disconnected while +/// trying to `recv`, the `recv` method will return `None`. +pub fn channel(bound: usize) -> (Sender, Receiver) { + let chan = Arc::new(Channel::new(Some(bound))); + let sender = Sender { + inner: SenderInternal::new(chan.clone()), + }; + let receiver = Receiver { + inner: ReceiverInternal { chan }, + }; + (sender, receiver) +} + +impl Receiver { + /// Receives the next value for this receiver. + pub async fn recv(&mut self) -> Option { + self.inner.recv().await + } + + /// Tries to receive the next value for this receiver. + pub fn try_recv(&mut self) -> Result { + self.inner.try_recv() + } + + /// Blocking receive to call outside of asynchronous contexts. + pub fn blocking_recv(&mut self) -> Option { + self.inner.blocking_recv() + } + + /// Closes the receiving half of a channel, without dropping it. + pub fn close(&mut self) { + self.inner.close(); + } + + /// Polls to receive the next message on this channel. + /// + /// This method returns: + /// + /// * `Poll::Pending` if no messages are available but the channel is not + /// closed, or if a spurious failure happens. + /// * `Poll::Ready(Some(message))` if a message is available. + /// * `Poll::Ready(None)` if the channel has been closed and all messages + /// sent before it was closed have been received. + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_recv(cx) + } + + /// Returns the number of messages in the channel. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Checks if a channel is empty. + /// + /// This method returns `true` if the channel has no messages. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Checks if a channel is closed. + /// + /// This method returns `true` if the channel has been closed. The channel is closed + /// when all [`Sender`] have been dropped, or when [`Receiver::close`] is called. + pub fn is_closed(&self) -> bool { + self.inner.is_closed() + } +} + +impl Unpin for Receiver {} + +// === BOUNDED SENDER === + +/// Sends values to the associated `Receiver`. +pub struct Sender { + inner: SenderInternal, +} + +// Note that this cannot be derived, as then we get a `T: Clone` bound, but `Sender` should +// be `Clone` even if `T` is not +impl Clone for Sender { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +/// Permits to send one value into the channel. +pub struct Permit { + chan: Arc>, +} + +impl fmt::Debug for Permit { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Permit").field("chan", &self.chan).finish() + } +} + +/// Owned permit to send one value into the channel. +/// +/// This is identical to the [`Permit`] type, except that it moves the sender +/// rather than borrowing it. +pub struct OwnedPermit { + chan: Option>>, +} + +impl fmt::Debug for OwnedPermit { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("OwnedPermit").field("chan", &self.chan).finish() + } +} + +impl Sender { + /// Sends a value, waiting until there is capacity. + pub async fn send(&self, message: T) -> Result<(), SendError> { + self.inner.send(message).await + } + + /// Completes when the receiver has dropped. + pub async fn closed(&self) { + self.inner.closed().await; + } + + /// Attempts to immediately send a message on this `Sender` + pub fn try_send(&self, message: T) -> Result<(), TrySendError> { + self.inner.try_send(message) + } + + /// Blocking send to call outside of asynchronous contexts. + pub fn blocking_send(&self, message: T) -> Result<(), SendError> { + self.inner.blocking_send(message) + } + + /// Checks if the channel has been closed. This happens when the + /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is + /// called. + pub fn is_closed(&self) -> bool { + self.inner.is_closed() + } + + /// Waits for channel capacity. Once capacity to send one message is + /// available, it is reserved for the caller. + pub async fn reserve(&self) -> Result, SendError<()>> { + self.inner.reserve().await + } + + /// Waits for channel capacity, moving the `Sender` and returning an owned + /// permit. Once capacity to send one message is available, it is reserved + /// for the caller. + pub async fn reserve_owned(self) -> Result, SendError<()>> { + self.inner.reserve_owned().await + } + + /// Returns `true` if senders belong to the same channel. + pub fn same_channel(&self, other: &Self) -> bool { + self.inner.same_channel(&other.inner) + } + + /// Returns the current capacity of the channel. + pub fn capacity(&self) -> usize { + self.inner.capacity() + } + + /// Returns the maximum buffer capacity of the channel. + pub fn max_capacity(&self) -> usize { + self.inner.max_capacity() + } +} + +impl fmt::Debug for Sender { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Sender").field("chan", &self.inner).finish() + } +} diff --git a/wrappers/tokio/impls/tokio/inner/src/sync/mutex.rs b/wrappers/tokio/impls/tokio/inner/src/sync/mutex.rs new file mode 100644 index 00000000..78762cf6 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/sync/mutex.rs @@ -0,0 +1,212 @@ +//! An asynchronous `Mutex`-like type. + +use shuttle::future::batch_semaphore::{BatchSemaphore, Fairness, TryAcquireError}; +use std::cell::UnsafeCell; +use std::error::Error; +use std::fmt::{self, Debug, Display}; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; +use std::thread; +use tracing::trace; + +/// An asynchronous semaphore +pub struct Mutex { + semaphore: BatchSemaphore, + inner: UnsafeCell, +} + +/// A handle to a held `Mutex`. The guard can be held across any `.await` point +/// as it is [`Send`]. +pub struct MutexGuard<'a, T: ?Sized> { + mutex: &'a Mutex, +} + +impl Display for MutexGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt(&**self, f) + } +} + +/// An owned handle to a held `Mutex`. +pub struct OwnedMutexGuard { + mutex: Arc>, +} + +/// Error returned from the [`Mutex::try_lock`], `RwLock::try_read` and +/// `RwLock::try_write` functions. +#[derive(Debug)] +pub struct TryLockError(pub(super) ()); + +impl Display for TryLockError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "operation would block") + } +} + +impl Error for TryLockError {} + +// As long as T: Send, it's fine to send and share Mutex between threads. +// If T was not Send, sending and sharing a Mutex would be bad, since you can +// access T through Mutex. +unsafe impl Send for Mutex where T: ?Sized + Send {} +unsafe impl Sync for Mutex where T: ?Sized + Send {} +unsafe impl Sync for MutexGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl Sync for OwnedMutexGuard where T: ?Sized + Send + Sync {} + +impl Mutex { + /// Creates a new lock in an unlocked state ready for use. + pub fn new(t: T) -> Self + where + T: Sized, + { + Self { + semaphore: BatchSemaphore::new(1, Fairness::StrictlyFair), + inner: UnsafeCell::new(t), + } + } + + async fn acquire(&self) { + trace!("acquiring lock {:p}", self); + self.semaphore.acquire(1).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and + // we own it exclusively, which means that this can never happen. + if !thread::panicking() { + unreachable!() + } + }); + trace!("acquired lock {:p}", self); + } + + /// Locks this mutex, causing the current task to yield until the lock has + /// been acquired. When the lock has been acquired, function returns a + /// [`MutexGuard`]. + pub async fn lock(&self) -> MutexGuard<'_, T> { + self.acquire().await; + + MutexGuard { mutex: self } + } + + /// Blockingly locks this `Mutex`. When the lock has been acquired, function returns a + /// [`MutexGuard`]. + /// + /// This method is intended for use cases where you + /// need to use this mutex in asynchronous code as well as in synchronous code. + pub fn blocking_lock(&self) -> MutexGuard<'_, T> { + shuttle::future::block_on(self.lock()) + } + + /// Locks this mutex, causing the current task to yield until the lock has + /// been acquired. When the lock has been acquired, this returns an + /// [`OwnedMutexGuard`]. + pub async fn lock_owned(self: Arc) -> OwnedMutexGuard { + self.acquire().await; + + OwnedMutexGuard { mutex: self } + } + + fn try_acquire(&self) -> Result<(), TryAcquireError> { + self.semaphore.try_acquire(1) + } + + /// Attempts to acquire the lock, and returns [`TryLockError`] if the + /// lock is currently held somewhere else. + pub fn try_lock(&self) -> Result, TryLockError> { + match self.try_acquire() { + Ok(()) => Ok(MutexGuard { mutex: self }), + Err(_) => Err(TryLockError(())), + } + } + + /// Attempts to acquire the lock, and returns [`TryLockError`] if the lock + /// is currently held somewhere else. + pub fn try_lock_owned(self: Arc) -> Result, TryLockError> { + match self.try_acquire() { + Ok(()) => Ok(OwnedMutexGuard { mutex: self }), + Err(_) => Err(TryLockError(())), + } + } + + /// Consumes the mutex, returning the underlying data. + pub fn into_inner(self) -> T + where + T: Sized, + { + self.inner.into_inner() + } +} + +impl Debug for Mutex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // SAFETY: Shuttle is running single-threaded, only we are able to access `inner` at the time of this call. + Debug::fmt(&unsafe { &*self.inner.get() }, f) + } +} + +impl Deref for MutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.mutex.inner.get() } + } +} + +impl DerefMut for MutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.mutex.inner.get() } + } +} + +impl Drop for MutexGuard<'_, T> { + fn drop(&mut self) { + trace!("releasing lock {:p}", self); + self.mutex.semaphore.release(1); + } +} + +impl Debug for MutexGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.mutex, f) + } +} + +impl Drop for OwnedMutexGuard { + fn drop(&mut self) { + trace!("releasing owned lock {:p}", self); + self.mutex.semaphore.release(1); + } +} + +impl Deref for OwnedMutexGuard { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.mutex.inner.get() } + } +} + +impl DerefMut for OwnedMutexGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.mutex.inner.get() } + } +} + +impl Debug for OwnedMutexGuard { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.mutex, f) + } +} + +impl From for Mutex { + fn from(s: T) -> Self { + Self::new(s) + } +} + +impl Default for Mutex +where + T: Default, +{ + fn default() -> Self { + Self::new(T::default()) + } +} diff --git a/wrappers/tokio/impls/tokio/inner/src/sync/notify.rs b/wrappers/tokio/impls/tokio/inner/src/sync/notify.rs new file mode 100644 index 00000000..def3aedc --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/sync/notify.rs @@ -0,0 +1,285 @@ +/// Notifies a single task to wake up. +/// +/// `Notify` provides a basic mechanism to notify a single task of an event. +/// `Notify` itself does not carry any data. Instead, it is to be used to signal +/// another task to perform an operation. +use crate::sync::oneshot; +use pin_project::{pin_project, pinned_drop}; +use shuttle::rand::{rngs::ThreadRng, thread_rng, Rng}; +use std::collections::VecDeque; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; +use tracing::trace; + +#[derive(Debug)] +pub struct Notify { + state: Mutex, +} + +#[derive(Debug)] +struct NotifyState { + // For assigning unique ids to waiters + next_id: usize, + + // Seed for selecting waiters to notify at random + rng: ThreadRng, + + // Whether a notify is pending + pending: bool, + + // List of waiters + waiters: VecDeque, +} + +impl NotifyState { + fn new() -> Self { + Self { + next_id: 0, + rng: thread_rng(), + pending: false, + waiters: VecDeque::new(), + } + } + + fn remove_waiter(&mut self, id: usize) -> Waiter { + for i in 0..self.waiters.len() { + if self.waiters[i].id == id { + return self.waiters.remove(i).unwrap(); + } + } + panic!("could not find waiter with id {:?} among {:?}", id, self.waiters); + } +} + +// Encode worker state in usize +const INIT: usize = 0; // not yet polled or enabled +const ENABLED: usize = 1; // enabled but not notified +const NOTIFIED: usize = 2; // notified + +#[derive(Debug)] +struct Waiter { + flag: Arc, + id: usize, + tx: oneshot::Sender<()>, +} + +/// Future returned from [`Notify::notified()`]. +/// +/// This future is fused, so once it has completed, any future calls to poll +/// will immediately return `Poll::Ready`. +#[derive(Debug)] +#[pin_project(PinnedDrop)] +pub struct Notified<'a> { + id: usize, // unique id for this waiter + flag: Arc, + /// The `Notify` being received on. + notify: &'a Notify, + // oneshot being awaited on + #[pin] + rx: oneshot::Receiver<()>, +} + +unsafe impl Send for Notified<'_> {} +unsafe impl Sync for Notified<'_> {} + +impl Notify { + /// Create a new `Notify`, initialized without a permit. + pub fn new() -> Notify { + Notify { + state: Mutex::new(NotifyState::new()), + } + } + + /// Wait for a notification. + pub fn notified(&self) -> Notified<'_> { + let (tx, rx) = oneshot::channel(); + let mut state = self.state.lock().unwrap(); + let id = state.next_id; + state.next_id += 1; + let flag = Arc::new(AtomicUsize::new(INIT)); + let waiter = Waiter { + flag: flag.clone(), + id, + tx, + }; + state.waiters.push_back(waiter); + trace!( + "notified {:?} adding waiter {:?} to waiters {:?}", + self, + id, + state.waiters, + ); + drop(state); + Notified { + id, + flag, + notify: self, + rx, + } + } + + /// Notifies a waiting task. + /// + /// If a task is currently waiting, that task is notified. Otherwise, a + /// permit is stored in this `Notify` value and the **next** call to + /// [`notified().await`] will complete immediately consuming the permit made + /// available by this call to `notify_one()`. + pub fn notify_one(&self) { + let mut state = self.state.lock().unwrap(); + // Need to choose a waiter that is Pending + let mut pending = Vec::with_capacity(state.waiters.len()); + for w in &state.waiters { + let flag = w.flag.load(Ordering::SeqCst); + assert!(flag == INIT || flag == ENABLED); + if flag == ENABLED { + pending.push(w.id); + } + } + trace!("notify_one for {:p} notifying waiters {:?}", self, pending); + if pending.is_empty() { + // No pending waiters, so just record the fact that a notify is pending + state.pending = true; + } else { + // Choose a pending waiter at random + let index = state.rng.gen_range(0..pending.len()); + let id = pending[index]; + // Remove waiter and mark it notified + let waiter = state.remove_waiter(id); + state.pending = false; + drop(state); + trace!("notify_one for {:?} waking waiter {:?}", self, waiter.id); + // Must set flag before notifying waiter + waiter.flag.store(NOTIFIED, Ordering::SeqCst); + waiter.tx.send(()).unwrap(); + } + } + + /// Notifies all waiting tasks. + /// + /// If a task is currently waiting, that task is notified. Unlike with + /// `notify_one()`, no permit is stored to be used by the next call to + /// `notified().await`. The purpose of this method is to notify all + /// already registered waiters. Registering for notification is done by + /// acquiring an instance of the `Notified` future via calling `notified()`. + pub fn notify_waiters(&self) { + let mut state = self.state.lock().unwrap(); + // Notify all waiters, including those not yet enabled + let waiters = std::mem::take(&mut state.waiters); + trace!("notify_waiters for {:p} notifying waiters {:?}", self, waiters); + state.pending = false; + drop(state); + // Since we have removed all the waiters, we need to clear all the + // flags first, before waking any of them. This is because sending + // a oneshot to wake a waiter will admit context switches, and we + // could invoke the drop handler for a waiter whose flag has not been + // cleared. + for w in &waiters { + // Must set flag before notifying waiter + let flag = w.flag.swap(NOTIFIED, Ordering::SeqCst); + assert!(flag == INIT || flag == ENABLED); + } + for w in waiters { + let _ = w.tx.send(()); // Note this may fail if the waiter has dropped + } + } +} + +impl Default for Notify { + fn default() -> Notify { + Notify::new() + } +} + +impl Notified<'_> { + /// Adds this future to the list of futures that are ready to receive + /// wakeups from calls to [`Notify::notify_one`]. + /// + /// Polling the future also adds it to the list, so this method should only + /// be used if you want to add the future to the list before the first call + /// to `poll`. (In fact, this method is equivalent to calling `poll` except + /// that no `Waker` is registered.) + /// + /// This has no effect on notifications sent using [`Notify::notify_waiters`], which + /// are received as long as they happen after the creation of the `Notified` + /// regardless of whether `enable` or `poll` has been called. + /// + /// This method returns true if the `Notified` is ready. This happens in the + /// following situations: + /// + /// 1. The `notify_waiters` method was called between the creation of the + /// `Notified` and the call to this method. + /// 2. This is the first call to `enable` or `poll` on this future, and the + /// `Notify` was holding a permit from a previous call to `notify_one`. + /// The call consumes the permit in that case. + /// 3. The future has previously been enabled or polled, and it has since + /// then been marked ready by either consuming a permit from the + /// `Notify`, or by a call to `notify_one` or `notify_waiters` that + /// removed it from the list of futures ready to receive wakeups. + /// + /// If this method returns true, any future calls to poll on the same future + /// will immediately return `Poll::Ready`. + pub fn enable(self: Pin<&mut Self>) -> bool { + self.poll_inner() + } + + fn poll_inner(&self) -> bool { + let flag = self.flag.load(Ordering::SeqCst); + if flag == NOTIFIED { + return true; + } + if flag == INIT { + // Not yet polled or enabled, so mark it enabled + self.flag.store(ENABLED, Ordering::SeqCst); + } + let mut state = self.notify.state.lock().unwrap(); + if std::mem::replace(&mut state.pending, false) { + trace!("waiter {} in state {:?} consuming permit", self.id, flag); + // We just consumed a permit, so mark this future ready and remove + // the waiter + let waiter = state.remove_waiter(self.id); + drop(state); + waiter.flag.store(NOTIFIED, Ordering::SeqCst); + waiter.tx.send(()).unwrap(); + true + } else { + false + } + } +} + +impl Future for Notified<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let enabled = self.poll_inner(); + if enabled { + Poll::Ready(()) + } else { + let mut this = self.project(); + match this.rx.as_mut().poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(_) => Poll::Ready(()), + } + } + } +} + +#[pinned_drop] +impl PinnedDrop for Notified<'_> { + fn drop(self: Pin<&mut Self>) { + trace!( + "dropping waiter {:?} with flag {:?}", + self.id, + self.flag.load(Ordering::SeqCst) + ); + // We're using std::sync::Atomics here, so no context switching will happen here + if self.flag.load(Ordering::SeqCst) != NOTIFIED { + // If the waiter hasn't been notified, remove it from the waiter queue + let mut state = self.notify.state.lock().unwrap(); + let _ = state.remove_waiter(self.id); + } + } +} diff --git a/wrappers/tokio/impls/tokio/inner/src/sync/once_cell.rs b/wrappers/tokio/impls/tokio/inner/src/sync/once_cell.rs new file mode 100644 index 00000000..e836948b --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/sync/once_cell.rs @@ -0,0 +1,418 @@ +// This file is 99% copy-paste from tokio and Loom. +// This means that it should be a true model, but it uses `shuttle_tokio::sync::Semaphore`s instead of `tokio::sync::Semaphore`s. +use super::{Semaphore, SemaphorePermit, TryAcquireError}; +use std::error::Error; +use std::fmt; +use std::future::Future; +use std::mem::MaybeUninit; +use std::ops::Drop; +use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; + +// Loom's `UnsafeCell` verbatim. +// Not `pub` in Tokio, so reimplemented here. +#[derive(Debug)] +pub(crate) struct UnsafeCell(std::cell::UnsafeCell); + +impl UnsafeCell { + pub(crate) const fn new(data: T) -> UnsafeCell { + UnsafeCell(std::cell::UnsafeCell::new(data)) + } + + #[inline(always)] + pub(crate) fn with(&self, f: impl FnOnce(*const T) -> R) -> R { + f(self.0.get()) + } + + #[inline(always)] + pub(crate) fn with_mut(&self, f: impl FnOnce(*mut T) -> R) -> R { + f(self.0.get()) + } +} + +// This file contains an implementation of an OnceCell. The principle +// behind the safety the of the cell is that any thread with an `&OnceCell` may +// access the `value` field according the following rules: +// +// 1. When `value_set` is false, the `value` field may be modified by the +// thread holding the permit on the semaphore. +// 2. When `value_set` is true, the `value` field may be accessed immutably by +// any thread. +// +// It is an invariant that if the semaphore is closed, then `value_set` is true. +// The reverse does not necessarily hold — but if not, the semaphore may not +// have any available permits. +// +// A thread with a `&mut OnceCell` may modify the value in any way it wants as +// long as the invariants are upheld. + +/// A thread-safe cell that can be written to only once. +pub struct OnceCell { + value_set: AtomicBool, + value: UnsafeCell>, + semaphore: Semaphore, +} + +impl Default for OnceCell { + fn default() -> OnceCell { + OnceCell::new() + } +} + +impl fmt::Debug for OnceCell { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("OnceCell").field("value", &self.get()).finish() + } +} + +impl Clone for OnceCell { + fn clone(&self) -> OnceCell { + OnceCell::new_with(self.get().cloned()) + } +} + +impl PartialEq for OnceCell { + fn eq(&self, other: &OnceCell) -> bool { + self.get() == other.get() + } +} + +impl Eq for OnceCell {} + +impl Drop for OnceCell { + fn drop(&mut self) { + if self.initialized_mut() { + unsafe { + self.value.with_mut(|ptr| ptr::drop_in_place((*ptr).as_mut_ptr())); + }; + } + } +} + +impl From for OnceCell { + fn from(value: T) -> Self { + OnceCell { + value_set: AtomicBool::new(true), + value: UnsafeCell::new(MaybeUninit::new(value)), + semaphore: Semaphore::new(0), // `new_closed` in tokio + } + } +} + +impl OnceCell { + /// Creates a new empty `OnceCell` instance. + pub fn new() -> Self { + OnceCell { + value_set: AtomicBool::new(false), + value: UnsafeCell::new(MaybeUninit::uninit()), + semaphore: Semaphore::new(1), + } + } + + /// Creates a new `OnceCell` that contains the provided value, if any. + /// + /// If the `Option` is `None`, this is equivalent to `OnceCell::new`. + /// + /// [`OnceCell::new`]: crate::sync::OnceCell::new + // Once https://github.com/rust-lang/rust/issues/73255 lands + // and tokio MSRV is bumped to the rustc version with it stablised, + // we can made this function available in const context, + // by creating `Semaphore::const_new_closed`. + pub fn new_with(value: Option) -> Self { + if let Some(v) = value { + OnceCell::from(v) + } else { + OnceCell::new() + } + } + + pub const fn const_new_with(value: T) -> Self { + OnceCell { + value_set: AtomicBool::new(true), + value: UnsafeCell::new(MaybeUninit::new(value)), + semaphore: Semaphore::const_new(0), // `const_new_closed` in tokio + } + } + + pub const fn const_new() -> Self { + OnceCell { + value_set: AtomicBool::new(false), + value: UnsafeCell::new(MaybeUninit::uninit()), + semaphore: Semaphore::const_new(1), + } + } + + /// Returns `true` if the `OnceCell` currently contains a value, and `false` + /// otherwise. + pub fn initialized(&self) -> bool { + // Using acquire ordering so any threads that read a true from this + // atomic is able to read the value. + self.value_set.load(Ordering::Acquire) + } + + /// Returns `true` if the `OnceCell` currently contains a value, and `false` + /// otherwise. + fn initialized_mut(&mut self) -> bool { + *self.value_set.get_mut() + } + + // SAFETY: The OnceCell must not be empty. + unsafe fn get_unchecked(&self) -> &T { + &*self.value.with(|ptr| (*ptr).as_ptr()) + } + + // SAFETY: The OnceCell must not be empty. + unsafe fn get_unchecked_mut(&mut self) -> &mut T { + &mut *self.value.with_mut(|ptr| (*ptr).as_mut_ptr()) + } + + fn set_value(&self, value: T, permit: SemaphorePermit<'_>) -> &T { + // SAFETY: We are holding the only permit on the semaphore. + unsafe { + self.value.with_mut(|ptr| (*ptr).as_mut_ptr().write(value)); + } + + // Using release ordering so any threads that read a true from this + // atomic is able to read the value we just stored. + self.value_set.store(true, Ordering::Release); + self.semaphore.close(); + permit.forget(); + + // SAFETY: We just initialized the cell. + unsafe { self.get_unchecked() } + } + + /// Returns a reference to the value currently stored in the `OnceCell`, or + /// `None` if the `OnceCell` is empty. + pub fn get(&self) -> Option<&T> { + if self.initialized() { + Some(unsafe { self.get_unchecked() }) + } else { + None + } + } + + /// Returns a mutable reference to the value currently stored in the + /// `OnceCell`, or `None` if the `OnceCell` is empty. + /// + /// Since this call borrows the `OnceCell` mutably, it is safe to mutate the + /// value inside the `OnceCell` — the mutable borrow statically guarantees + /// no other references exist. + pub fn get_mut(&mut self) -> Option<&mut T> { + if self.initialized_mut() { + Some(unsafe { self.get_unchecked_mut() }) + } else { + None + } + } + + /// Sets the value of the `OnceCell` to the given value if the `OnceCell` is + /// empty. + /// + /// If the `OnceCell` already has a value, this call will fail with an + /// [`SetError::AlreadyInitializedError`]. + /// + /// If the `OnceCell` is empty, but some other task is currently trying to + /// set the value, this call will fail with [`SetError::InitializingError`]. + /// + /// [`SetError::AlreadyInitializedError`]: crate::sync::SetError::AlreadyInitializedError + /// [`SetError::InitializingError`]: crate::sync::SetError::InitializingError + pub fn set(&self, value: T) -> Result<(), SetError> { + if self.initialized() { + return Err(SetError::AlreadyInitializedError(value)); + } + + // Another task might be initializing the cell, in which case + // `try_acquire` will return an error. If we succeed to acquire the + // permit, then we can set the value. + match self.semaphore.try_acquire() { + Ok(permit) => { + debug_assert!(!self.initialized()); + self.set_value(value, permit); + Ok(()) + } + Err(TryAcquireError::NoPermits) => { + // Some other task is holding the permit. That task is + // currently trying to initialize the value. + Err(SetError::InitializingError(value)) + } + Err(TryAcquireError::Closed) => { + // The semaphore was closed. Some other task has initialized + // the value. + Err(SetError::AlreadyInitializedError(value)) + } + } + } + + /// Gets the value currently in the `OnceCell`, or initialize it with the + /// given asynchronous operation. + /// + /// If some other task is currently working on initializing the `OnceCell`, + /// this call will wait for that other task to finish, then return the value + /// that the other task produced. + /// + /// If the provided operation is cancelled or panics, the initialization + /// attempt is cancelled. If there are other tasks waiting for the value to + /// be initialized, one of them will start another attempt at initializing + /// the value. + /// + /// This will deadlock if `f` tries to initialize the cell recursively. + pub async fn get_or_init(&self, f: F) -> &T + where + F: FnOnce() -> Fut, + Fut: Future, + { + if self.initialized() { + // SAFETY: The OnceCell has been fully initialized. + unsafe { self.get_unchecked() } + } else { + // Here we try to acquire the semaphore permit. Holding the permit + // will allow us to set the value of the OnceCell, and prevents + // other tasks from initializing the OnceCell while we are holding + // it. + match self.semaphore.acquire().await { + Ok(permit) => { + debug_assert!(!self.initialized()); + + // If `f()` panics or `select!` is called, this + // `get_or_init` call is aborted and the semaphore permit is + // dropped. + let value = f().await; + + self.set_value(value, permit) + } + Err(_) => { + debug_assert!(self.initialized()); + + // SAFETY: The semaphore has been closed. This only happens + // when the OnceCell is fully initialized. + unsafe { self.get_unchecked() } + } + } + } + } + + /// Gets the value currently in the `OnceCell`, or initialize it with the + /// given asynchronous operation. + /// + /// If some other task is currently working on initializing the `OnceCell`, + /// this call will wait for that other task to finish, then return the value + /// that the other task produced. + /// + /// If the provided operation returns an error, is cancelled or panics, the + /// initialization attempt is cancelled. If there are other tasks waiting + /// for the value to be initialized, one of them will start another attempt + /// at initializing the value. + /// + /// This will deadlock if `f` tries to initialize the cell recursively. + pub async fn get_or_try_init(&self, f: F) -> Result<&T, E> + where + F: FnOnce() -> Fut, + Fut: Future>, + { + if self.initialized() { + // SAFETY: The OnceCell has been fully initialized. + unsafe { Ok(self.get_unchecked()) } + } else { + // Here we try to acquire the semaphore permit. Holding the permit + // will allow us to set the value of the OnceCell, and prevents + // other tasks from initializing the OnceCell while we are holding + // it. + match self.semaphore.acquire().await { + Ok(permit) => { + debug_assert!(!self.initialized()); + + // If `f()` panics or `select!` is called, this + // `get_or_try_init` call is aborted and the semaphore + // permit is dropped. + let value = f().await; + + match value { + Ok(value) => Ok(self.set_value(value, permit)), + Err(e) => Err(e), + } + } + Err(_) => { + debug_assert!(self.initialized()); + + // SAFETY: The semaphore has been closed. This only happens + // when the OnceCell is fully initialized. + unsafe { Ok(self.get_unchecked()) } + } + } + } + } + + /// Takes the value from the cell, destroying the cell in the process. + /// Returns `None` if the cell is empty. + pub fn into_inner(mut self) -> Option { + if self.initialized_mut() { + // Set to uninitialized for the destructor of `OnceCell` to work properly + *self.value_set.get_mut() = false; + Some(unsafe { self.value.with(|ptr| ptr::read(ptr).assume_init()) }) + } else { + None + } + } + + /// Takes ownership of the current value, leaving the cell empty. Returns + /// `None` if the cell is empty. + pub fn take(&mut self) -> Option { + std::mem::take(self).into_inner() + } +} + +// Since `get` gives us access to immutable references of the OnceCell, OnceCell +// can only be Sync if T is Sync, otherwise OnceCell would allow sharing +// references of !Sync values across threads. We need T to be Send in order for +// OnceCell to by Sync because we can use `set` on `&OnceCell` to send values +// (of type T) across threads. +unsafe impl Sync for OnceCell {} + +// Access to OnceCell's value is guarded by the semaphore permit +// and atomic operations on `value_set`, so as long as T itself is Send +// it's safe to send it to another thread +unsafe impl Send for OnceCell {} + +/// Errors that can be returned from [`OnceCell::set`]. +/// +/// [`OnceCell::set`]: crate::sync::OnceCell::set +#[derive(Debug, PartialEq, Eq)] +pub enum SetError { + /// The cell was already initialized when [`OnceCell::set`] was called. + /// + /// [`OnceCell::set`]: crate::sync::OnceCell::set + AlreadyInitializedError(T), + + /// The cell is currently being initialized. + InitializingError(T), +} + +impl fmt::Display for SetError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SetError::AlreadyInitializedError(_) => write!(f, "AlreadyInitializedError"), + SetError::InitializingError(_) => write!(f, "InitializingError"), + } + } +} + +impl Error for SetError {} + +impl SetError { + /// Whether `SetError` is `SetError::AlreadyInitializedError`. + pub fn is_already_init_err(&self) -> bool { + match self { + SetError::AlreadyInitializedError(_) => true, + SetError::InitializingError(_) => false, + } + } + + /// Whether `SetError` is `SetError::InitializingError` + pub fn is_initializing_err(&self) -> bool { + match self { + SetError::AlreadyInitializedError(_) => false, + SetError::InitializingError(_) => true, + } + } +} diff --git a/wrappers/tokio/impls/tokio/inner/src/sync/oneshot.rs b/wrappers/tokio/impls/tokio/inner/src/sync/oneshot.rs new file mode 100644 index 00000000..7eddf298 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/sync/oneshot.rs @@ -0,0 +1,145 @@ +//! A one-shot channel is used for sending a single message between +//! asynchronous tasks. The [`channel`] function is used to create a +//! [`Sender`] and [`Receiver`] handle pair that form the channel. + +use futures::channel::oneshot; +use shuttle::future; +use std::future::Future; +use std::pin::{pin, Pin}; +use std::task::{Context, Poll}; +use tracing::trace; + +/// Sends a value to the associated [`Receiver`]. +#[derive(Debug)] +pub struct Sender(oneshot::Sender); + +#[derive(Debug)] +pub struct Receiver(oneshot::Receiver); + +pub mod error { + //! Oneshot error types. + use std::fmt; + + /// Error returned by the `Future` implementation for `Receiver`. + #[derive(Debug, Eq, PartialEq, Clone)] + pub struct RecvError(pub(super) ()); + + /// Error returned by the `try_recv` function on `Receiver`. + #[derive(Debug, Eq, PartialEq, Clone)] + pub enum TryRecvError { + /// The send half of the channel has not yet sent a value. + Empty, + + /// The send half of the channel was dropped without sending a value. + Closed, + } + + // ===== impl RecvError ===== + + impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl std::error::Error for RecvError {} + + // ===== impl TryRecvError ===== + + impl fmt::Display for TryRecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryRecvError::Empty => write!(fmt, "channel empty"), + TryRecvError::Closed => write!(fmt, "channel closed"), + } + } + } + + impl std::error::Error for TryRecvError {} +} +use self::error::{RecvError, TryRecvError}; + +/// Creates a new one-shot channel for sending single values across asynchronous +/// tasks. +pub fn channel() -> (Sender, Receiver) { + let (tx, rx) = oneshot::channel(); + (Sender(tx), Receiver(rx)) +} + +impl Sender { + /// Attempts to send a value on this channel, returning it back if it could + /// not be sent. + pub fn send(self, t: T) -> Result<(), T> { + trace!("Sending message on oneshot {:p}", &self.0); + let send_result = self.0.send(t); + shuttle::thread::yield_now(); + send_result + } + + /// Waits for the associated [`Receiver`] handle to close. + pub async fn closed(&mut self) { + trace!("sender closing oneshot {:p}", &self.0); + self.0.cancellation().await; + shuttle::future::yield_now().await; + } + + /// Returns `true` if the associated [`Receiver`] handle has been dropped. + pub fn is_closed(&self) -> bool { + self.0.is_canceled() + } + + /// Checks whether the oneshot channel has been closed, and if not, schedules the + /// `Waker` in the provided `Context` to receive a notification when the channel is + /// closed. + /// + /// Note that on multiple calls to poll, only the `Waker` from the `Context` passed + /// to the most recent call will be scheduled to receive a wakeup. + pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { + self.0.poll_canceled(cx) + } +} + +impl Receiver { + /// Prevents the associated [`Sender`] handle from sending a value. + pub fn close(&mut self) { + self.0.close(); + shuttle::thread::yield_now(); + } + + /// Attempts to receive a value. + pub fn try_recv(&mut self) -> Result { + let out = match self.0.try_recv() { + Ok(Some(v)) => Ok(v), + Ok(None) => Err(TryRecvError::Empty), + Err(_) => Err(TryRecvError::Closed), + }; + shuttle::thread::yield_now(); + out + } + + /// Blocking receive to call outside of asynchronous contexts. + pub fn blocking_recv(self) -> Result { + future::block_on(self) + } +} + +impl Drop for Receiver { + fn drop(&mut self) { + tracing::trace!("dropping oneshot receiver {:p}", self); + } +} + +impl Future for Receiver { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let receiver = pin!(&mut self.0); + trace!("polling oneshot receiver {:p}", receiver); + let poll_result = receiver.poll(cx).map_err(|_| RecvError(())); + if poll_result.is_ready() { + // Force a yield + shuttle::thread::yield_now(); + } + poll_result + } +} diff --git a/wrappers/tokio/impls/tokio/inner/src/sync/rwlock.rs b/wrappers/tokio/impls/tokio/inner/src/sync/rwlock.rs new file mode 100644 index 00000000..1ae441ae --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/sync/rwlock.rs @@ -0,0 +1,580 @@ +//! An asynchronous reader-writer lock. + +use crate::sync::TryLockError; +use shuttle::future::batch_semaphore::{BatchSemaphore, Fairness, TryAcquireError}; +use std::cell::UnsafeCell; +use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::sync::Arc; +use std::{fmt, ops}; +use tracing::trace; + +const MAX_READERS: usize = usize::MAX >> 3; + +/// An asynchronous reader-writer lock. +/// +/// This type of lock allows a number of readers or at most one writer at any +/// point in time. The write portion of this lock typically allows modification +/// of the underlying data (exclusive access) and the read portion of this lock +/// typically allows for read-only access (shared access). +/// +/// In comparison, a `Mutex` does not distinguish between readers or writers +/// that acquire the lock, therefore causing any tasks waiting for the lock to +/// become available to yield. An `RwLock` will allow any number of readers to +/// acquire the lock as long as a writer is not holding the lock. +#[derive(Debug)] +pub struct RwLock { + // maximum number of concurrent readers + max_readers: usize, + + //semaphore to coordinate read and write access to T + sem: BatchSemaphore, + + //inner data T + inner: UnsafeCell, +} + +impl RwLock { + /// Creates a new instance of an `RwLock` which is unlocked. + pub fn new(value: T) -> Self + where + T: Sized, + { + Self::with_max_readers(value, MAX_READERS) + } + + /// Creates a new instance of an `RwLock` which is unlocked + /// and allows a maximum of `max_readers` concurrent readers. + pub fn with_max_readers(value: T, max_readers: usize) -> Self + where + T: Sized, + { + assert!( + max_readers <= MAX_READERS, + "a RwLock may not be created with more than {MAX_READERS} readers" + ); + let sem = BatchSemaphore::new(max_readers, Fairness::StrictlyFair); + let rwlock = RwLock { + max_readers, + sem, + inner: UnsafeCell::new(value), + }; + trace!("initialized RwLock {:p} with {} permits", &rwlock, max_readers); + rwlock + } + + /// Locks this `RwLock` with shared read access, causing the current task + /// to yield until the lock has been acquired. + /// + /// The calling task will yield until there are no writers which hold the + /// lock. There may be other readers inside the lock when the task resumes. + /// + /// Returns an RAII guard which will drop this read access of the `RwLock` + /// when dropped. + pub async fn read(&self) -> RwLockReadGuard<'_, T> { + let inner = self.sem.acquire(1); + inner.await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + if !std::thread::panicking() { + unreachable!() + } + }); + + trace!("rwlock {:p} acquired ReadGuard", self); + RwLockReadGuard { + sem: &self.sem, + data: self.inner.get(), + _p: PhantomData, + } + } + + /// Blockingly locks this `RwLock` with shared read access. + /// + /// This method is intended for use cases where you + /// need to use this rwlock in asynchronous code as well as in synchronous code. + pub fn blocking_read(&self) -> RwLockReadGuard<'_, T> { + shuttle::future::block_on(self.read()) + } + + /// Locks this `RwLock` with shared read access, causing the current task + /// to yield until the lock has been acquired. + /// + /// The calling task will yield until there are no writers which hold the + /// lock. There may be other readers inside the lock when the task resumes. + /// + /// This method is identical to [`RwLock::read`], except that the returned + /// guard references the `RwLock` with an [`Arc`] rather than by borrowing + /// it. Therefore, the `RwLock` must be wrapped in an `Arc` to call this + /// method, and the guard will live for the `'static` lifetime, as it keeps + /// the `RwLock` alive by holding an `Arc`. + /// + /// Returns an RAII guard which will drop this read access of the `RwLock` + /// when dropped. + pub async fn read_owned(self: Arc) -> OwnedRwLockReadGuard { + let inner = self.sem.acquire(1); + inner.await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + if !std::thread::panicking() { + unreachable!() + } + }); + + trace!("rwlock {:p} acquired OwnedReadGuard", self); + OwnedRwLockReadGuard { + data: self.inner.get(), + lock: ManuallyDrop::new(self), + _p: PhantomData, + } + } + + /// Attempts to acquire this `RwLock` with shared read access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release read access + /// when dropped. + pub fn try_read(&self) -> Result, TryLockError> { + match self.sem.try_acquire(1) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => { + if !std::thread::panicking() { + unreachable!() + } + } + } + + trace!("rwlock {:p} try_read acquired ReadGuard", self); + Ok(RwLockReadGuard { + sem: &self.sem, + data: self.inner.get(), + _p: PhantomData, + }) + } + + /// Attempts to acquire this `RwLock` with shared read access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release read access + /// when dropped. + /// + /// This method is identical to [`RwLock::try_read`], except that the + /// returned guard references the `RwLock` with an [`Arc`] rather than by + /// borrowing it. Therefore, the `RwLock` must be wrapped in an `Arc` to + /// call this method, and the guard will live for the `'static` lifetime, + /// as it keeps the `RwLock` alive by holding an `Arc`. + pub fn try_read_owned(self: Arc) -> Result, TryLockError> { + match self.sem.try_acquire(1) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => { + if !std::thread::panicking() { + unreachable!() + } + } + } + + trace!("rwlock {:p} try_read acquired OwnedReadGuard", self); + Ok(OwnedRwLockReadGuard { + data: self.inner.get(), + lock: ManuallyDrop::new(self), + _p: PhantomData, + }) + } + + /// Locks this `RwLock` with exclusive write access, causing the current + /// task to yield until the lock has been acquired. + /// + /// The calling task will yield while other writers or readers currently + /// have access to the lock. + /// + /// Returns an RAII guard which will drop the write access of this `RwLock` + /// when dropped. + pub async fn write(&self) -> RwLockWriteGuard<'_, T> { + self.sem.acquire(self.max_readers).await.unwrap_or_else(|_| { + if !std::thread::panicking() { + unreachable!() + } + }); + + trace!("rwlock {:p} acquired WriteGuard", self); + RwLockWriteGuard { + permits_acquired: self.max_readers, + data: self.inner.get(), + sem: &self.sem, + _p: PhantomData, + } + } + + /// Blockingly locks this `RwLock` with exclusive write access. + /// + /// This method is intended for use cases where you + /// need to use this rwlock in asynchronous code as well as in synchronous code. + pub fn blocking_write(&self) -> RwLockWriteGuard<'_, T> { + shuttle::future::block_on(self.write()) + } + + /// Locks this `RwLock` with exclusive write access, causing the current + /// task to yield until the lock has been acquired. + /// + /// The calling task will yield while other writers or readers currently + /// have access to the lock. + /// + /// This method is identical to [`RwLock::write`], except that the returned + /// guard references the `RwLock` with an [`Arc`] rather than by borrowing + /// it. Therefore, the `RwLock` must be wrapped in an `Arc` to call this + /// method, and the guard will live for the `'static` lifetime, as it keeps + /// the `RwLock` alive by holding an `Arc`. + pub async fn write_owned(self: Arc) -> OwnedRwLockWriteGuard { + self.sem.acquire(self.max_readers).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + if !std::thread::panicking() { + unreachable!() + } + }); + + tracing::trace!("rwlock {:p} acquired OwnedWriteGuard", self,); + OwnedRwLockWriteGuard { + permits_acquired: self.max_readers, + data: self.inner.get(), + lock: ManuallyDrop::new(self), + _p: PhantomData, + } + } + + /// Attempts to acquire this `RwLock` with exclusive write access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release write access + /// when dropped. + pub fn try_write(&self) -> Result, TryLockError> { + match self.sem.try_acquire(self.max_readers) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => { + if !std::thread::panicking() { + unreachable!() + } + } + } + + tracing::trace!("rwlock {:p} try_write acquired WriteGuard", self,); + Ok(RwLockWriteGuard { + permits_acquired: self.max_readers, + sem: &self.sem, + data: self.inner.get(), + _p: PhantomData, + }) + } + + /// Attempts to acquire this `RwLock` with exclusive write access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release write access + /// when dropped. + /// + /// This method is identical to [`RwLock::try_write`], except that the + /// returned guard references the `RwLock` with an [`Arc`] rather than by + /// borrowing it. Therefore, the `RwLock` must be wrapped in an `Arc` to + /// call this method, and the guard will live for the `'static` lifetime, + /// as it keeps the `RwLock` alive by holding an `Arc`. + pub fn try_write_owned(self: Arc) -> Result, TryLockError> { + match self.sem.try_acquire(self.max_readers) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => { + if !std::thread::panicking() { + unreachable!() + } + } + } + + tracing::trace!("rwlock {:p} try_write acquired OwnedWriteGuard", self,); + Ok(OwnedRwLockWriteGuard { + permits_acquired: self.max_readers, + data: self.inner.get(), + lock: ManuallyDrop::new(self), + _p: PhantomData, + }) + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `RwLock` mutably, no actual locking needs to + /// take place -- the mutable borrow statically guarantees no locks exist. + pub fn get_mut(&mut self) -> &mut T { + unsafe { + // Safety: This is https://github.com/rust-lang/rust/pull/76936 + &mut *self.inner.get() + } + } + + /// Consumes the lock, returning the underlying data. + pub fn into_inner(self) -> T + where + T: Sized, + { + self.inner.into_inner() + } +} + +impl From for RwLock { + fn from(s: T) -> Self { + Self::new(s) + } +} + +impl Default for RwLock +where + T: Default, +{ + fn default() -> Self { + Self::new(T::default()) + } +} + +// As long as T: Send + Sync, it's fine to send and share RwLock between threads. +// If T were not Send, sending and sharing a RwLock would be bad, since you can access T through +// RwLock. +unsafe impl Send for RwLock where T: ?Sized + Send {} +unsafe impl Sync for RwLock where T: ?Sized + Send + Sync {} +// NB: These impls need to be explicit since we're storing a raw pointer. +// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over +// `T` is `Send`. +unsafe impl Send for RwLockReadGuard<'_, T> where T: ?Sized + Sync {} +unsafe impl Sync for RwLockReadGuard<'_, T> where T: ?Sized + Send + Sync {} +// T is required to be `Send` because an OwnedRwLockReadGuard can be used to drop the value held in +// the RwLock, unlike RwLockReadGuard. +unsafe impl Send for OwnedRwLockReadGuard +where + T: ?Sized + Send + Sync, + U: ?Sized + Sync, +{ +} +unsafe impl Sync for OwnedRwLockReadGuard +where + T: ?Sized + Send + Sync, + U: ?Sized + Send + Sync, +{ +} +unsafe impl Sync for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl Sync for OwnedRwLockWriteGuard where T: ?Sized + Send + Sync {} + +// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over +// `T` is `Send` - but since this is also provides mutable access, we need to +// make sure that `T` is `Send` since its value can be sent across thread +// boundaries. +unsafe impl Send for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl Send for OwnedRwLockWriteGuard where T: ?Sized + Send + Sync {} + +/// RAII structure used to release the shared read access of a lock when +/// dropped. +/// +/// This structure is created by the [`read`] method on +/// [`RwLock`]. +/// +/// [`read`]: method@crate::sync::RwLock::read +/// [`RwLock`]: struct@crate::sync::RwLock +pub struct RwLockReadGuard<'a, T: ?Sized> { + sem: &'a BatchSemaphore, + data: *const T, + _p: PhantomData<&'a T>, +} + +impl ops::Deref for RwLockReadGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl fmt::Debug for RwLockReadGuard<'_, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl Drop for RwLockReadGuard<'_, T> { + fn drop(&mut self) { + self.sem.release(1); + } +} + +/// Owned RAII structure used to release the shared read access of a lock when +/// dropped. +/// +/// This structure is created by the [`read_owned`] method on +/// [`RwLock`]. +/// +/// [`read_owned`]: method@crate::sync::RwLock::read_owned +/// [`RwLock`]: struct@crate::sync::RwLock +pub struct OwnedRwLockReadGuard { + // ManuallyDrop allows us to destructure into this field without running the destructor. + lock: ManuallyDrop>>, + data: *const U, + _p: PhantomData, +} + +impl ops::Deref for OwnedRwLockReadGuard { + type Target = U; + + fn deref(&self) -> &U { + unsafe { &*self.data } + } +} + +impl fmt::Debug for OwnedRwLockReadGuard +where + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl Drop for OwnedRwLockReadGuard { + fn drop(&mut self) { + self.lock.sem.release(1); + unsafe { ManuallyDrop::drop(&mut self.lock) }; + } +} + +/// RAII structure used to release the exclusive write access of a lock when +/// dropped. +pub struct RwLockWriteGuard<'a, T: ?Sized> { + permits_acquired: usize, + sem: &'a BatchSemaphore, + data: *mut T, + _p: PhantomData<&'a mut T>, +} + +impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { + /// Atomically downgrades a write lock into a read lock without allowing + /// any writers to take exclusive access of the lock in the meantime. + /// + /// **Note:** This won't *necessarily* allow any additional readers to acquire + /// locks, since [`RwLock`] is fair and it is possible that a writer is next + /// in line. + pub fn downgrade(self) -> RwLockReadGuard<'a, T> { + let RwLockWriteGuard { sem, data, .. } = self; + let to_release = self.permits_acquired - 1; + + tracing::trace!("rwlock {:p} downgrading to ReadGuard", &self); + + // NB: Forget to avoid drop impl from being called. + std::mem::forget(self); + + // Release all but one of the permits held by the write guard + sem.release(to_release); + + RwLockReadGuard { + sem, + data, + _p: PhantomData, + } + } +} + +impl ops::Deref for RwLockWriteGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl ops::DerefMut for RwLockWriteGuard<'_, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.data } + } +} + +impl fmt::Debug for RwLockWriteGuard<'_, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl Drop for RwLockWriteGuard<'_, T> { + fn drop(&mut self) { + self.sem.release(self.permits_acquired); + } +} + +/// Owned RAII structure used to release the exclusive write access of a lock when +/// dropped. +pub struct OwnedRwLockWriteGuard { + permits_acquired: usize, + // ManuallyDrop allows us to destructure into this field without running the destructor. + lock: ManuallyDrop>>, + data: *mut T, + _p: PhantomData, +} + +impl OwnedRwLockWriteGuard { + /// Atomically downgrades a write lock into a read lock without allowing + /// any writers to take exclusive access of the lock in the meantime. + /// + /// **Note:** This won't *necessarily* allow any additional readers to acquire + /// locks, since [`RwLock`] is fair and it is possible that a writer is next + /// in line. + pub fn downgrade(mut self) -> OwnedRwLockReadGuard { + let lock = unsafe { ManuallyDrop::take(&mut self.lock) }; + + let data = self.data; + let to_release = self.permits_acquired - 1; + + tracing::trace!("rwlock {:p} downgrading to OwnedReadGuard", &self); + + // NB: Forget to avoid drop impl from being called. + std::mem::forget(self); + + // Release all but one of the permits held by the write guard + lock.sem.release(to_release); + + OwnedRwLockReadGuard { + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + } + } +} + +impl ops::Deref for OwnedRwLockWriteGuard { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl ops::DerefMut for OwnedRwLockWriteGuard { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.data } + } +} + +impl fmt::Debug for OwnedRwLockWriteGuard +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl Drop for OwnedRwLockWriteGuard { + fn drop(&mut self) { + self.lock.sem.release(self.permits_acquired); + unsafe { ManuallyDrop::drop(&mut self.lock) }; + } +} diff --git a/wrappers/tokio/impls/tokio/inner/src/sync/semaphore.rs b/wrappers/tokio/impls/tokio/inner/src/sync/semaphore.rs new file mode 100644 index 00000000..67f004a5 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/sync/semaphore.rs @@ -0,0 +1,272 @@ +//! Counting semaphore performing asynchronous permit acquisition. + +use shuttle::future::batch_semaphore::{AcquireError, BatchSemaphore, Fairness, TryAcquireError}; +use std::sync::Arc; + +/// Counting semaphore performing asynchronous permit acquisition. +#[derive(Debug)] +pub struct Semaphore { + sem: BatchSemaphore, +} + +/// A permit from the semaphore. +#[must_use] +#[derive(Debug)] +pub struct SemaphorePermit<'a> { + sem: &'a Semaphore, + permits: u32, +} + +/// An owned permit from the semaphore. +#[must_use] +#[derive(Debug)] +pub struct OwnedSemaphorePermit { + sem: Arc, + permits: u32, +} + +impl Semaphore { + /// The maximum number of permits which a semaphore can hold. It is `usize::MAX >> 3`. + /// + /// Exceeding this limit typically results in a panic. + pub const MAX_PERMITS: usize = usize::MAX >> 3; + + /// Creates a new semaphore with the initial number of permits. + pub fn new(num_permits: usize) -> Self { + let sem = BatchSemaphore::new(num_permits, Fairness::StrictlyFair); + Self { sem } + } + + /// Creates a new semaphore with the initial number of permits. + pub const fn const_new(num_permits: usize) -> Self { + let sem = BatchSemaphore::const_new(num_permits, Fairness::StrictlyFair); + Self { sem } + } + + /// Returns the current number of available permits. + pub fn available_permits(&self) -> usize { + self.sem.available_permits() + } + + /// Adds `n` new permits to the semaphore. + /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. + pub fn add_permits(&self, n: usize) { + self.sem.release(n); + } + + /// Acquires a permit from the semaphore. + /// + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`SemaphorePermit`] representing the + /// acquired permit. + pub async fn acquire(&self) -> Result, AcquireError> { + self.acquire_many(1).await + } + + /// Acquires `n` permits from the semaphore. + /// + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`SemaphorePermit`] representing the + /// acquired permits. + pub async fn acquire_many(&self, permits: u32) -> Result, AcquireError> { + self.sem.acquire(permits as usize).await?; + Ok(SemaphorePermit { sem: self, permits }) + } + + /// Tries to acquire a permit from the semaphore. + /// + /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise, + /// this returns a [`SemaphorePermit`] representing the acquired permits. + pub fn try_acquire(&self) -> Result, TryAcquireError> { + self.try_acquire_many(1) + } + + /// Tries to acquire `n` permits from the semaphore. + /// + /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are not enough permits left. + /// Otherwise, this returns a [`SemaphorePermit`] representing the acquired permits. + pub fn try_acquire_many(&self, permits: u32) -> Result, TryAcquireError> { + match self.sem.try_acquire(permits as usize) { + Ok(()) => Ok(SemaphorePermit { sem: self, permits }), + Err(e) => Err(e), + } + } + + /// Acquires an owned permit from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + pub async fn acquire_owned(self: Arc) -> Result { + self.acquire_many_owned(1).await + } + + /// Acquires `n` owned permits from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + pub async fn acquire_many_owned(self: Arc, permits: u32) -> Result { + self.sem.acquire(permits as usize).await?; + Ok(OwnedSemaphorePermit { sem: self, permits }) + } + + /// Tries to acquire an owned permit from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. If + /// the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + pub fn try_acquire_owned(self: Arc) -> Result { + self.try_acquire_many_owned(1) + } + + /// Tries to acquire `n` owned permits from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. If + /// the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + pub fn try_acquire_many_owned(self: Arc, permits: u32) -> Result { + match self.sem.try_acquire(permits as usize) { + Ok(()) => Ok(OwnedSemaphorePermit { sem: self, permits }), + Err(e) => Err(e), + } + } + + /// Closes the semaphore. + /// + /// This prevents the semaphore from issuing new permits and notifies all pending waiters. + pub fn close(&self) { + self.sem.close(); + } + + /// Returns true if the semaphore is closed + pub fn is_closed(&self) -> bool { + self.sem.is_closed() + } +} + +impl SemaphorePermit<'_> { + /// Forgets the permit **without** releasing it back to the semaphore. + /// This can be used to reduce the amount of permits available from a + /// semaphore. + pub fn forget(mut self) { + self.permits = 0; + } + + /// Merge two [`SemaphorePermit`] instances together, consuming `other` + /// without releasing the permits it holds. + /// + /// Permits held by both `self` and `other` are released when `self` drops. + /// + /// # Panics + /// + /// This function panics if permits from different [`Semaphore`] instances + /// are merged. + #[track_caller] + pub fn merge(&mut self, mut other: Self) { + assert!( + std::ptr::eq(self.sem, other.sem), + "merging permits from different semaphore instances" + ); + self.permits += other.permits; + other.permits = 0; + } + + /// Splits `n` permits from `self` and returns a new [`SemaphorePermit`] instance that holds `n` permits. + /// + /// If there are insufficient permits and it's not possible to reduce by `n`, returns `None`. + pub fn split(&mut self, n: usize) -> Option { + let n = u32::try_from(n).ok()?; + + if n > self.permits { + return None; + } + + self.permits -= n; + + Some(Self { + sem: self.sem, + permits: n, + }) + } + + /// Returns the number of permits held by `self`. + pub fn num_permits(&self) -> usize { + self.permits as usize + } +} + +impl OwnedSemaphorePermit { + /// Forgets the permit **without** releasing it back to the semaphore. + /// This can be used to reduce the amount of permits available from a + /// semaphore. + pub fn forget(mut self) { + self.permits = 0; + } + + /// Merge two [`OwnedSemaphorePermit`] instances together, consuming `other` + /// without releasing the permits it holds. + /// + /// Permits held by both `self` and `other` are released when `self` drops. + #[track_caller] + pub fn merge(&mut self, mut other: Self) { + assert!( + Arc::ptr_eq(&self.sem, &other.sem), + "merging permits from different semaphore instances" + ); + self.permits += other.permits; + other.permits = 0; + } + + /// Splits `n` permits from `self` and returns a new [`OwnedSemaphorePermit`] instance that holds `n` permits. + /// + /// If there are insufficient permits and it's not possible to reduce by `n`, returns `None`. + /// + /// # Note + /// + /// It will clone the owned `Arc` to construct the new instance. + pub fn split(&mut self, n: usize) -> Option { + let n = u32::try_from(n).ok()?; + + if n > self.permits { + return None; + } + + self.permits -= n; + + Some(Self { + sem: self.sem.clone(), + permits: n, + }) + } + + /// Returns the [`Semaphore`] from which this permit was acquired. + pub fn semaphore(&self) -> &Arc { + &self.sem + } + + /// Returns the number of permits held by `self`. + pub fn num_permits(&self) -> usize { + self.permits as usize + } +} + +impl Drop for SemaphorePermit<'_> { + fn drop(&mut self) { + self.sem.add_permits(self.permits as usize); + } +} + +impl Drop for OwnedSemaphorePermit { + fn drop(&mut self) { + self.sem.add_permits(self.permits as usize); + } +} diff --git a/wrappers/tokio/impls/tokio/inner/src/sync/watch.rs b/wrappers/tokio/impls/tokio/inner/src/sync/watch.rs new file mode 100644 index 00000000..d776e991 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/sync/watch.rs @@ -0,0 +1,1129 @@ +// +// Note: This code is mostly copied directly from the tokio sources, version 1.19.2 +// + +//! A single-producer, multi-consumer channel that only retains the *last* sent +//! value. +//! +//! This channel is useful for watching for changes to a value from multiple +//! points in the code base, for example, changes to configuration values. +//! +//! # Usage +//! +//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are the producer +//! and sender halves of the channel. The channel is created with an initial +//! value. The **latest** value stored in the channel is accessed with +//! [`Receiver::borrow()`]. Awaiting [`Receiver::changed()`] waits for a new +//! value to sent by the [`Sender`] half. +//! +//! # Examples +//! +//! ```text +//! use crate::sync::watch; +//! +//! # async fn dox() -> Result<(), Box> { +//! let (tx, mut rx) = watch::channel("hello"); +//! +//! tokio::spawn(async move { +//! while rx.changed().await.is_ok() { +//! println!("received = {:?}", *rx.borrow()); +//! } +//! }); +//! +//! tx.send("world")?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Closing +//! +//! [`Sender::is_closed`] and [`Sender::closed`] allow the producer to detect +//! when all [`Receiver`] handles have been dropped. This indicates that there +//! is no further interest in the values being produced and work can be stopped. +//! +//! # Thread safety +//! +//! Both [`Sender`] and [`Receiver`] are thread safe. They can be moved to other +//! threads and can be used in a concurrent environment. Clones of [`Receiver`] +//! handles may be moved to separate threads and also used concurrently. +//! +//! [`Sender`]: crate::sync::watch::Sender +//! [`Receiver`]: crate::sync::watch::Receiver +//! [`Receiver::changed()`]: crate::sync::watch::Receiver::changed +//! [`Receiver::borrow()`]: crate::sync::watch::Receiver::borrow +//! [`channel`]: crate::sync::watch::channel +//! [`Sender::is_closed`]: crate::sync::watch::Sender::is_closed +//! [`Sender::closed`]: crate::sync::watch::Sender::closed + +use crate::sync::notify::Notify; +use crate::sync::{RwLock, RwLockReadGuard}; +use std::mem; +use std::ops; +use std::panic; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +/// Receives values from the associated [`Sender`](struct@Sender). +/// +/// Instances are created by the [`channel`](fn@channel) function. +/// +/// To turn this receiver into a `Stream`, you can use the [`WatchStream`] +/// wrapper. +/// +/// [`WatchStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.WatchStream.html +#[derive(Debug)] +pub struct Receiver { + /// Pointer to the shared state + shared: Arc>, + + /// Last observed version + version: Version, +} + +/// Sends values to the associated [`Receiver`](struct@Receiver). +/// +/// Instances are created by the [`channel`](fn@channel) function. +#[derive(Debug)] +pub struct Sender { + shared: Arc>, +} + +/// Returns a reference to the inner value. +/// +/// Outstanding borrows hold a read lock on the inner value. This means that +/// long lived borrows could cause the produce half to block. It is recommended +/// to keep the borrow as short lived as possible. +/// +/// The priority policy of the lock is dependent on the underlying lock +/// implementation, and this type does not guarantee that any particular policy +/// will be used. In particular, a producer which is waiting to acquire the lock +/// in `send` might or might not block concurrent calls to `borrow`, e.g.: +/// +///
Potential deadlock example +/// +/// ```text +/// // Task 1 (on thread A) | // Task 2 (on thread B) +/// let _ref1 = rx.borrow(); | +/// | // will block +/// | let _ = tx.send(()); +/// // may deadlock | +/// let _ref2 = rx.borrow(); | +/// ``` +///
+#[derive(Debug)] +pub struct Ref<'a, T> { + inner: RwLockReadGuard<'a, T>, +} + +#[derive(Debug)] +struct Shared { + /// The most recent value. + value: RwLock, + + /// The current version. + /// + /// The lowest bit represents a "closed" state. The rest of the bits + /// represent the current version. + state: AtomicState, + + /// Tracks the number of `Receiver` instances. + ref_count_rx: AtomicUsize, + + /// Notifies waiting receivers that the value changed. + notify_rx: Notify, + + /// Notifies any task listening for `Receiver` dropped events. + notify_tx: Notify, +} + +pub mod error { + //! Watch error types. + + use std::fmt; + + /// Error produced when sending a value fails. + #[derive(Debug)] + pub struct SendError(pub T); + + // ===== impl SendError ===== + + impl fmt::Display for SendError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl std::error::Error for SendError {} + + /// Error produced when receiving a change notification. + #[derive(Debug, Clone)] + pub struct RecvError(pub(super) ()); + + // ===== impl RecvError ===== + + impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl std::error::Error for RecvError {} +} + +use self::state::{AtomicState, Version}; +mod state { + use std::sync::atomic::{AtomicUsize, Ordering}; + + const CLOSED: usize = 1; + + /// The version part of the state. The lowest bit is always zero. + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(super) struct Version(usize); + + /// Snapshot of the state. The first bit is used as the CLOSED bit. + /// The remaining bits are used as the version. + /// + /// The CLOSED bit tracks whether the Sender has been dropped. Dropping all + /// receivers does not set it. + #[derive(Copy, Clone, Debug)] + pub(super) struct StateSnapshot(usize); + + /// The state stored in an atomic integer. + #[derive(Debug)] + pub(super) struct AtomicState(AtomicUsize); + + impl Version { + /// Get the initial version when creating the channel. + pub(super) fn initial() -> Self { + Version(0) + } + } + + impl StateSnapshot { + /// Extract the version from the state. + pub(super) fn version(self) -> Version { + Version(self.0 & !CLOSED) + } + + /// Is the closed bit set? + pub(super) fn is_closed(self) -> bool { + (self.0 & CLOSED) == CLOSED + } + } + + impl AtomicState { + /// Create a new `AtomicState` that is not closed and which has the + /// version set to `Version::initial()`. + pub(super) fn new() -> Self { + AtomicState(AtomicUsize::new(0)) + } + + /// Load the current value of the state. + pub(super) fn load(&self) -> StateSnapshot { + StateSnapshot(self.0.load(Ordering::SeqCst)) + } + + /// Increment the version counter. + pub(super) fn increment_version(&self) { + // Increment by two to avoid touching the CLOSED bit. + self.0.fetch_add(2, Ordering::SeqCst); + } + + /// Set the closed bit in the state. + pub(super) fn set_closed(&self) { + self.0.fetch_or(CLOSED, Ordering::SeqCst); + } + } +} + +/// Creates a new watch channel, returning the "send" and "receive" handles. +/// +/// All values sent by [`Sender`] will become visible to the [`Receiver`] handles. +/// Only the last value sent is made available to the [`Receiver`] half. All +/// intermediate values are dropped. +/// +/// # Examples +/// +/// ```text +/// use crate::sync::watch; +/// +/// # async fn dox() -> Result<(), Box> { +/// let (tx, mut rx) = watch::channel("hello"); +/// +/// tokio::spawn(async move { +/// while rx.changed().await.is_ok() { +/// println!("received = {:?}", *rx.borrow()); +/// } +/// }); +/// +/// tx.send("world")?; +/// # Ok(()) +/// # } +/// ``` +/// +/// [`Sender`]: struct@Sender +/// [`Receiver`]: struct@Receiver +pub fn channel(init: T) -> (Sender, Receiver) { + let shared = Arc::new(Shared { + value: RwLock::new(init), + state: AtomicState::new(), + ref_count_rx: AtomicUsize::new(1), + notify_rx: Notify::new(), + notify_tx: Notify::new(), + }); + + let tx = Sender { shared: shared.clone() }; + + let rx = Receiver { + shared, + version: Version::initial(), + }; + + (tx, rx) +} + +impl Receiver { + fn from_shared(version: Version, shared: Arc>) -> Self { + // No synchronization necessary as this is only used as a counter and + // not memory access. + shared.ref_count_rx.fetch_add(1, Ordering::SeqCst); + + Self { shared, version } + } + + /// Returns a reference to the most recently sent value. + /// + /// This method does not mark the returned value as seen, so future calls to + /// [`changed`] may return immediately even if you have already seen the + /// value with a call to `borrow`. + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// The priority policy of the lock is dependent on the underlying lock + /// implementation, and this type does not guarantee that any particular policy + /// will be used. In particular, a producer which is waiting to acquire the lock + /// in `send` might or might not block concurrent calls to `borrow`, e.g.: + /// + ///
Potential deadlock example + /// + /// ```text + /// // Task 1 (on thread A) | // Task 2 (on thread B) + /// let _ref1 = rx.borrow(); | + /// | // will block + /// | let _ = tx.send(()); + /// // may deadlock | + /// let _ref2 = rx.borrow(); | + /// ``` + ///
+ /// + /// [`changed`]: Receiver::changed + /// + /// # Examples + /// + /// ```text + /// use crate::sync::watch; + /// + /// let (_, rx) = watch::channel("hello"); + /// assert_eq!(*rx.borrow(), "hello"); + /// ``` + pub fn borrow(&self) -> Ref<'_, T> { + let inner = self.shared.value.blocking_read(); + Ref { inner } + } + + /// Returns a reference to the most recently sent value and mark that value + /// as seen. + /// + /// This method marks the value as seen, so [`changed`] will not return + /// immediately if the newest value is one previously returned by + /// `borrow_and_update`. + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// The priority policy of the lock is dependent on the underlying lock + /// implementation, and this type does not guarantee that any particular policy + /// will be used. In particular, a producer which is waiting to acquire the lock + /// in `send` might or might not block concurrent calls to `borrow`, e.g.: + /// + ///
Potential deadlock example + /// + /// ```text + /// // Task 1 (on thread A) | // Task 2 (on thread B) + /// let _ref1 = rx1.borrow_and_update(); | + /// | // will block + /// | let _ = tx.send(()); + /// // may deadlock | + /// let _ref2 = rx2.borrow_and_update(); | + /// ``` + ///
+ /// + /// [`changed`]: Receiver::changed + pub fn borrow_and_update(&mut self) -> Ref<'_, T> { + let inner = self.shared.value.blocking_read(); + self.version = self.shared.state.load().version(); + Ref { inner } + } + + /// Checks if this channel contains a message that this receiver has not yet + /// seen. The new value is not marked as seen. + /// + /// Although this method is called `has_changed`, it does not check new + /// messages for equality, so this call will return true even if the new + /// message is equal to the old message. + /// + /// Returns an error if the channel has been closed. + /// # Examples + /// + /// ```text + /// use crate::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = watch::channel("hello"); + /// + /// tx.send("goodbye").unwrap(); + /// + /// assert!(rx.has_changed().unwrap()); + /// assert_eq!(*rx.borrow_and_update(), "goodbye"); + /// + /// // The value has been marked as seen + /// assert!(!rx.has_changed().unwrap()); + /// + /// drop(tx); + /// // The `tx` handle has been dropped + /// assert!(rx.has_changed().is_err()); + /// } + /// ``` + pub fn has_changed(&self) -> Result { + // Load the version from the state + let state = self.shared.state.load(); + if state.is_closed() { + // The sender has dropped. + return Err(error::RecvError(())); + } + let new_version = state.version(); + + Ok(self.version != new_version) + } + + /// Waits for a change notification, then marks the newest value as seen. + /// + /// If the newest value in the channel has not yet been marked seen when + /// this method is called, the method marks that value seen and returns + /// immediately. If the newest value has already been marked seen, then the + /// method sleeps until a new message is sent by the [`Sender`] connected to + /// this `Receiver`, or until the [`Sender`] is dropped. + /// + /// This method returns an error if and only if the [`Sender`] is dropped. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If you use it as the event in a + /// `tokio::select!` statement and some other branch + /// completes first, then it is guaranteed that no values have been marked + /// seen by this call to `changed`. + /// + /// [`Sender`]: struct@Sender + /// + /// # Examples + /// + /// ```text + /// use crate::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = watch::channel("hello"); + /// + /// tokio::spawn(async move { + /// tx.send("goodbye").unwrap(); + /// }); + /// + /// assert!(rx.changed().await.is_ok()); + /// assert_eq!(*rx.borrow(), "goodbye"); + /// + /// // The `tx` handle has been dropped + /// assert!(rx.changed().await.is_err()); + /// } + /// ``` + pub async fn changed(&mut self) -> Result<(), error::RecvError> { + changed_impl(&self.shared, &mut self.version).await + } + + /// Waits for a value that satisfies the provided condition. + /// + /// This method will call the provided closure whenever something is sent on + /// the channel. Once the closure returns `true`, this method will return a + /// reference to the value that was passed to the closure. + /// + /// Before `wait_for` starts waiting for changes, it will call the closure + /// on the current value. If the closure returns `true` when given the + /// current value, then `wait_for` will immediately return a reference to + /// the current value. This is the case even if the current value is already + /// considered seen. + /// + /// The watch channel only keeps track of the most recent value, so if + /// several messages are sent faster than `wait_for` is able to call the + /// closure, then it may skip some updates. Whenever the closure is called, + /// it will be called with the most recent value. + /// + /// When this function returns, the value that was passed to the closure + /// when it returned `true` will be considered seen. + /// + /// If the channel is closed, then `wait_for` will return a `RecvError`. + /// Once this happens, no more messages can ever be sent on the channel. + /// When an error is returned, it is guaranteed that the closure has been + /// called on the last value, and that it returned `false` for that value. + /// (If the closure returned `true`, then the last value would have been + /// returned instead of the error.) + /// + /// Like the `borrow` method, the returned borrow holds a read lock on the + /// inner value. This means that long-lived borrows could cause the producer + /// half to block. It is recommended to keep the borrow as short-lived as + /// possible. See the documentation of `borrow` for more information on + /// this. + /// + /// [`Receiver::changed()`]: crate::sync::watch::Receiver::changed + /// + /// # Examples + /// + /// ```ignore + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// + /// async fn main() { + /// let (tx, _rx) = watch::channel("hello"); + /// + /// tx.send("goodbye").unwrap(); + /// + /// // here we subscribe to a second receiver + /// // now in case of using `changed` we would have + /// // to first check the current value and then wait + /// // for changes or else `changed` would hang. + /// let mut rx2 = tx.subscribe(); + /// + /// // in place of changed we have use `wait_for` + /// // which would automatically check the current value + /// // and wait for changes until the closure returns true. + /// assert!(rx2.wait_for(|val| *val == "goodbye").await.is_ok()); + /// assert_eq!(*rx2.borrow(), "goodbye"); + /// } + /// ``` + pub async fn wait_for(&mut self, mut f: impl FnMut(&T) -> bool) -> Result, error::RecvError> { + let mut closed = false; + loop { + { + let inner = self.shared.value.blocking_read(); + + let new_version = self.shared.state.load().version(); + let has_changed = self.version != new_version; + self.version = new_version; + + if !closed || has_changed { + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| f(&inner))); + match result { + Ok(true) => { + return Ok(Ref { inner }); + } + Ok(false) => { + // Skip the value. + } + Err(panicked) => { + // Drop the read-lock to avoid poisoning it. + drop(inner); + // Forward the panic to the caller. + panic::resume_unwind(panicked); + // Unreachable + } + } + } + } + + if closed { + return Err(error::RecvError(())); + } + + // Wait for the value to change. + closed = changed_impl(&self.shared, &mut self.version).await.is_err(); + } + } + + /// Returns `true` if receivers belong to the same channel. + /// + /// # Examples + /// + /// ```text + /// let (tx, rx) = crate::sync::watch::channel(true); + /// let rx2 = rx.clone(); + /// assert!(rx.same_channel(&rx2)); + /// + /// let (tx3, rx3) = crate::sync::watch::channel(true); + /// assert!(!rx3.same_channel(&rx2)); + /// ``` + pub fn same_channel(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.shared, &other.shared) + } +} + +fn maybe_changed(shared: &Shared, version: &mut Version) -> Option> { + // Load the version from the state + let state = shared.state.load(); + let new_version = state.version(); + + tracing::trace!("maybe_changed with version {:?} and {:?}", *version, new_version); + + if *version != new_version { + // Observe the new version and return + *version = new_version; + return Some(Ok(())); + } + + if state.is_closed() { + // All receivers have dropped. + return Some(Err(error::RecvError(()))); + } + + None +} + +async fn changed_impl(shared: &Shared, version: &mut Version) -> Result<(), error::RecvError> { + loop { + // In order to avoid a race condition, we first request a notification, + // **then** check the current value's version. If a new version exists, + // the notification request is dropped. + let notified = shared.notify_rx.notified(); + + if let Some(ret) = maybe_changed(shared, version) { + return ret; + } + + notified.await; + // loop around again in case the wake-up was spurious + } +} + +impl Clone for Receiver { + fn clone(&self) -> Self { + let version = self.version; + let shared = self.shared.clone(); + + Self::from_shared(version, shared) + } +} + +impl Drop for Receiver { + fn drop(&mut self) { + // No synchronization necessary as this is only used as a counter and + // not memory access. + tracing::trace!( + "dropping {:p} with count {:?}", + self, + self.shared.ref_count_rx.load(Ordering::SeqCst) + ); + if 1 == self.shared.ref_count_rx.fetch_sub(1, Ordering::SeqCst) { + // This is the last `Receiver` handle, tasks waiting on `Sender::closed()` + self.shared.notify_tx.notify_waiters(); + } + } +} + +impl Sender { + /// Sends a new value via the channel, notifying all receivers. + /// + /// This method fails if the channel has been closed, which happens when + /// every receiver has been dropped. + pub fn send(&self, value: T) -> Result<(), error::SendError> { + tracing::trace!( + "watch {:p} send value with receiver count {:?}", + &self.shared, + self.receiver_count() + ); + // This is pretty much only useful as a hint anyway, so synchronization isn't critical. + if 0 == self.receiver_count() { + return Err(error::SendError(value)); + } + + self.send_replace(value); + Ok(()) + } + + /// Modifies the watched value **unconditionally** in-place, + /// notifying all receivers. + /// + /// This can useful for modifying the watched value, without + /// having to allocate a new instance. Additionally, this + /// method permits sending values even when there are no receivers. + /// + /// Prefer to use the more versatile function [`Self::send_if_modified()`] + /// if the value is only modified conditionally during the mutable borrow + /// to prevent unneeded change notifications for unmodified values. + /// + /// # Panics + /// + /// This function panics when the invocation of the `modify` closure panics. + /// No receivers are notified when panicking. All changes of the watched + /// value applied by the closure before panicking will be visible in + /// subsequent calls to `borrow`. + /// + /// # Examples + /// + /// ```text + /// use crate::sync::watch; + /// + /// struct State { + /// counter: usize, + /// } + /// let (state_tx, state_rx) = watch::channel(State { counter: 0 }); + /// state_tx.send_modify(|state| state.counter += 1); + /// assert_eq!(state_rx.borrow().counter, 1); + /// ``` + pub fn send_modify(&self, modify: F) + where + F: FnOnce(&mut T), + { + self.send_if_modified(|value| { + modify(value); + true + }); + } + + /// Modifies the watched value **conditionally** in-place, + /// notifying all receivers only if modified. + /// + /// This can useful for modifying the watched value, without + /// having to allocate a new instance. Additionally, this + /// method permits sending values even when there are no receivers. + /// + /// The `modify` closure must return `true` if the value has actually + /// been modified during the mutable borrow. It should only return `false` + /// if the value is guaranteed to be unnmodified despite the mutable + /// borrow. + /// + /// Receivers are only notified if the closure returned `true`. If the + /// closure has modified the value but returned `false` this results + /// in a *silent modification*, i.e. the modified value will be visible + /// in subsequent calls to `borrow`, but receivers will not receive + /// a change notification. + /// + /// Returns the result of the closure, i.e. `true` if the value has + /// been modified and `false` otherwise. + /// + /// # Panics + /// + /// This function panics when the invocation of the `modify` closure panics. + /// No receivers are notified when panicking. All changes of the watched + /// value applied by the closure before panicking will be visible in + /// subsequent calls to `borrow`. + /// + /// # Examples + /// + /// ```text + /// use crate::sync::watch; + /// + /// struct State { + /// counter: usize, + /// } + /// let (state_tx, mut state_rx) = watch::channel(State { counter: 1 }); + /// let inc_counter_if_odd = |state: &mut State| { + /// if state.counter % 2 == 1 { + /// state.counter += 1; + /// return true; + /// } + /// false + /// }; + /// + /// assert_eq!(state_rx.borrow().counter, 1); + /// + /// assert!(!state_rx.has_changed().unwrap()); + /// assert!(state_tx.send_if_modified(inc_counter_if_odd)); + /// assert!(state_rx.has_changed().unwrap()); + /// assert_eq!(state_rx.borrow_and_update().counter, 2); + /// + /// assert!(!state_rx.has_changed().unwrap()); + /// assert!(!state_tx.send_if_modified(inc_counter_if_odd)); + /// assert!(!state_rx.has_changed().unwrap()); + /// assert_eq!(state_rx.borrow_and_update().counter, 2); + /// ``` + pub fn send_if_modified(&self, modify: F) -> bool + where + F: FnOnce(&mut T) -> bool, + { + { + // Acquire the write lock and update the value. + let mut lock = self.shared.value.blocking_write(); + + // Update the value and catch possible panic inside func. + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| modify(&mut lock))); + match result { + Ok(modified) => { + if !modified { + // Abort, i.e. don't notify receivers if unmodified + return false; + } + // Continue if modified + } + Err(panicked) => { + // Drop the lock to avoid poisoning it. + drop(lock); + // Forward the panic to the caller. + panic::resume_unwind(panicked); + // Unreachable + } + } + + self.shared.state.increment_version(); + + // Release the write lock. + // + // Incrementing the version counter while holding the lock ensures + // that receivers are able to figure out the version number of the + // value they are currently looking at. + drop(lock); + } + self.shared.notify_rx.notify_waiters(); + true + } + + /// Sends a new value via the channel, notifying all receivers and returning + /// the previous value in the channel. + /// + /// This can be useful for reusing the buffers inside a watched value. + /// Additionally, this method permits sending values even when there are no + /// receivers. + /// + /// # Examples + /// + /// ```text + /// use crate::sync::watch; + /// + /// let (tx, _rx) = watch::channel(1); + /// assert_eq!(tx.send_replace(2), 1); + /// assert_eq!(tx.send_replace(3), 2); + /// ``` + pub fn send_replace(&self, mut value: T) -> T { + // swap old watched value with the new one + self.send_modify(|old| mem::swap(old, &mut value)); + + value + } + + /// Returns a reference to the most recently sent value + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// # Examples + /// + /// ```text + /// use crate::sync::watch; + /// + /// let (tx, _) = watch::channel("hello"); + /// assert_eq!(*tx.borrow(), "hello"); + /// ``` + pub fn borrow(&self) -> Ref<'_, T> { + let inner = self.shared.value.blocking_read(); + Ref { inner } + } + + /// Checks if the channel has been closed. This happens when all receivers + /// have dropped. + /// + /// # Examples + /// + /// ```text + /// let (tx, rx) = crate::sync::watch::channel(()); + /// assert!(!tx.is_closed()); + /// + /// drop(rx); + /// assert!(tx.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.receiver_count() == 0 + } + + /// Completes when all receivers have dropped. + /// + /// This allows the producer to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// + /// # Examples + /// + /// ```text + /// use crate::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = watch::channel("hello"); + /// + /// tokio::spawn(async move { + /// // use `rx` + /// drop(rx); + /// }); + /// + /// // Waits for `rx` to drop + /// tx.closed().await; + /// println!("the `rx` handles dropped") + /// } + /// ``` + pub async fn closed(&self) { + while self.receiver_count() > 0 { + let notified = self.shared.notify_tx.notified(); + + if self.receiver_count() == 0 { + return; + } + + notified.await; + // The channel could have been reopened in the meantime by calling + // `subscribe`, so we loop again. + } + } + + /// Creates a new [`Receiver`] connected to this `Sender`. + /// + /// All messages sent before this call to `subscribe` are initially marked + /// as seen by the new `Receiver`. + /// + /// This method can be called even if there are no other receivers. In this + /// case, the channel is reopened. + /// + /// # Examples + /// + /// The new channel will receive messages sent on this `Sender`. + /// + /// ```text + /// use crate::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = watch::channel(0u64); + /// + /// tx.send(5).unwrap(); + /// + /// let rx = tx.subscribe(); + /// assert_eq!(5, *rx.borrow()); + /// + /// tx.send(10).unwrap(); + /// assert_eq!(10, *rx.borrow()); + /// } + /// ``` + /// + /// The most recent message is considered seen by the channel, so this test + /// is guaranteed to pass. + /// + /// ```text + /// use crate::sync::watch; + /// use tokio::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = watch::channel(0u64); + /// tx.send(5).unwrap(); + /// let mut rx = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// // by spawning and sleeping, the message is sent after `main` + /// // hits the call to `changed`. + /// # if false { + /// tokio::time::sleep(Duration::from_millis(10)).await; + /// # } + /// tx.send(100).unwrap(); + /// }); + /// + /// rx.changed().await.unwrap(); + /// assert_eq!(100, *rx.borrow()); + /// } + /// ``` + pub fn subscribe(&self) -> Receiver { + let shared = self.shared.clone(); + let version = shared.state.load().version(); + + // The CLOSED bit in the state tracks only whether the sender is + // dropped, so we do not need to unset it if this reopens the channel. + Receiver::from_shared(version, shared) + } + + /// Returns the number of receivers that currently exist. + /// + /// # Examples + /// + /// ```text + /// use crate::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx1) = watch::channel("hello"); + /// + /// assert_eq!(1, tx.receiver_count()); + /// + /// let mut _rx2 = rx1.clone(); + /// + /// assert_eq!(2, tx.receiver_count()); + /// } + /// ``` + pub fn receiver_count(&self) -> usize { + self.shared.ref_count_rx.load(Ordering::SeqCst) + } +} + +impl Drop for Sender { + fn drop(&mut self) { + self.shared.state.set_closed(); + self.shared.notify_rx.notify_waiters(); + } +} + +// ===== impl Ref ===== + +impl ops::Deref for Ref<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + &self.inner + } +} + +#[cfg(test)] +mod tests { + use crate::sync::watch; + use futures::FutureExt; + use shuttle::future; + use shuttle::thread; + + // test for https://github.com/tokio-rs/tokio/issues/3168 + #[test] + fn watch_spurious_wakeup() { + shuttle::check_random( + || { + let (send, mut recv) = watch::channel(0i32); + + send.send(1).unwrap(); + + let send_thread = thread::spawn(move || { + send.send(2).unwrap(); + send + }); + + recv.changed().now_or_never(); + + let send = send_thread.join().unwrap(); + let recv_thread = thread::spawn(move || { + recv.changed().now_or_never(); + recv.changed().now_or_never(); + recv + }); + + send.send(3).unwrap(); + + let mut recv = recv_thread.join().unwrap(); + let send_thread = thread::spawn(move || { + send.send(2).unwrap(); + }); + + recv.changed().now_or_never(); + + send_thread.join().unwrap(); + }, + 500_000, + ); + } + + #[test] + fn watch_borrow() { + shuttle::check_random( + || { + let (send, mut recv) = watch::channel(0i32); + + assert!(send.borrow().eq(&0)); + assert!(recv.borrow().eq(&0)); + + send.send(1).unwrap(); + assert!(send.borrow().eq(&1)); + + let send_thread = thread::spawn(move || { + send.send(2).unwrap(); + send + }); + + recv.changed().now_or_never(); + + let send = send_thread.join().unwrap(); + let recv_thread = thread::spawn(move || { + recv.changed().now_or_never(); + recv.changed().now_or_never(); + recv + }); + + send.send(3).unwrap(); + + let recv = recv_thread.join().unwrap(); + assert!(recv.borrow().eq(&3)); + assert!(send.borrow().eq(&3)); + + send.send(2).unwrap(); + + thread::spawn(move || { + assert!(recv.borrow().eq(&2)); + }); + assert!(send.borrow().eq(&2)); + }, + 500_000, + ); + } + + /// This test exposes an issue when [`shuttle::sync::RwLock`] is used in the [`watch`] implementation. + /// + /// The issue is that the concurrent evaluation of select guards can break internal invariants of `RwLock`. + /// Concretely, the read in `borrow_and_update` puts the caller on the `RwLock` waiter list. But then + /// the other select arm can wake up the caller while the watch send holds the write lock. This breaks + /// the expectation in `RwLock` that waiters are only woken up when they are granted the lock, and + /// results in the following panic. + /// + /// ```text + /// resumed a waiting Read thread while the lock was in state Write(main-thread(2)) + /// ``` + /// + /// See commit `751a0433`, which switched from [`shuttle::sync::RwLock`] to [`crate::sync::RwLock`]. + #[test] + fn watch_with_select() { + use crate::sync::mpsc; + + shuttle::check_random( + || { + future::block_on(async move { + let (watch_tx, mut watch_rx) = watch::channel(()); + let (mpsc_tx, mut mpsc_rx) = mpsc::unbounded_channel(); + + let h1 = future::spawn(async move { + tokio::select! { + biased; + _ = mpsc_rx.recv() => {} + _ = async { watch_rx.borrow_and_update() } => {} + } + }); + + let h2 = future::spawn(async move { + let _ = watch_tx.send(()); + }); + + let h3 = future::spawn(async move { + let _ = mpsc_tx.send(()); + }); + + futures::future::join_all([h1, h2, h3]).await; + }); + }, + 10_000, + ); + } +} diff --git a/wrappers/tokio/impls/tokio/inner/src/task.rs b/wrappers/tokio/impls/tokio/inner/src/task.rs new file mode 100644 index 00000000..66ef97b1 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/task.rs @@ -0,0 +1,605 @@ +use crate::runtime::Handle; +use pin_project::pin_project; +use shuttle::current::remove_label_for_task; +use shuttle::current::{me, set_label_for_task, ChildLabelFn, TaskName}; +use shuttle::future::spawn_local; +use shuttle::scheduler::TaskId; +use std::any::Any; +use std::error::Error; +use std::fmt::{Display, Formatter}; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +pub use shuttle::current::TaskId as Id; +pub use shuttle::future::yield_now; + +#[doc(hidden)] +#[deprecated = "Moved to shuttle_tokio_impl_inner::task::coop::consume_budget"] +pub use coop::consume_budget; +#[doc(hidden)] +#[deprecated = "Moved to shuttle_tokio_impl_inner::task::coop::unconstrained"] +pub use coop::unconstrained; +#[doc(hidden)] +#[deprecated = "Moved to shuttle_tokio_impl_inner::task::coop::Unconstrained"] +pub use coop::Unconstrained; + +// TODO: Implement. Only exists in order to get compilation to pass, should not actually be used. +pub mod futures { + pub use tokio::task::futures::TaskLocalFuture; +} + +/// Returns the [`Id`] of the currently running task. +pub fn id() -> Id { + shuttle::current::get_current_task().unwrap() +} + +/// Returns the [`Id`] of the currently running task, or `None` if called outside +/// of a task. +pub fn try_id() -> Option { + shuttle::current::get_current_task() +} + +/// Spawns a future onto the runtime +pub fn spawn(future: F) -> JoinHandle +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + let rt = crate::runtime::Handle::current(); + rt.spawn(future) +} + +// TODO: See comment in `runtime::Handle::spawn_blocking` +pub fn spawn_blocking(func: F) -> JoinHandle +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + let rt = crate::runtime::Handle::current(); + rt.spawn_blocking(func) +} + +/// A wrapper around the `JoinHandle` found in Shuttle in order to implement the full `JoinError` API. +#[derive(Debug)] +#[pin_project] +pub struct JoinHandle { + #[pin] + inner: shuttle::future::JoinHandle, +} + +impl JoinHandle { + pub(crate) fn new(inner: shuttle::future::JoinHandle) -> Self { + Self { inner } + } + + pub fn abort(&self) { + self.inner.abort(); + } + + pub fn is_finished(&self) -> bool { + self.inner.is_finished() + } + + pub fn abort_handle(&self) -> AbortHandle { + AbortHandle { + inner: self.inner.abort_handle(), + } + } +} + +/// Task failed to execute to completion. +#[derive(Debug)] +pub struct JoinError { + repr: Repr, +} + +impl Display for JoinError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.repr.fmt(f) + } +} + +impl Error for JoinError {} + +// NOTE: Remember to reimplement `Drop` here if this is ever fully moved out of Shuttle. + +impl Future for JoinHandle { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + this.inner.poll(cx).map_err(std::convert::Into::into) + } +} + +impl From for JoinError { + fn from(error: shuttle::future::JoinError) -> Self { + match error { + shuttle::future::JoinError::Cancelled => Self { repr: Repr::Cancelled }, + } + } +} + +impl JoinError { + /// Returns true if the error was caused by the task being cancelled. + pub fn is_cancelled(&self) -> bool { + matches!(self.repr, Repr::Cancelled) + } + + /// Returns true if the error was caused by the task panicking. + pub fn is_panic(&self) -> bool { + matches!(self.repr, Repr::Panic) + } + + pub fn into_panic(self) -> Box { + unimplemented!() + } + + pub fn try_into_panic(self) -> Result, JoinError> { + unimplemented!() + } + + pub fn id(&self) -> Id { + unimplemented!() + } +} + +#[allow(unused)] +#[derive(Debug)] +enum Repr { + /// Task was aborted + Cancelled, + /// Task panicked + Panic, +} + +impl Display for Repr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Repr::Cancelled => write!(f, "task was cancelled"), + Repr::Panic => write!(f, "task panicked"), + } + } +} + +pub mod coop { + use pin_project::pin_project; + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + + /// Future for the [`unconstrained`](unconstrained) method. + #[must_use = "Unconstrained does nothing unless polled"] + #[pin_project] + pub struct Unconstrained { + #[pin] + inner: F, + } + + impl Future for Unconstrained + where + F: Future, + { + type Output = ::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let inner = self.project().inner; + inner.poll(cx) + } + } + + /// Under Shuttle this is always a thin wrapper around the `inner` which just + /// forwards the `poll.` + pub fn unconstrained(inner: F) -> Unconstrained { + Unconstrained { inner } + } + + /// Under Shuttle this is always just a `yield_now`. + pub async fn consume_budget() { + shuttle::future::yield_now().await; + } + + /// Not faithfully modelled; always returns `true`. + /// Always returning `false` would be equally correct. The reason for always returning + /// `true` is that an implementation which follows the "do work until done or budget exhausted" + /// becomes "do work until done" with `true` and "do work once" with `false`. Both versions would + /// require additional implementation to readd the "budget exhausted" logic, but the `false` case + /// would also require an override of `has_budged_remaining` to readd the "do work until done" part. + pub fn has_budget_remaining() -> bool { + true + } +} + +#[derive(Debug, Clone)] +pub struct AbortHandle { + inner: shuttle::future::AbortHandle, +} + +impl Drop for AbortHandle { + fn drop(&mut self) {} +} + +impl AbortHandle { + pub fn abort(&self) { + self.inner.abort(); + } + + pub fn is_finished(&self) -> bool { + self.inner.is_finished() + } +} + +pub use join_set::JoinSet; + +pub mod join_set { + use crate::task::{AbortHandle, Handle, JoinError, JoinHandle}; + use ::futures::stream::{FuturesUnordered, StreamExt}; + use std::fmt; + use std::future::Future; + + pub struct JoinSet { + inner: FuturesUnordered>, + } + + #[must_use = "builders do nothing unless used to spawn a task"] + pub struct Builder<'a, T> { + joinset: &'a mut JoinSet, + builder: super::Builder<'a>, + } + + impl JoinSet { + pub fn build_task(&mut self) -> Builder<'_, T> { + Builder { + builder: super::Builder::new(), + joinset: self, + } + } + + /// Create a new `JoinSet`. + pub fn new() -> Self { + Self { + inner: FuturesUnordered::new(), + } + } + + /// Returns the number of tasks currently in the `JoinSet`. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Returns whether the `JoinSet` is empty. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + } + + impl fmt::Debug for JoinSet { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JoinSet").field("len", &self.len()).finish() + } + } + + impl Drop for JoinSet { + fn drop(&mut self) { + for jh in &self.inner { + jh.abort(); + } + } + } + + impl Default for JoinSet { + fn default() -> Self { + Self::new() + } + } + + impl JoinSet { + /// Spawn the provided task on the `JoinSet`, returning an [`AbortHandle`] + /// that can be used to remotely cancel the task. + /// + /// The provided future will start running in the background immediately + /// when this method is called, even if you don't await anything on this + /// `JoinSet`. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + /// + /// [`AbortHandle`]: crate::task::AbortHandle + #[track_caller] + pub fn spawn(&mut self, task: F) -> AbortHandle + where + F: Future, + F: Send + 'static, + T: Send, + { + self.insert(crate::spawn(task)) + } + + #[track_caller] + pub fn spawn_on(&mut self, task: F, handle: &Handle) -> AbortHandle + where + F: Future, + F: Send + 'static, + T: Send, + { + self.insert(handle.spawn(task)) + } + + #[track_caller] + pub fn spawn_local(&mut self, task: F) -> AbortHandle + where + F: Future, + F: 'static, + { + self.insert(JoinHandle::new(crate::task::spawn_local(task))) + } + + fn insert(&mut self, jh: JoinHandle) -> AbortHandle { + let abort = jh.abort_handle(); + self.inner.push(jh); + + abort + } + + pub async fn join_next(&mut self) -> Option> { + self.inner.next().await + } + } + + impl std::iter::FromIterator for JoinSet + where + F: Future, + F: Send + 'static, + T: Send + 'static, + { + fn from_iter>(iter: I) -> Self { + let mut set = Self::new(); + iter.into_iter().for_each(|task| { + set.spawn(task); + }); + set + } + } + + impl<'a, T: 'static> Builder<'a, T> { + /// Assigns a name to the task which will be spawned. + pub fn name(self, name: &'a str) -> Self { + let builder = self.builder.name(name); + Self { builder, ..self } + } + + /// Spawn the provided task with this builder's settings and store it in the + /// [`JoinSet`], returning an [`AbortHandle`] that can be used to remotely + /// cancel the task. + /// + /// # Returns + /// + /// An [`AbortHandle`] that can be used to remotely cancel the task. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + /// + /// [`AbortHandle`]: crate::task::AbortHandle + #[track_caller] + pub fn spawn(self, future: F) -> std::io::Result + where + F: Future, + F: Send + 'static, + T: Send, + { + Ok(self.joinset.insert(self.builder.spawn(future)?)) + } + + /// Spawn the provided task on the provided [runtime handle] with this + /// builder's settings, and store it in the [`JoinSet`]. + /// + /// # Returns + /// + /// An [`AbortHandle`] that can be used to remotely cancel the task. + /// + /// + /// [`AbortHandle`]: crate::task::AbortHandle + /// [runtime handle]: crate::runtime::Handle + #[track_caller] + pub fn spawn_on(self, future: F, handle: &Handle) -> std::io::Result + where + F: Future, + F: Send + 'static, + T: Send, + { + Ok(self.joinset.insert(self.builder.spawn_on(future, handle)?)) + } + + /// Spawn the blocking code on the blocking threadpool with this builder's + /// settings, and store it in the [`JoinSet`]. + /// + /// # Returns + /// + /// An [`AbortHandle`] that can be used to remotely cancel the task. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + /// + /// [`JoinSet`]: crate::task::JoinSet + /// [`AbortHandle`]: crate::task::AbortHandle + #[track_caller] + pub fn spawn_blocking(self, f: F) -> std::io::Result + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send, + { + Ok(self.joinset.insert(self.builder.spawn_blocking(f)?)) + } + + /// Spawn the blocking code on the blocking threadpool of the provided + /// runtime handle with this builder's settings, and store it in the + /// [`JoinSet`]. + /// + /// # Returns + /// + /// An [`AbortHandle`] that can be used to remotely cancel the task. + /// + /// [`JoinSet`]: crate::task::JoinSet + /// [`AbortHandle`]: crate::task::AbortHandle + #[track_caller] + pub fn spawn_blocking_on(self, f: F, handle: &Handle) -> std::io::Result + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send, + { + Ok(self.joinset.insert(self.builder.spawn_blocking_on(f, handle)?)) + } + } + + // Manual `Debug` impl so that `Builder` is `Debug` regardless of whether `T` is + // `Debug`. + impl<'a, T> fmt::Debug for Builder<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("join_set::Builder") + .field("joinset", &self.joinset) + .field("builder", &self.builder) + .finish() + } + } +} + +// `Builder` is unstable in Tokio +#[derive(Default, Debug)] +pub struct Builder<'a> { + name: Option<&'a str>, +} + +/// Clears the `ChildLabelFn` label for the `task_id` on drop. Used so that we don't keep naming subsequent spawns as well. +struct RestoreChildLabelFnOnDrop { + task_id: TaskId, + child_fn: Option, +} + +impl Drop for RestoreChildLabelFnOnDrop { + fn drop(&mut self) { + if let Some(func) = &self.child_fn { + set_label_for_task(self.task_id, func.clone()); + } else { + remove_label_for_task::(self.task_id); + } + } +} + +impl<'a> Builder<'a> { + /// Creates a new task builder. + pub fn new() -> Self { + Self::default() + } + + /// Assigns a name to the task which will be spawned. + pub fn name(&self, name: &'a str) -> Self { + Self { name: Some(name) } + } + + /// Makes the next spawned tasks become named if there is no current `ChildLabelFnOnDrop`. + /// If the next task will be named, a `Some(RestoreChildLabelFnOnDrop)` will be returned such that only the next spawn is named. + #[must_use] + fn handle_naming(&self) -> Option { + let me = me(); + + if let Some(name) = self.name { + let old_fn = remove_label_for_task::(me); + let old_fn_cloned = old_fn.clone(); + let name = TaskName::from(name); + set_label_for_task( + me, + #[allow(clippy::arc_with_non_send_sync)] + ChildLabelFn(Arc::new(move |task_id, labels| { + // If there is a `ChildLabelFn` set, then execute that. + if let Some(func) = &old_fn_cloned { + func.0(task_id, labels); + } + + // Update the name + labels.insert(name.clone()); + })), + ); + + // If we have set the `ChildLabelFn`, then we should restore it after. If not then we would name any subsequent spawns as well. + Some(RestoreChildLabelFnOnDrop { + task_id: me, + child_fn: old_fn, + }) + } else { + None + } + } + + #[track_caller] + pub fn spawn(self, future: Fut) -> io::Result> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let handle = Handle::current(); + self.spawn_on(future, &handle) + } + + /// Spawn a task with this builder's settings on the provided [runtime + /// handle]. + /// + /// See [`Handle::spawn`] for more details. + /// + /// [runtime handle]: crate::runtime::Handle + /// [`Handle::spawn`]: crate::runtime::Handle::spawn + #[track_caller] + pub fn spawn_on(self, future: Fut, handle: &Handle) -> io::Result> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let _drop_guard = self.handle_naming(); + Ok(handle.spawn(future)) + } + + /// Spawns blocking code on the blocking threadpool. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + /// + /// See [`task::spawn_blocking`](crate::task::spawn_blocking) + /// for more details. + #[track_caller] + pub fn spawn_blocking(self, function: Function) -> io::Result> + where + Function: FnOnce() -> Output + Send + 'static, + Output: Send + 'static, + { + let handle = Handle::current(); + self.spawn_blocking_on(function, &handle) + } + + /// Spawns blocking code on the provided [runtime handle]'s blocking threadpool. + /// + /// See [`Handle::spawn_blocking`] for more details. + /// + /// [runtime handle]: crate::runtime::Handle + /// [`Handle::spawn_blocking`]: crate::runtime::Handle::spawn_blocking + #[track_caller] + pub fn spawn_blocking_on( + self, + function: Function, + handle: &Handle, + ) -> io::Result> + where + Function: FnOnce() -> Output + Send + 'static, + Output: Send + 'static, + { + let _drop_guard = self.handle_naming(); + Ok(handle.spawn_blocking(function)) + } +} diff --git a/wrappers/tokio/impls/tokio/inner/src/time.rs b/wrappers/tokio/impls/tokio/inner/src/time.rs new file mode 100644 index 00000000..04b2b1f9 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/src/time.rs @@ -0,0 +1,413 @@ +//! Shuttle stubs for tokio time utilities + +use pin_project::{pin_project, pinned_drop}; +use shuttle::current::{get_current_task, with_labels_for_task, Labels, TaskId}; +use std::collections::BTreeMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::Mutex; +use std::task::{Context, Poll, Waker}; +pub use std::time::Duration; +pub use tokio::time::Instant; +use tracing::trace; + +pub fn pause() {} + +// We don't attempt to model time directly. Instead, we provide hooks for +// test frameworks to selectively force timeouts. This is achieved by +// remembering the tag (if any) of the task that requested a timeout. + +// The following global table records (1) the set of all currently active +// timeouts, as well as (2) the set of all registered expiry triggers. + +thread_local! { + static TIMEOUT_TABLE: Map = Map::default(); +} + +#[derive(Default)] +struct Map { + inner: Mutex, +} + +#[derive(Default)] +struct InnerMap { + // Timeouts that are currently active. Each timeout is assigned a unique slot, which + // can be used to look up the timeout in the BTreeMap. + entries: BTreeMap, + // The next available slot + next_slot: usize, + // The set of triggers that have already expired + #[allow(clippy::type_complexity)] + triggers: Vec bool>>, +} + +#[derive(Debug)] +struct TimeoutEntry { + task_id: TaskId, + state: TimeoutState, +} + +#[derive(Debug)] +enum TimeoutState { + Waiting(Option), + Expired, +} + +// SAFETY: We are running single-threaded, and all accesses to the InnerMap are guarded +// by a std::sync::Mutex. +unsafe impl Send for Map {} +unsafe impl Sync for Map {} + +/// Expire all current and future timeouts requested by tasks whose tags match the +/// given predicate. +pub fn trigger_timeouts(trigger: F) +where + F: Fn(&Labels) -> bool + 'static, +{ + let wakers = TIMEOUT_TABLE.with(|table| { + let mut map = table.inner.lock().unwrap(); + let mut wakers = vec![]; + for TimeoutEntry { task_id, state } in map.entries.values_mut() { + let task_name = format!("{task_id:?}"); + with_labels_for_task(*task_id, |labels| { + if trigger(labels) { + trace!("triggering timeout for task {}", task_name); + match state { + TimeoutState::Expired => {} + TimeoutState::Waiting(waker) => { + if let Some(waker) = waker { + wakers.push(waker.clone()); + } + } + } + *state = TimeoutState::Expired; + } + }); + } + map.triggers.push(Box::new(trigger)); + drop(map); + wakers + }); + for waker in wakers { + waker.wake(); + } +} + +/// Clear all timeout triggers +pub fn clear_triggers() { + TIMEOUT_TABLE.with(|table| { + table.inner.lock().unwrap().triggers.clear(); + }); +} + +// The Future returned by a call to `timeout(future)`. +#[pin_project(PinnedDrop)] +pub struct Timeout { + slot: usize, + #[pin] + future: F, +} + +impl Future for Timeout +where + F: Future, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // We want expiry triggers to take effect as soon as possible. So we check + // if a timeout has been expired before polling the underlying future. + // If the future returns `Pending`, we poll it again in case we + // invoked the trigger while the future was being polled. + let slot = self.slot; + let timed_out = TIMEOUT_TABLE.with(|table| { + let map = table.inner.lock().unwrap(); + let entry = map.entries.get(&slot).unwrap(); + matches!(entry.state, TimeoutState::Expired) + }); + if timed_out { + return Poll::Ready(Err(error::Elapsed::new())); + } + let this = self.project(); + match this.future.poll(cx) { + Poll::Ready(r) => Poll::Ready(Ok(r)), + Poll::Pending => TIMEOUT_TABLE.with(|table| { + let mut map = table.inner.lock().unwrap(); + let TimeoutEntry { state, .. } = map.entries.get_mut(&slot).unwrap(); + match state { + TimeoutState::Expired => Poll::Ready(Err(error::Elapsed::new())), + TimeoutState::Waiting(waker) => { + *waker = Some(cx.waker().clone()); + Poll::Pending + } + } + }), + } + } +} + +#[pinned_drop] +impl PinnedDrop for Timeout { + fn drop(self: Pin<&mut Self>) { + // On drop, remove the associated entry from the `TIMEOUT_TABLE` + TIMEOUT_TABLE.with(|table| { + let mut map = table.inner.lock().unwrap(); + if let Some(e) = map.entries.remove(&self.slot) { + trace!("removed entry {:?} at slot {}", e, self.slot); + } + }); + } +} + +fn timeout_inner(future: F) -> Timeout +where + F: Future, +{ + let slot = { + let task_id = get_current_task().expect("TaskId should be defined"); + // We save the debug name of the task up-front, to avoid a double-borrow of the + // underlying Label storage in Shuttle (needed for `with_labels_for_task`). + let task_name = format!("{task_id:?}"); + with_labels_for_task(task_id, |labels| { + TIMEOUT_TABLE.with(|table| { + let mut map = table.inner.lock().unwrap(); + let state = if map.triggers.iter().any(|trigger| trigger(labels)) { + TimeoutState::Expired + } else { + TimeoutState::Waiting(None) + }; + let slot = map.next_slot; + map.next_slot += 1; + trace!( + "Registering {}timeout for task {} in slot {}", + match &state { + TimeoutState::Expired => "(expired) ", + TimeoutState::Waiting(_) => "", + }, + task_name, + slot + ); + let entry = TimeoutEntry { task_id, state }; + map.entries.insert(slot, entry); + slot + }) + }) + }; + Timeout { slot, future } +} + +/// Requires a `Future` to complete before the specified duration has elapsed. +pub fn timeout(_duration: Duration, future: T) -> Timeout +where + T: Future, +{ + timeout_inner(future) +} + +/// Requires a `Future` to complete before the specified instant in time. +pub fn timeout_at(_deadline: Instant, future: T) -> Timeout +where + T: Future, +{ + timeout_inner(future) +} + +// Lifted from Tokio +pub(crate) fn far_future() -> Instant { + // Roughly 30 years from now. + // API does not provide a way to obtain max `Instant` + // or convert specific date in the future to instant. + // 1000 years overflows on macOS, 100 years overflows on FreeBSD. + Instant::now() + Duration::from_secs(86400 * 365 * 30) +} + +#[pin_project] +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Sleep(#[pin] Instant); + +/// Waits until duration has elapsed +pub fn sleep(duration: Duration) -> Sleep { + match Instant::now().checked_add(duration) { + Some(deadline) => Sleep(deadline), + None => Sleep(far_future()), + } +} + +/// Waits until the given deadline +pub fn sleep_until(deadline: Instant) -> Sleep { + Sleep(deadline) +} + +impl Future for Sleep { + type Output = (); + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + // If it is going to resolve in more than a year, return `Poll::Pending` + if self.0 >= Instant::now() + Duration::from_secs(86400 * 365) { + Poll::Pending + } else { + shuttle::thread::yield_now(); + Poll::Ready(()) + } + } +} + +impl Sleep { + /// Returns the instant at which the future will complete. + pub fn deadline(&self) -> Instant { + self.0 + } + + /// Returns `true` if `Sleep` has elapsed. + /// + /// A `Sleep` instance is elapsed when the requested duration has elapsed. + pub fn is_elapsed(&self) -> bool { + Instant::now() >= self.0 + } + + /// Resets the `Sleep` instance to a new deadline. + pub fn reset(self: Pin<&mut Self>, deadline: Instant) { + let mut me = self.project(); + *me.0 = deadline; + } +} + +/// Advances time +pub async fn advance(_duration: Duration) { + shuttle::future::yield_now().await; +} + +/// Resumes time. +pub fn resume() {} + +pub mod error { + /// Errors returned by `Timeout`. + /// + /// This error is returned when a timeout expires before the function was able + /// to finish. + #[derive(Debug, Default, PartialEq, Eq)] + pub struct Elapsed(()); + + // ===== impl Elapsed ===== + + impl Elapsed { + // Note that this is not `pub` in Tokio. We expose it to enable modelling of timeouts. + pub fn new() -> Self { + Elapsed(()) + } + } + + impl std::fmt::Display for Elapsed { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + "deadline has elapsed".fmt(fmt) + } + } + + impl std::error::Error for Elapsed {} + + impl From for std::io::Error { + fn from(_err: Elapsed) -> std::io::Error { + std::io::ErrorKind::TimedOut.into() + } + } +} + +pub use tokio::time::MissedTickBehavior; + +/// Interval returned by [`interval`] and [`interval_at`]. +#[derive(Debug)] +pub struct Interval { + /// The strategy `Interval` should use when a tick is missed. + missed_tick_behavior: MissedTickBehavior, + count: usize, + period: Duration, +} + +/// Creates new [`Interval`] that yields with interval of `period`. The first +pub fn interval(period: Duration) -> Interval { + interval_at(Instant::now(), period) +} + +/// Creates new [`Interval`] that yields with interval of `period` with the +/// first tick completing at `start`. The default [`MissedTickBehavior`] is +/// [`Burst`](MissedTickBehavior::Burst), but this can be configured +/// by calling [`set_missed_tick_behavior`](Interval::set_missed_tick_behavior). +/// +/// An interval will tick indefinitely. At any time, the [`Interval`] value can +/// be dropped. This cancels the interval. +pub fn interval_at(_start: Instant, period: Duration) -> Interval { + // By default, Shuttle will allow interval ticks to be generated forever (or at + // least `usize::MAX` times per interval). To override this behavior, set the + // environment variable `SHUTTLE_INTERVAL_TICKS` to the number of time each interval + // should tick. (Setting this to 0 means intervals never generate ticks.) + let count = std::env::var("SHUTTLE_INTERVAL_TICKS") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(usize::MAX); + Interval { + missed_tick_behavior: Default::default(), + count, + period, + } +} + +static HAS_WARNED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false); + +impl Interval { + /// Completes when the next instant in the interval has been reached. + pub async fn tick(&mut self) -> Instant { + if self.count > 0 { + shuttle::future::yield_now().await; + self.count -= 1; + Instant::now() + } else { + futures::future::pending().await + } + } + + /// Polls for the next instant in the interval to be reached. + pub fn poll_tick(&mut self, _cx: &mut Context<'_>) -> Poll { + if self.count > 0 { + shuttle::thread::yield_now(); + self.count -= 1; + Poll::Ready(Instant::now()) + } else { + Poll::Pending + } + } + + /// Returns the [`MissedTickBehavior`] strategy currently being used. + pub fn missed_tick_behavior(&self) -> MissedTickBehavior { + self.missed_tick_behavior + } + + /// Sets the [`MissedTickBehavior`] strategy that should be used. + pub fn set_missed_tick_behavior(&mut self, behavior: MissedTickBehavior) { + self.missed_tick_behavior = behavior; + } + + /// Resets the interval to complete one period after the current time. + /// TODO make this work right + pub fn reset(&mut self) {} + + /// Disables the interval + pub fn disable(&mut self) { + self.count = 0; + } + + /// Returns the period of the interval. + pub fn period(&self) -> Duration { + use std::sync::atomic::Ordering; + if !HAS_WARNED.load(Ordering::SeqCst) { + if std::env::var("SHUTTLE_SILENCE_WARNINGS").is_err() { + tracing::warn!("`period` suggests code dependent on real time, which means that it will behave\n\ + nondeterministically under Shuttle, meaning that failing schedules would not be replayable.\n\ + The suggested solution is to have some different handling under Shuttle, by using `cfg(feature = \"shuttle\")`.\n + If you do not wish to see this warning again, then it can be turned off by setting the SHUTTLE_SILENCE_WARNINGS environment variable."); + } + HAS_WARNED.store(true, Ordering::SeqCst); + } + self.period + } +} diff --git a/wrappers/tokio/impls/tokio/inner/tests/examples.rs b/wrappers/tokio/impls/tokio/inner/tests/examples.rs new file mode 100644 index 00000000..28934d24 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/tests/examples.rs @@ -0,0 +1,55 @@ +use futures::FutureExt; +use shuttle::check_dfs; +use shuttle::future; +use shuttle_tokio_impl_inner::sync::Mutex; +use test_log::test; + +/// This function illustrates the danger of partially executing a set of futures +/// (e.g., in a `select!` statement), and then awaiting one of the futures without +/// dropping the other futures first. The other futures could still hold resources +/// that the awaited future needs, resulting in a deadlock. +async fn footgun(drop_a: bool) { + let lock = Mutex::new(()); + + let mut a = async { + let _guard = lock.lock().await; + future::yield_now().await; + } + .boxed(); + + let mut b = async { + let _guard = lock.lock().await; + } + .boxed(); + + let mut c = async {}.boxed(); + + // task `a` acquires the mutex and yields + // task `b` becomes a waiter for the mutex + // task `c` completes the select + tokio::select! { + biased; + () = &mut a => {} + () = &mut b => {} + () = &mut c => {} + } + + if drop_a { + // dropping `a` releases the mutex so that `b` can acquire it + drop(a); + } + + // if `a` was not dropped, then `b` cannot acquire the mutex and we deadlock + b.await; +} + +#[test] +#[should_panic(expected = "deadlock")] +fn footgun_deadlock() { + check_dfs(|| future::block_on(footgun(false)), None); +} + +#[test] +fn footgun_averted() { + check_dfs(|| future::block_on(footgun(true)), None); +} diff --git a/wrappers/tokio/impls/tokio/inner/tests/macros.rs b/wrappers/tokio/impls/tokio/inner/tests/macros.rs new file mode 100644 index 00000000..dece2cd9 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/tests/macros.rs @@ -0,0 +1,56 @@ +// We import as `some_other_name` to ensure that renaming to something arbitrary works. +use shuttle_tokio_impl_inner as some_other_name; +// We import as `tokio` to ensure that renaming to tokio works (there could in theory have been a collision). +use shuttle_tokio_impl_inner as tokio; +use shuttle_tokio_impl_inner::sync::Mutex; +use std::sync::Arc; + +#[shuttle_tokio_impl_inner::test] +async fn litmus() {} + +#[some_other_name::test] +async fn litmus_renamed() {} + +// Test that it works with parameters +#[some_other_name::test(flavor = "multi_thread", worker_threads = 1)] +async fn litmus_params() {} + +// Simple happy case test +#[some_other_name::test] +async fn should_succeed() { + let lock = Arc::new(Mutex::new(0usize)); + let lock_clone = Arc::clone(&lock); + + let jh = tokio::spawn(async move { + let mut counter = lock_clone.lock().await; + *counter += 1; + }); + + { + let mut counter = lock.lock().await; + *counter += 1; + } + jh.await.unwrap(); + assert!(*lock.lock().await == 2); +} + +// Simple failure case test +#[should_panic(expected = "Failed")] +#[tokio::test] +async fn should_fail() { + panic!("Failed"); +} + +// Test that returning a result works. +#[tokio::test] +async fn return_value_ok() -> Result<(), &'static str> { + Ok(()) +} + +#[should_panic(expected = "failure :(")] +#[tokio::test] +async fn return_value_err() -> Result<(), &'static str> { + Err("failure :(") +} + +// TODO: Make the `test` macro compatible with `proptest!` and make a test here which tests that. diff --git a/wrappers/tokio/impls/tokio/inner/tests/mpsc.rs b/wrappers/tokio/impls/tokio/inner/tests/mpsc.rs new file mode 100644 index 00000000..3b31a363 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/tests/mpsc.rs @@ -0,0 +1,794 @@ +use crate::mpsc::error::TryRecvError; +use assert_matches::assert_matches; +use shuttle::future; +use shuttle::sync::Arc; +use shuttle::{check_dfs, check_random}; +use shuttle_tokio_impl_inner::sync::{mpsc, oneshot}; +use std::sync::atomic::Ordering; +use test_log::test; +use tracing::trace; + +#[test] +fn async_mpsc_send_recv_unbounded() { + check_dfs( + || { + future::block_on(async { + let (tx, mut rx) = mpsc::unbounded_channel(); + + future::spawn(async move { + assert!(tx.send(1).is_ok()); + assert!(tx.send(2).is_ok()); + }); + + assert_eq!(Some(1), rx.recv().await); + assert_eq!(Some(2), rx.recv().await); + assert_eq!(None, rx.recv().await); + }); + }, + None, + ); +} + +#[test] +fn mpsc_unbounded_len() { + check_dfs( + || { + future::block_on(async { + let (tx, mut rx) = mpsc::unbounded_channel(); + + assert_eq!(0, rx.len()); + + assert!(tx.send(1).is_ok()); + assert_eq!(1, rx.len()); + assert!(tx.send(2).is_ok()); + assert_eq!(2, rx.len()); + + assert_eq!(Some(1), rx.recv().await); + assert_eq!(1, rx.len()); + assert_eq!(Some(2), rx.recv().await); + assert_eq!(0, rx.len()); + drop(tx); + assert_eq!(None, rx.recv().await); + assert_eq!(0, rx.len()); + }); + }, + None, + ); +} + +#[test] +fn async_mpsc_commutative_senders() { + check_dfs( + || { + future::block_on(async { + let (tx, mut rx) = mpsc::unbounded_channel(); + let tx2 = tx.clone(); + + future::spawn(async move { + tx.send(5).unwrap(); + }); + future::spawn(async move { + tx2.send(6).unwrap(); + }); + let mut val = rx.recv().await.unwrap(); + val += rx.recv().await.unwrap(); + assert_eq!(val, 11); + }); + }, + None, + ); +} + +fn ignore_result
(_: A) {} + +#[test] +#[should_panic(expected = "expected panic: sends can happen in any order")] +fn async_mpsc_loom_non_commutative_senders1() { + check_dfs( + || { + future::block_on(async { + let (s, mut r) = mpsc::unbounded_channel(); + let s2 = s.clone(); + future::spawn(async move { + ignore_result(s.send(5)); + }); + future::spawn(async move { + ignore_result(s2.send(6)); + }); + let val = r.recv().await; + assert_eq!(val, Some(5), "expected panic: sends can happen in any order"); + ignore_result(r.recv().await); + }); + }, + None, + ); +} + +#[test] +#[should_panic(expected = "expected panic: sends can happen in any order")] +fn async_mpsc_loom_non_commutative_senders2() { + check_dfs( + || { + future::block_on(async { + let (s, mut r) = mpsc::unbounded_channel(); + let s2 = s.clone(); + future::spawn(async move { + ignore_result(s.send(5)); + }); + future::spawn(async move { + ignore_result(s2.send(6)); + }); + let val = r.recv().await; + assert_eq!(val, Some(6), "expected panic: sends can happen in any order"); + ignore_result(r.recv().await); + }); + }, + None, + ); +} + +#[test] +fn async_mpsc_drop_sender_unbounded() { + check_dfs( + || { + future::block_on(async { + let (tx, mut rx) = mpsc::unbounded_channel::(); + future::spawn(async move { + drop(tx); + }); + assert!(rx.recv().await.is_none()); + }); + }, + None, + ); +} + +#[test] +fn async_mpsc_drop_receiver_unbounded() { + check_dfs( + || { + let (tx, rx) = mpsc::unbounded_channel(); + drop(rx); + assert!(tx.send(1).is_err()); + }, + None, + ); +} + +#[test] +fn async_mpsc_buffering_behavior() { + check_dfs( + || { + future::block_on(async move { + let (send, mut recv) = mpsc::unbounded_channel(); + let handle = future::spawn(async move { + send.send(1u8).unwrap(); + send.send(2).unwrap(); + send.send(3).unwrap(); + drop(send); + }); + + // wait for the thread to join so we ensure the sender is dropped + handle.await.unwrap(); + + // values sent before the sender disconnects are still available afterwards + assert_eq!(Some(1), recv.recv().await); + assert_eq!(Some(2), recv.recv().await); + assert_eq!(Some(3), recv.recv().await); + // but after the values are exhausted, recv() returns None + assert!(recv.recv().await.is_none()); + }); + }, + None, + ); +} + +#[test] +fn async_mpsc_bounded_sum() { + check_dfs( + || { + future::block_on(async move { + let (tx, mut rx) = mpsc::channel::(5); + future::spawn(async move { + for _ in 0..3 { + tx.send(1).await.unwrap(); + } + }); + let handle = future::spawn(async move { + let mut sum = 0; + for _ in 0..3 { + trace!("... waiting for value"); + sum += rx.recv().await.unwrap(); + } + sum + }); + let r = handle.await.unwrap(); + assert_eq!(r, 3); + }); + }, + None, + ); +} + +#[test] +fn mpsc_bounded_len() { + check_dfs( + || { + future::block_on(async { + let (tx, mut rx) = mpsc::channel(5); + + assert_eq!(0, rx.len()); + + tx.send(1).await.unwrap(); + assert_eq!(1, rx.len()); + tx.send(2).await.unwrap(); + assert_eq!(2, rx.len()); + + assert_eq!(Some(1), rx.recv().await); + assert_eq!(1, rx.len()); + assert_eq!(Some(2), rx.recv().await); + assert_eq!(0, rx.len()); + drop(tx); + assert_eq!(None, rx.recv().await); + assert_eq!(0, rx.len()); + }); + }, + None, + ); +} + +// Sending on a bounded channel doesn't block the sender if the channel isn't filled +#[test] +fn async_mpsc_bounded_sender_buffered() { + check_dfs( + || { + future::block_on(async { + let (tx, _rx) = mpsc::channel::(5); + let handle = future::spawn(async move { + for _ in 0..5 { + tx.send(1).await.unwrap(); + } + 42 + }); + let r = handle.await.unwrap(); + assert_eq!(r, 42); + }); + }, + None, + ); +} + +// Sending on a bounded channel blocks the sender when the channel becomes full +#[test] +#[should_panic(expected = "deadlock")] +fn async_mpsc_bounded_sender_blocked() { + check_dfs( + || { + future::block_on(async { + let (tx, _rx) = mpsc::channel::(10); + let handle = future::spawn(async move { + for _ in 0..11 { + tx.send(1).await.unwrap(); + } + 42 + }); + let r = handle.await.unwrap(); + assert_eq!(r, 42); + }); + }, + None, + ); +} + +async fn mpsc_senders_with_blocking_inner(num_senders: usize, channel_size: usize) { + assert!(num_senders >= channel_size); + let num_receives = num_senders - channel_size; + let (tx, mut rx) = mpsc::channel::(channel_size); + let senders = (0..num_senders) + .map(move |i| { + let tx = tx.clone(); + future::spawn(async move { + tx.send(i).await.unwrap(); + }) + }) + .collect::>(); + + // Receive enough messages to ensure no sender will block + for _ in 0..num_receives { + let _ = rx.recv().await.unwrap(); + } + for sender in senders { + sender.await.unwrap(); + } +} + +#[test] +fn async_mpsc_some_senders_with_blocking() { + check_dfs( + || { + future::block_on(async { + mpsc_senders_with_blocking_inner(3, 1).await; + }); + }, + None, + ); +} + +#[test] +fn async_mpsc_many_senders_with_blocking() { + shuttle::check_random( + || { + future::block_on(async { + mpsc_senders_with_blocking_inner(1000, 500).await; + }); + }, + 1000, + ); +} + +#[test] +fn async_mpsc_many_senders_drop_receiver() { + const NUM_SENDERS: usize = 3; + const CHANNEL_SIZE: usize = 1; + check_dfs( + || { + future::block_on(async { + let (tx, rx) = mpsc::channel::(CHANNEL_SIZE); + let senders = (0..NUM_SENDERS) + .map(move |i| { + let tx = tx.clone(); + future::spawn(async move { + let _ = tx.send(i).await; + }) + }) + .collect::>(); + + // Drop the receiver; this will unblock any waiting senders + drop(rx); + + // Make sure all senders finish + for sender in senders { + sender.await.unwrap(); + } + }); + }, + None, + ); +} + +#[test] +fn async_mpsc_multiple_messages() { + const NUM_MESSAGES: usize = 4; + check_dfs( + || { + future::block_on(async { + let (tx, mut rx) = mpsc::unbounded_channel(); + let (ctx, crx) = oneshot::channel(); + + let h1 = future::spawn(async move { + for i in 0..NUM_MESSAGES { + tx.send(i).unwrap(); + } + crx.await.unwrap(); + }); + + let h2 = future::spawn(async move { + for i in 0..NUM_MESSAGES { + let n = rx.recv().await.unwrap(); + assert_eq!(i, n); + } + ctx.send(()).unwrap(); + }); + + futures::future::join_all([h1, h2]).await; + }); + }, + None, + ); +} + +#[test] +fn async_mpsc_select() { + check_dfs( + || { + let lock = std::sync::Arc::new(std::sync::Mutex::new(0)); + let lock2 = lock.clone(); + future::block_on(async move { + let (tx1, mut rx1) = mpsc::unbounded_channel(); + let (tx2, mut rx2) = mpsc::unbounded_channel(); + + let h1 = future::spawn(async move { + loop { + tokio::select! { + biased; + + msg = rx1.recv() => { + if let Some(v) = msg { + *lock2.lock().unwrap() += v; + } else { + break; + } + } + + _stop = rx2.recv() => { + // It is possible to check the branch above, then have a message be sent on `tx1` and `tx2`, then + // check this branch. + // We have to consume these messages before breaking. + while let Some(v) = rx1.recv().await { + *lock2.lock().unwrap() += v; + } + + break; + } + } + } + }); + + tx1.send(10).unwrap(); + tx1.send(32).unwrap(); + tx2.send(()).unwrap(); + drop(tx1); + h1.await.unwrap(); + }); + let value = *lock.lock().unwrap(); + assert_eq!(value, 42, "{}", value); + }, + None, + ); +} + +#[test] +fn mpsc_drain_after_close() { + check_dfs( + || { + future::block_on(async { + let (tx, mut rx) = mpsc::unbounded_channel(); + tx.send(()).unwrap(); + rx.close(); + assert!(!rx.is_empty()); + rx.recv() + .await + .expect("must be able to receive already sent message after closing receiver"); + assert!(rx.is_empty()); + assert!(rx.recv().await.is_none()); + }); + }, + None, + ); +} + +#[test] +fn mpsc_send_after_close() { + check_dfs( + || { + future::block_on(async { + let (tx, mut rx) = mpsc::unbounded_channel(); + rx.close(); + tx.send(()) + .expect_err("shouldn't be able to send after closing receiver"); + assert!(rx.is_empty()); + assert!(rx.recv().await.is_none()); + }); + }, + None, + ); +} + +/// This test captures a pattern that exposed a bug in mpsc. +/// +/// The bug has been further reduced to [`mpsc_drain_after_close`] and [`mpsc_send_after_close`], +/// but it doesn't hurt to keep this test. +#[test] +fn mpsc_close_bug() { + check_random( + || { + future::block_on(async { + async fn send(v: u64, tx: std::sync::Arc)>>) -> bool { + let (res_tx, res_rx) = oneshot::channel(); + if tx.send((v, res_tx)).is_ok() { + trace!("send {} ok", v); + // bug caused potential deadlock here + res_rx.await.unwrap(); + trace!("send {} response", v); + true + } else { + trace!("send {} failed", v); + false + } + } + + let (tx, mut rx) = mpsc::unbounded_channel(); + let tx = std::sync::Arc::new(tx); + + let sender1 = future::spawn(send(1, tx.clone())); + let sender2 = future::spawn(send(2, tx)); + + let receiver = future::spawn(async move { + let (v, res_tx) = rx.recv().await.unwrap(); + trace!("first recv {}", v); + res_tx.send(()).unwrap(); + rx.close(); + if let Some((v, res_tx)) = rx.recv().await { + trace!("second recv {}", v); + assert!(rx.is_empty()); + res_tx.send(()).unwrap(); + true + } else { + trace!("second recv failed"); + false + } + }); + + let send_1_ok = sender1.await.unwrap(); + let send_2_ok = sender2.await.unwrap(); + let received_both = receiver.await.unwrap(); + + assert!(send_1_ok || send_2_ok); + assert!(!received_both || (send_1_ok && send_2_ok)); + }); + }, + 10_000, + ); +} + +/// This test captures a pattern that exposed a bug in mpsc. +/// +/// Similar to [`mpsc_close_bug`], but dropping instead of closing receiver. +/// +/// The bug has been further reduced to [`drop_unreceived_messages_when_receiver_drops`]. +#[test] +fn mpsc_drop_receiver_bug() { + // values collected across executions, to ensure `check_random` explored interesting executions + let send_results = Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); + let send_results_ = send_results.clone(); + + check_random( + move || { + future::block_on(async { + /// Possible return values: + /// * `Ok(Ok(()))`: successfully sent message and awaited response + /// * `Ok(Err(()))`: successfully sent messages but failed awaiting response + /// * `Err(())`: failed to send message + async fn send( + v: u64, + tx: std::sync::Arc)>>, + ) -> Result, ()> { + let (res_tx, res_rx) = oneshot::channel(); + tx.send((v, res_tx)).map_err(|_| ())?; + // bug caused potential deadlock here, due to `res_tx` not getting dropped + // when `rx` is dropped below + Ok(res_rx.await.map_err(|_| ())) + } + + let (tx, mut rx) = mpsc::unbounded_channel(); + let tx = std::sync::Arc::new(tx); + + let sender1 = future::spawn(send(1, tx.clone())); + let sender2 = future::spawn(send(2, tx)); + + let receiver = future::spawn(async move { + let (v, res_tx) = rx.recv().await.unwrap(); + res_tx.send(()).unwrap(); + drop(rx); + v + }); + + let send_1_res = sender1.await.unwrap(); + let send_2_res = sender2.await.unwrap(); + let received_value = receiver.await.unwrap(); + + if received_value == 1 { + assert_eq!(send_1_res, Ok(Ok(()))); + assert_matches!(send_2_res, Err(()) | Ok(Err(()))); + } else { + assert_eq!(send_2_res, Ok(Ok(()))); + assert_matches!(send_1_res, Err(()) | Ok(Err(()))); + } + + let mut send_results = send_results_.lock().unwrap(); + send_results.insert(send_1_res); + send_results.insert(send_2_res); + }); + }, + 10_000, + ); + + // `Ok(Err(()))` is most important, but let's just check all possible values + let send_results = Arc::into_inner(send_results).unwrap().into_inner().unwrap(); + assert_eq!(send_results, [Ok(Ok(())), Ok(Err(())), Err(())].into()); +} + +#[test] +fn drop_unreceived_messages_when_receiver_drops() { + check_dfs( + || { + future::block_on(async { + let (tx, rx) = mpsc::unbounded_channel(); + + let send_task = future::spawn(async move { + let (otx, orx) = oneshot::channel::<()>(); + if tx.send(otx).is_ok() { + // used to deadlock due to `otx` not getting dropped when `rx` is dropped below + assert!(orx.await.is_err()); + } + }); + + drop(rx); + + send_task.await.unwrap(); + }); + }, + None, + ); +} + +// Sanity test for `is_closed()`. Tests that `close` closes the channel. +#[test] +fn mpsc_close_closes_channel() { + check_dfs( + || { + future::block_on(async { + let (_tx, mut rx) = mpsc::unbounded_channel::<()>(); + rx.close(); + assert!(rx.is_closed()); + + let (_tx, mut rx) = mpsc::channel::<()>(1); + rx.close(); + assert!(rx.is_closed()); + }); + }, + None, + ); +} + +// Sanity test for `is_closed()`, testing that dropping all senders closes the channel. +#[test] +fn mpsc_dropping_senders_closed_channel() { + check_dfs( + || { + future::block_on(async { + let (unbounded_rx, bounded_rx) = { + let (_tx, rx) = mpsc::unbounded_channel::<()>(); + let (_tx, rx2) = mpsc::channel::<()>(1); + (rx, rx2) + }; + // All senders should be dropped here, meaning the channels should be closed. + assert!(unbounded_rx.is_closed()); + assert!(bounded_rx.is_closed()); + }); + }, + None, + ); +} + +// Tests that `recv`, `blocking_recv` and `try_recv` work correctly on channel drop +#[test] +fn mpsc_recv_and_friends_correct_on_sender_drop_unbounded() { + // There should be schedules which hit `TryRecvError::Disconnected` and `TryRecvError::Empty` + let has_seen_empties = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let has_seen_disconnects = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let empties = has_seen_empties.clone(); + let disconnects = has_seen_disconnects.clone(); + + check_dfs( + move || { + let empties = empties.clone(); + let disconnects = disconnects.clone(); + + // The whole idea of these is to have one thread `recv`/`blocking_recv`/`try_recv` and then drop the sender. + + future::block_on(async { + let jh = { + let (_tx, mut rx) = mpsc::unbounded_channel::<()>(); + future::spawn(async move { + let msg = rx.recv().await; + assert!(msg.is_none()); + assert!(rx.is_closed()); + }) + }; + jh.await.unwrap(); + + // The same as above, but with blocking_recv + let jh = { + let (_tx, mut rx) = mpsc::unbounded_channel::<()>(); + shuttle::thread::spawn(move || { + let msg = rx.blocking_recv(); + assert!(msg.is_none()); + assert!(rx.is_closed()); + }) + }; + jh.join().unwrap(); + + // Similar, but with `try_recv`. + let jh = { + let (_tx, mut rx) = mpsc::unbounded_channel::<()>(); + future::spawn(async move { + let is_closed = rx.is_closed(); + + match rx.try_recv().unwrap_err() { + TryRecvError::Disconnected => { + disconnects.clone().store(true, Ordering::SeqCst); + assert!(rx.is_closed()); + } + TryRecvError::Empty => { + empties.clone().store(true, Ordering::SeqCst); + assert!(!is_closed); + } + } + }) + }; + jh.await.unwrap(); + }); + }, + None, + ); + + assert!(has_seen_disconnects.load(Ordering::SeqCst)); + assert!(has_seen_empties.load(Ordering::SeqCst)); +} + +// Same test as above, but on bounded channels +#[test] +fn mpsc_recv_and_friends_correct_on_sender_drop_bounded() { + // There should be schedules which hit `TryRecvError::Disconnected` and `TryRecvError::Empty` + let has_seen_empties = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let has_seen_disconnects = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let empties = has_seen_empties.clone(); + let disconnects = has_seen_disconnects.clone(); + + check_dfs( + move || { + let empties = empties.clone(); + let disconnects = disconnects.clone(); + + // The whole idea of these is to have one thread `recv`/`blocking_recv`/`try_recv` and then drop the sender. + + future::block_on(async { + let jh = { + let (_tx, mut rx) = mpsc::channel::<()>(1); + future::spawn(async move { + let msg = rx.recv().await; + assert!(msg.is_none()); + assert!(rx.is_closed()); + }) + }; + jh.await.unwrap(); + + // The same as above, but with blocking_recv + let jh = { + let (_tx, mut rx) = mpsc::channel::<()>(1); + shuttle::thread::spawn(move || { + let msg = rx.blocking_recv(); + assert!(msg.is_none()); + assert!(rx.is_closed()); + }) + }; + jh.join().unwrap(); + + // Similar, but with `try_recv`. + let jh = { + let (_tx, mut rx) = mpsc::channel::<()>(1); + future::spawn(async move { + let is_closed = rx.is_closed(); + + match rx.try_recv().unwrap_err() { + TryRecvError::Disconnected => { + disconnects.clone().store(true, Ordering::SeqCst); + assert!(rx.is_closed()); + } + TryRecvError::Empty => { + empties.clone().store(true, Ordering::SeqCst); + assert!(!is_closed); + } + } + }) + }; + jh.await.unwrap(); + }); + }, + None, + ); + + assert!(has_seen_disconnects.load(Ordering::SeqCst)); + assert!(has_seen_empties.load(Ordering::SeqCst)); +} diff --git a/wrappers/tokio/impls/tokio/inner/tests/mutex.rs b/wrappers/tokio/impls/tokio/inner/tests/mutex.rs new file mode 100644 index 00000000..6a8ae785 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/tests/mutex.rs @@ -0,0 +1,144 @@ +use shuttle::future; +use shuttle::scheduler::PctScheduler; +use shuttle::{check_dfs, Runner}; +use shuttle_tokio_impl_inner::sync::Mutex; +use std::collections::BTreeMap; +use std::sync::Arc; +use test_log::test; + +#[test] +fn async_mutex_basic_test() { + check_dfs( + || { + future::block_on(async { + let lock = Arc::new(Mutex::new(0usize)); + let lock_clone = Arc::clone(&lock); + + future::spawn(async move { + let mut counter = lock_clone.lock().await; + *counter += 1; + }); + + let mut counter = lock.lock().await; + *counter += 1; + }); + }, + None, + ); + + // TODO would be cool if we were allowed to smuggle the lock out of the run, + // TODO so we can assert invariants about it *after* execution ends +} + +async fn deadlock() { + let lock1 = Arc::new(Mutex::new(0usize)); + let lock2 = Arc::new(Mutex::new(0usize)); + let lock1_clone = Arc::clone(&lock1); + let lock2_clone = Arc::clone(&lock2); + + future::spawn(async move { + let _l1 = lock1_clone.lock().await; + let _l2 = lock2_clone.lock().await; + }); + + let _l2 = lock2.lock().await; + let _l1 = lock1.lock().await; +} + +#[test] +#[should_panic(expected = "deadlock")] +fn async_mutex_deadlock() { + check_dfs(|| future::block_on(async { deadlock().await }), None); +} + +#[test] +#[should_panic(expected = "racing increments")] +fn async_mutex_concurrent_increment_buggy() { + let scheduler = PctScheduler::new(2, 100); + let runner = Runner::new(scheduler, Default::default()); + runner.run(|| { + future::block_on(async { + let lock = Arc::new(Mutex::new(0usize)); + + let tasks = (0..2) + .map(|_| { + let lock = Arc::clone(&lock); + future::spawn(async move { + let curr = *lock.lock().await; + *lock.lock().await = curr + 1; + }) + }) + .collect::>(); + + for t in tasks { + t.await.unwrap(); + } + + let counter = *lock.lock().await; + assert_eq!(counter, 2, "racing increments"); + }); + }); +} + +// Create a BTreeMap of usize -> Node where each Node contains some value (usize), initially 0 +// Spin up tasks that grab locks on individual nodes from the tree in some order +// Once all nodes are locked, the task increments the value of each node and then releases them +async fn owned_mutex_increment(num_locks: usize, mut locks: Vec>) -> Vec { + #[derive(Debug)] + struct Node(usize); + + let mut map = BTreeMap::new(); + for i in 0usize..num_locks { + map.insert(i, Arc::new(Mutex::new(Node(0)))); + } + let map = Arc::new(map); + let mut handles = Vec::new(); + for lock in locks.drain(..) { + let map = Arc::clone(&map); + handles.push(future::spawn(async move { + let mut nodes = Vec::new(); + for m in lock { + let node = map.get(&m).unwrap().clone(); + nodes.push(node.lock_owned().await); + } + for mut n in nodes { + n.0 += 1; + } + })); + } + for h in handles { + h.await.unwrap(); + } + let mut values = Vec::new(); + for (_, n) in map.iter() { + let v = n.lock().await.0; + values.push(v); + } + values +} + +#[test] +fn async_mutex_owned_1() { + check_dfs( + || { + future::block_on(async move { + let v = owned_mutex_increment(5usize, vec![vec![0, 2, 4], vec![1, 3]]).await; + assert_eq!(v, [1, 1, 1, 1, 1]); + }); + }, + None, + ); +} + +#[test] +fn async_mutex_owned_2() { + check_dfs( + || { + future::block_on(async move { + let v = owned_mutex_increment(5usize, vec![vec![0, 2, 4], vec![1, 2, 3]]).await; + assert_eq!(v, [1, 1, 2, 1, 1]); + }); + }, + None, + ); +} diff --git a/wrappers/tokio/impls/tokio/inner/tests/notify.rs b/wrappers/tokio/impls/tokio/inner/tests/notify.rs new file mode 100644 index 00000000..b79c836f --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/tests/notify.rs @@ -0,0 +1,302 @@ +use shuttle::future; +use shuttle_tokio_impl_inner::sync::notify::Notify; +use std::collections::VecDeque; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, +}; +use test_log::test; + +fn check_dfs(f: F) +where + F: Fn() + Send + Sync + 'static, +{ + use shuttle::scheduler::DfsScheduler; + let mut config = shuttle::Config::default(); + config.max_steps = shuttle::MaxSteps::FailAfter(1_000); + + let scheduler = DfsScheduler::new(None, true); + let runner = shuttle::Runner::new(scheduler, config); + runner.run(f); +} + +#[test] +fn notify_mpsc_channel() { + // Example from the tokio docs for Notify + + // Unbound multi-producer single-consumer (mpsc) channel. + // + // No wakeups can be lost when using this channel because the call to + // `notify_one()` will store a permit in the `Notify`, which the following call + // to `notified()` will consume. + struct Channel { + values: Mutex>, + notify: Notify, + } + + impl Channel { + fn new() -> Self { + Self { + values: Mutex::new(VecDeque::new()), + notify: Notify::new(), + } + } + pub fn send(&self, value: T) { + let mut values = self.values.lock().unwrap(); + values.push_back(value); + drop(values); + + // Notify the consumer a value is available + self.notify.notify_one(); + } + + // This is a single-consumer channel, so several concurrent calls to + // `recv` are not allowed. + pub async fn recv(&self) -> T { + loop { + // Drain values + { + let mut values = self.values.lock().unwrap(); + if let Some(value) = values.pop_front() { + return value; + } + drop(values); + } + + // Wait for values to be available + self.notify.notified().await; + } + } + } + + check_dfs(|| { + future::block_on(async { + let tx1 = Arc::new(Channel::new()); + let tx2 = tx1.clone(); + let rx = tx1.clone(); + future::spawn(async move { + tx1.send(1); + }); + future::spawn(async move { + tx2.send(2); + }); + let mut v = rx.recv().await; + v += rx.recv().await; + assert_eq!(v, 3); + }); + }); +} + +#[test] +fn notify_mpmc_channel() { + // Example from the tokio docs for Notify + // A multi-producer, multi-consumer channel + + struct Channel { + messages: Mutex>, + notify_on_sent: Notify, + } + + impl Channel { + fn new() -> Self { + Self { + messages: Mutex::new(VecDeque::new()), + notify_on_sent: Notify::new(), + } + } + + pub fn send(&self, msg: T) { + let mut locked_queue = self.messages.lock().unwrap(); + locked_queue.push_back(msg); + drop(locked_queue); + + // Send a notification to one of the calls currently + // waiting in a call to `recv`. + self.notify_on_sent.notify_one(); + } + + pub fn try_recv(&self) -> Option { + let mut locked_queue = self.messages.lock().unwrap(); + locked_queue.pop_front() + } + + pub async fn recv(&self) -> T { + let mut future = self.notify_on_sent.notified(); + let mut future = unsafe { std::pin::Pin::new_unchecked(&mut future) }; + + loop { + // Make sure that no wakeup is lost if we get + // `None` from `try_recv`. + future.as_mut().enable(); + + if let Some(msg) = self.try_recv() { + return msg; + } + + // Wait for a call to `notify_one`. + // + // This uses `.as_mut()` to avoid consuming the future, + // which lets us call `Pin::set` below. + future.as_mut().await; + + // Reset the future in case another call to + // `try_recv` got the message before us. + future.set(self.notify_on_sent.notified()); + } + } + } + + check_dfs(|| { + future::block_on(async { + let counter = Arc::new(AtomicUsize::new(0)); + let counter1 = counter.clone(); + let counter2 = counter.clone(); + + let tx1 = Arc::new(Channel::new()); + let tx2 = tx1.clone(); + let rx1 = tx1.clone(); + let rx2 = tx1.clone(); + future::spawn(async move { + tx1.send(1); + }); + future::spawn(async move { + tx2.send(2); + }); + let r1 = future::spawn(async move { + let x = rx1.recv().await; + counter1.fetch_add(x, Ordering::SeqCst); + }); + let r2 = future::spawn(async move { + let x = rx2.recv().await; + counter2.fetch_add(x, Ordering::SeqCst); + }); + r1.await.unwrap(); + r2.await.unwrap(); + assert_eq!(counter.load(Ordering::SeqCst), 3); + }); + }); +} + +fn notify_mpmc_channel_2_test(do_enable: bool) { + // Example from the tokio docs + + // Unbound multi-producer multi-consumer (mpmc) channel. + // + // The call to `enable` is important because otherwise if you have two + // calls to `recv` and two calls to `send` in parallel, the following could + // happen: + // + // 1. Both calls to `try_recv` return `None`. + // 2. Both new elements are added to the vector. + // 3. The `notify_one` method is called twice, adding only a single + // permit to the `Notify`. + // 4. Both calls to `recv` reach the `Notified` future. One of them + // consumes the permit, and the other sleeps forever. + // + // By adding the `Notified` futures to the list by calling `enable` before + // `try_recv`, the `notify_one` calls in step three would remove the + // futures from the list and mark them notified instead of adding a permit + // to the `Notify`. This ensures that both futures are woken. + struct Channel { + messages: Mutex>, + notify_on_sent: Notify, + do_enable: bool, + } + + impl Channel { + fn new(do_enable: bool) -> Self { + Self { + messages: Mutex::new(VecDeque::new()), + notify_on_sent: Notify::new(), + do_enable, + } + } + + pub fn send(&self, msg: T) { + let mut locked_queue = self.messages.lock().unwrap(); + locked_queue.push_back(msg); + drop(locked_queue); + + // Send a notification to one of the calls currently + // waiting in a call to `recv`. + self.notify_on_sent.notify_one(); + } + + pub fn try_recv(&self) -> Option { + let mut locked_queue = self.messages.lock().unwrap(); + locked_queue.pop_front() + } + + pub async fn recv(&self) -> T { + let mut future = self.notify_on_sent.notified(); + let mut future = unsafe { std::pin::Pin::new_unchecked(&mut future) }; + + loop { + if self.do_enable { + // Make sure that no wakeup is lost if we get `None` from `try_recv`. + future.as_mut().enable(); + } + + if let Some(msg) = self.try_recv() { + return msg; + } + + // Force context switch before the future is polled + shuttle::future::yield_now().await; + + // Wait for a call to `notify_one`. + // + // This uses `.as_mut()` to avoid consuming the future, + // which lets us call `Pin::set` below. + future.as_mut().await; + + // Reset the future in case another call to + // `try_recv` got the message before us. + future.set(self.notify_on_sent.notified()); + } + } + } + + check_dfs(move || { + future::block_on(async { + let counter = Arc::new(AtomicUsize::new(0)); + let counter1 = counter.clone(); + let counter2 = counter.clone(); + + let tx1 = Arc::new(Channel::new(do_enable)); + let tx2 = tx1.clone(); + let rx1 = tx1.clone(); + let rx2 = tx1.clone(); + + let mut h = vec![]; + + h.push(future::spawn(async move { + let x = rx1.recv().await; + counter1.fetch_add(x, Ordering::SeqCst); + })); + h.push(future::spawn(async move { + let x = rx2.recv().await; + counter2.fetch_add(x, Ordering::SeqCst); + })); + h.push(future::spawn(async move { + tx1.send(1); + })); + h.push(future::spawn(async move { + tx2.send(2); + })); + futures::future::join_all(h).await; + assert_eq!(counter.load(Ordering::SeqCst), 3); + }); + }); +} + +#[test] +fn notify_mpmc_no_deadlock() { + notify_mpmc_channel_2_test(true); +} + +#[test] +#[should_panic(expected = "deadlock")] +fn notify_mpmc_deadlock() { + notify_mpmc_channel_2_test(false); +} diff --git a/wrappers/tokio/impls/tokio/inner/tests/oneshot.rs b/wrappers/tokio/impls/tokio/inner/tests/oneshot.rs new file mode 100644 index 00000000..ac7c185e --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/tests/oneshot.rs @@ -0,0 +1,126 @@ +use shuttle::{check_dfs, check_random, future}; +use shuttle_tokio_impl_inner::sync::{mpsc, oneshot}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +#[macro_export] +macro_rules! assert_pending { + ($e:expr) => {{ + use core::task::Poll::*; + match $e { + Pending => {} + Ready(v) => panic!("ready; value = {:?}", v), + } + }}; + ($e:expr, $($msg:tt)+) => {{ + use core::task::Poll::*; + match $e { + Pending => {} + Ready(v) => { + panic!("ready; value = {:?}; {}", v, format_args!($($msg)+)) + } + } + }}; +} + +#[test] +fn async_oneshot_send_recv() { + check_dfs( + || { + future::block_on(async { + let (tx, mut rx) = oneshot::channel(); + + let h1 = future::spawn(async move { + let result = rx.try_recv(); + if let Err(e) = result { + assert_eq!(e, oneshot::error::TryRecvError::Empty); + let result = rx.await; + assert_eq!(result, Ok(1)); + } else { + assert_eq!(result, Ok(1)); + } + }); + + let h2 = future::spawn(async move { + assert!(!tx.is_closed()); + assert!(tx.send(1).is_ok()); + }); + + futures::future::join_all(vec![h1, h2]).await; + }); + }, + None, + ); +} + +#[test] +fn async_oneshot_commands() { + // Test a common use case for oneshot channels: as a way to acknowledge completion of an activity. + // A client asks the server to do some work (increment a shared counter) and reply when done. + check_dfs( + || { + future::block_on(async { + let counter = Arc::new(AtomicUsize::new(0)); + let counter2 = counter.clone(); + let (tx, mut rx) = mpsc::channel(1); + + // Client + let h1 = future::spawn(async move { + let (ctx, crx) = oneshot::channel(); + tx.send(ctx).await.unwrap(); + crx.await.unwrap(); + assert_eq!(counter.load(Ordering::SeqCst), 3); + }); + + // Server + let h2 = future::spawn(async move { + let ctx = rx.recv().await.unwrap(); + for _ in 0..3 { + counter2.fetch_add(1, Ordering::SeqCst); + } + ctx.send(()).unwrap(); + }); + futures::future::join_all(vec![h1, h2]).await; + }); + }, + None, + ); +} + +#[test] +#[should_panic(expected = "assertion failed")] +fn async_oneshot_yield() { + check_dfs( + || { + let (mut tx1, mut rx1) = oneshot::channel::<()>(); + let (tx2, mut rx2) = oneshot::channel::<()>(); + future::block_on(async { + let h = future::spawn(async move { + rx1.close(); + rx2.close(); + }); + tx1.closed().await; // wait for first channel to be closed + assert!(tx2.is_closed()); // Some execution should fail this assertion + h.await.unwrap(); + }); + }, + None, + ); +} + +// If there is no synchronization point in `try_recv` (as was the case previously), then +// this test will hang forever. +#[test] +fn try_recv_loop() { + check_random( + || { + let (tx1, mut rx1) = oneshot::channel::<()>(); + future::block_on(async { + let try_recv_task = future::spawn(async move { while rx1.try_recv().is_err() {} }); + tx1.send(()).unwrap(); + try_recv_task.await.unwrap(); + }); + }, + 100, + ); +} diff --git a/wrappers/tokio/impls/tokio/inner/tests/runtime.rs b/wrappers/tokio/impls/tokio/inner/tests/runtime.rs new file mode 100644 index 00000000..ee31d137 --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/tests/runtime.rs @@ -0,0 +1,45 @@ +use shuttle_tokio_impl_inner::runtime::{self, Handle}; +use shuttle_tokio_impl_inner::sync::mpsc; +use test_log::test; + +async fn mpsc_senders_with_blocking_inner(num_senders: usize, channel_size: usize) { + let rt = Handle::current(); + + assert!(num_senders >= channel_size); + let num_receives = num_senders - channel_size; + let (tx, mut rx) = mpsc::channel::(channel_size); + let senders = (0..num_senders) + .map(move |i| { + let tx = tx.clone(); + rt.spawn(async move { + tx.send(i).await.unwrap(); + }) + }) + .collect::>(); + + // Receive enough messages to ensure no sender will block + for _ in 0..num_receives { + let _ = rx.recv().await.unwrap(); + } + for sender in senders { + sender.await.unwrap(); + } +} + +#[test] +fn runtime_mpsc_many_senders_with_blocking() { + shuttle::check_random( + || { + let rt = runtime::Builder::new_current_thread() + .enable_time() + .start_paused(true) + .build() + .unwrap(); + + rt.block_on(async { + mpsc_senders_with_blocking_inner(1000, 500).await; + }); + }, + 1000, + ); +} diff --git a/wrappers/tokio/impls/tokio/inner/tests/rwlock.rs b/wrappers/tokio/impls/tokio/inner/tests/rwlock.rs new file mode 100644 index 00000000..eafd33fe --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/tests/rwlock.rs @@ -0,0 +1,202 @@ +use shuttle::current::{me, set_label_for_task}; +use shuttle::{check_dfs, check_random, future}; +use shuttle_tokio_impl_inner::sync::{mpsc, RwLock}; +use shuttle_tokio_impl_inner::time::{clear_triggers, timeout, trigger_timeouts}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; +use test_log::test; +use tracing::trace; + +#[test] +fn async_rwlock_reader_concurrency() { + let saw_concurrent_reads = Arc::new(AtomicBool::new(false)); + { + let saw_concurrent_reads = Arc::clone(&saw_concurrent_reads); + check_random( + move || { + let saw_concurrent_reads2 = Arc::clone(&saw_concurrent_reads); + future::block_on(async move { + let rwlock = Arc::new(RwLock::new(0usize)); + let readers = Arc::new(AtomicUsize::new(0)); + + { + let rwlock = Arc::clone(&rwlock); + let readers = Arc::clone(&readers); + + let saw_concurrent_reads3 = Arc::clone(&saw_concurrent_reads2); + + future::spawn(async move { + let counter = rwlock.read().await; + assert_eq!(*counter, 0); + + readers.fetch_add(1, Ordering::SeqCst); + + future::yield_now().await; + + if readers.load(Ordering::SeqCst) == 2 { + saw_concurrent_reads3.store(true, Ordering::SeqCst); + } + + readers.fetch_sub(1, Ordering::SeqCst); + }); + } + + let counter = rwlock.read().await; + assert_eq!(*counter, 0); + + readers.fetch_add(1, Ordering::SeqCst); + + future::yield_now().await; + + if readers.load(Ordering::SeqCst) == 2 { + saw_concurrent_reads2.store(true, Ordering::SeqCst); + } + + readers.fetch_sub(1, Ordering::SeqCst); + }); + }, + 100, + ); + } + + assert!(saw_concurrent_reads.load(Ordering::SeqCst)); +} + +async fn deadlock() { + let lock1 = Arc::new(RwLock::new(0usize)); + let lock2 = Arc::new(RwLock::new(0usize)); + let lock1_clone = Arc::clone(&lock1); + let lock2_clone = Arc::clone(&lock2); + + future::spawn(async move { + let _l1 = lock1_clone.read().await; + let _l2 = lock2_clone.read().await; + }); + + let _l2 = lock2.write().await; + let _l1 = lock1.write().await; +} + +#[test] +#[should_panic(expected = "deadlock")] +fn async_rwlock_deadlock() { + check_dfs( + || { + future::block_on(async { + deadlock().await; + }); + }, + None, + ); +} + +#[test] +fn async_rwlock_two_writers() { + check_dfs( + || { + future::block_on(async move { + let lock = Arc::new(RwLock::new(1)); + let lock2 = lock.clone(); + + future::spawn(async move { + let mut w = lock.write_owned().await; + *w += 1; + let v = *w; + let r = w.downgrade(); + assert_eq!(*r, v); + }); + + future::spawn(async move { + let mut w = lock2.write_owned().await; + *w += 1; + }); + }); + }, + None, + ); +} + +// Check that multiple readers are allowed to read at a time +// This test should never deadlock. +#[test] +fn async_rwlock_allows_multiple_readers() { + check_dfs( + || { + future::block_on(async move { + let lock1 = Arc::new(RwLock::new(1)); + let lock2 = lock1.clone(); + + let (s1, mut r1) = mpsc::unbounded_channel::(); + let (s2, mut r2) = mpsc::unbounded_channel::(); + + future::spawn(async move { + let w = lock1.read().await; + s1.send(*w).unwrap(); // Send value to other thread + let r = r2.recv().await; // Wait for value from other thread + assert_eq!(r, Some(1)); + }); + + future::spawn(async move { + let w = lock2.read().await; + s2.send(*w).unwrap(); + let r = r1.recv().await; + assert_eq!(r, Some(1)); + }); + }); + }, + None, + ); +} + +/// This test exposed a deadlock bug in the following situation. A task is holding a read lock. +/// Another task wants to acquire a write lock and gets blocked. We cancel that task. Yet another +/// task gets stuck acquiring a read lock, although the lock should be available. +/// +/// The underlying bug in `BatchSemaphore` was fixed in https://github.com/awslabs/shuttle/pull/167. +#[test] +fn canceling_blocked_write_deadlock_bug() { + check_dfs( + || { + future::block_on(async move { + #[derive(Debug, Clone)] + struct TimeoutLabel; + + clear_triggers(); + + let lock = Arc::new(RwLock::new(())); + let lock1 = lock.clone(); + let lock2 = lock.clone(); + + // Acquire a read lock so that acquiring a write lock below gets blocked. + // This corresponds to a leaked lock in the production code that initially hit the bug. + let _read_guard = lock.read().await; + + let w = future::spawn(async move { + set_label_for_task(me(), TimeoutLabel); + trace!("acquiring write lock"); + let _write_result = timeout(Duration::from_secs(1), lock1.write()).await; + assert!(_write_result.is_err()); + trace!("acquiring write lock timed out"); + }); + + let t = future::spawn(async { + trace!("trigger timeout"); + trigger_timeouts(move |labels| labels.get::().is_some()); + }); + + let r = future::spawn(async move { + trace!("acquiring read lock"); + // This is the lock acquire where we could get suck. + let _read_guard = lock2.read().await; + trace!("acquired read lock"); + }); + + w.await.unwrap(); + t.await.unwrap(); + r.await.unwrap(); + }); + }, + None, + ); +} diff --git a/wrappers/tokio/impls/tokio/inner/tests/semaphore.rs b/wrappers/tokio/impls/tokio/inner/tests/semaphore.rs new file mode 100644 index 00000000..0298644c --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/tests/semaphore.rs @@ -0,0 +1,318 @@ +use shuttle::{check_dfs, future}; +use shuttle_tokio_impl_inner::sync::Semaphore; +use std::collections::HashSet; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::sync::Mutex; +use test_log::test; + +#[test] +fn try_acquire() { + check_dfs( + || { + let sem = Arc::new(Semaphore::new(1)); + { + let p1 = sem.clone().try_acquire_owned(); + assert!(p1.is_ok()); + let p2 = sem.clone().try_acquire_owned(); + assert!(p2.is_err()); + } + let p3 = sem.try_acquire_owned(); + assert!(p3.is_ok()); + }, + None, + ); +} + +#[test] +fn try_acquire_many() { + shuttle::check_dfs( + || { + let sem = Arc::new(Semaphore::new(42)); + { + let p1 = sem.clone().try_acquire_many_owned(42); + assert!(p1.is_ok()); + let p2 = sem.clone().try_acquire_owned(); + assert!(p2.is_err()); + } + let p3 = sem.clone().try_acquire_many_owned(32); + assert!(p3.is_ok()); + let p4 = sem.clone().try_acquire_many_owned(10); + assert!(p4.is_ok()); + assert!(sem.try_acquire_owned().is_err()); + }, + None, + ); +} + +#[test] +fn semaphore_acquire() { + check_dfs( + || { + future::block_on(async move { + let sem = Arc::new(Semaphore::new(1)); + let p1 = sem.clone().try_acquire_owned().unwrap(); + let sem_clone = sem.clone(); + let j = future::spawn(async move { + let _p2 = sem_clone.acquire_owned().await; + }); + drop(p1); + j.await.unwrap(); + }); + }, + None, + ); +} + +async fn semtest(num_permits: usize, counts: Vec, states: &Arc>>) { + let s = Arc::new(Semaphore::new(num_permits)); + let r = Arc::new(AtomicUsize::new(0)); + let mut handles = vec![]; + for (i, &c) in counts.iter().enumerate() { + let s = s.clone(); + let r = r.clone(); + let states = states.clone(); + let val = 1usize << i; + handles.push(future::spawn(async move { + let permit = s.acquire_many(c as u32).await.unwrap(); + let v = r.fetch_add(val, Ordering::SeqCst); + future::yield_now().await; + let _ = r.fetch_sub(val, Ordering::SeqCst); + states.lock().unwrap().insert((i, v)); + drop(permit); + })); + } + for h in handles { + h.await.unwrap(); + } +} + +#[test] +fn semtest_1() { + let states = Arc::new(Mutex::new(HashSet::new())); + let states2 = states.clone(); + check_dfs( + move || { + let states2 = states2.clone(); + future::block_on(async move { + semtest(5, vec![3, 3, 3], &states2).await; + }); + }, + None, + ); + + let states = Arc::try_unwrap(states).unwrap().into_inner().unwrap(); + assert_eq!(states, HashSet::from([(0, 0), (1, 0), (2, 0)])); +} + +#[test] +fn semtest_2() { + let states = Arc::new(Mutex::new(HashSet::new())); + let states2 = states.clone(); + check_dfs( + move || { + let states2 = states2.clone(); + future::block_on(async move { + semtest(5, vec![3, 3, 2], &states2).await; + }); + }, + None, + ); + + let states = Arc::try_unwrap(states).unwrap().into_inner().unwrap(); + assert_eq!( + states, + HashSet::from([(0, 0), (1, 0), (2, 0), (0, 4), (1, 4), (2, 1), (2, 2)]) + ); +} + +#[test] +fn add_permits() { + check_dfs( + || { + future::block_on(async { + let sem = Arc::new(Semaphore::new(0)); + let sem_clone = sem.clone(); + let j = future::spawn(async move { + let _p2 = sem_clone.acquire_owned().await; + }); + sem.add_permits(1); + j.await.unwrap(); + }); + }, + None, + ); +} + +#[test] +fn forget() { + check_dfs( + || { + let sem = Arc::new(Semaphore::new(1)); + { + let p = sem.clone().try_acquire_owned().unwrap(); + assert_eq!(sem.available_permits(), 0); + p.forget(); + assert_eq!(sem.available_permits(), 0); + } + assert_eq!(sem.available_permits(), 0); + assert!(sem.try_acquire_owned().is_err()); + }, + None, + ); +} + +#[test] +fn merge() { + check_dfs( + || { + let sem = Arc::new(Semaphore::new(3)); + { + let mut p1 = sem.try_acquire().unwrap(); + assert_eq!(sem.available_permits(), 2); + let p2 = sem.try_acquire_many(2).unwrap(); + assert_eq!(sem.available_permits(), 0); + p1.merge(p2); + assert_eq!(sem.available_permits(), 0); + } + assert_eq!(sem.available_permits(), 3); + }, + None, + ); +} + +#[test] +#[should_panic(expected = "merging permits from different semaphore instances")] +fn merge_unrelated_permits() { + check_dfs( + || { + let sem1 = Arc::new(Semaphore::new(3)); + let sem2 = Arc::new(Semaphore::new(3)); + let mut p1 = sem1.try_acquire().unwrap(); + let p2 = sem2.try_acquire().unwrap(); + p1.merge(p2); + }, + None, + ); +} + +#[test] +fn merge_owned() { + check_dfs( + || { + let sem = Arc::new(Semaphore::new(3)); + { + let mut p1 = sem.clone().try_acquire_owned().unwrap(); + assert_eq!(sem.available_permits(), 2); + let p2 = sem.clone().try_acquire_many_owned(2).unwrap(); + assert_eq!(sem.available_permits(), 0); + p1.merge(p2); + assert_eq!(sem.available_permits(), 0); + } + assert_eq!(sem.available_permits(), 3); + }, + None, + ); +} + +#[test] +#[should_panic(expected = "merging permits from different semaphore instances")] +fn merge_unrelated_owned_permits() { + check_dfs( + || { + let sem1 = Arc::new(Semaphore::new(3)); + let sem2 = Arc::new(Semaphore::new(3)); + let mut p1 = sem1.try_acquire_owned().unwrap(); + let p2 = sem2.try_acquire_owned().unwrap(); + p1.merge(p2); + }, + None, + ); +} + +#[test] +fn split() { + check_dfs( + || { + let sem = Semaphore::new(5); + let mut p1 = sem.try_acquire_many(3).unwrap(); + assert_eq!(sem.available_permits(), 2); + assert_eq!(p1.num_permits(), 3); + let mut p2 = p1.split(1).unwrap(); + assert_eq!(sem.available_permits(), 2); + assert_eq!(p1.num_permits(), 2); + assert_eq!(p2.num_permits(), 1); + let p3 = p1.split(0).unwrap(); + assert_eq!(p3.num_permits(), 0); + drop(p1); + assert_eq!(sem.available_permits(), 4); + let p4 = p2.split(1).unwrap(); + assert_eq!(p2.num_permits(), 0); + assert_eq!(p4.num_permits(), 1); + assert!(p2.split(1).is_none()); + drop(p2); + assert_eq!(sem.available_permits(), 4); + drop(p3); + assert_eq!(sem.available_permits(), 4); + drop(p4); + assert_eq!(sem.available_permits(), 5); + }, + None, + ); +} + +#[test] +fn split_owned() { + check_dfs( + || { + let sem = Arc::new(Semaphore::new(5)); + let mut p1 = sem.clone().try_acquire_many_owned(3).unwrap(); + assert_eq!(sem.available_permits(), 2); + assert_eq!(p1.num_permits(), 3); + let mut p2 = p1.split(1).unwrap(); + assert_eq!(sem.available_permits(), 2); + assert_eq!(p1.num_permits(), 2); + assert_eq!(p2.num_permits(), 1); + let p3 = p1.split(0).unwrap(); + assert_eq!(p3.num_permits(), 0); + drop(p1); + assert_eq!(sem.available_permits(), 4); + let p4 = p2.split(1).unwrap(); + assert_eq!(p2.num_permits(), 0); + assert_eq!(p4.num_permits(), 1); + assert!(p2.split(1).is_none()); + drop(p2); + assert_eq!(sem.available_permits(), 4); + drop(p3); + assert_eq!(sem.available_permits(), 4); + drop(p4); + assert_eq!(sem.available_permits(), 5); + }, + None, + ); +} + +/* +#[test] +fn acquire_many() { + check_dfs( + || { + future::block_on(async move { + let semaphore = Arc::new(Semaphore::new(42)); + let permit32 = semaphore.clone().try_acquire_many_owned(32).unwrap(); + let (sender, receiver) = shuttle::sync::oneshot::channel(); + let join_handle = future::spawn(async move { + let _permit10 = semaphore.clone().acquire_many_owned(10).await.unwrap(); + sender.send(()).unwrap(); + let _permit32 = semaphore.acquire_many_owned(32).await.unwrap(); + }); + receiver.await.unwrap(); + drop(permit32); + join_handle.await.unwrap(); + }); + }, + None, + ); +} + +*/ diff --git a/wrappers/tokio/impls/tokio/inner/tests/time.rs b/wrappers/tokio/impls/tokio/inner/tests/time.rs new file mode 100644 index 00000000..d519679d --- /dev/null +++ b/wrappers/tokio/impls/tokio/inner/tests/time.rs @@ -0,0 +1,143 @@ +use shuttle_tokio_impl_inner::time::{clear_triggers, sleep, sleep_until, Duration, Instant}; +use test_log::test; + +fn sleep_test(duration: Duration) { + shuttle::check_dfs( + move || { + shuttle::future::block_on(async move { + let _jh = shuttle::future::spawn(async move { + sleep(duration).await; + panic!(); + }); + }); + }, + None, + ); +} + +fn sleep_until_test(deadline: Instant) { + shuttle::check_dfs( + move || { + shuttle::future::block_on(async move { + let _jh = shuttle::future::spawn(async move { + sleep_until(deadline).await; + panic!(); + }); + }); + }, + None, + ); +} + +#[test] +#[should_panic(expected = "explicit panic")] +fn sleep_panic() { + sleep_test(Duration::from_secs(10_0000)); +} + +#[test] +fn sleep_forever() { + sleep_test(Duration::MAX); +} + +#[test] +#[should_panic(expected = "explicit panic")] +fn sleep_until_panic() { + sleep_until_test(Instant::now() + Duration::from_micros(500)); +} + +// Test that `deadline`, `is_elapsed` and `reset` work as expected. +#[test] +fn sleep_reset() { + let old_deadline = Instant::now() - Duration::from_secs(5); + let mut sleep = sleep_until(old_deadline); + assert_eq!(sleep.deadline(), old_deadline); + assert!(sleep.is_elapsed()); + + let new_deadline = Instant::now() + Duration::from_secs(100); + let pinned = std::pin::Pin::new(&mut sleep); + pinned.reset(new_deadline); + assert_eq!(sleep.deadline(), new_deadline); + assert!(!sleep.is_elapsed()); +} + +mod timeout_tests { + use super::*; + use futures::future::join_all; + use shuttle::current::{me, set_label_for_task, Labels}; + use shuttle::future; + use shuttle_tokio_impl_inner::sync::Mutex; + use shuttle_tokio_impl_inner::time::{timeout, trigger_timeouts}; + use std::sync::Arc; + use test_log::test; + + #[derive(Clone, Debug)] + struct Label(usize); + + // Create an instance of the Dining Philosophers problem with `count` philosophers. + // The i'th philosopher grabs forks i and i+1 (modulo count). + async fn dining_philosophers(count: usize, trigger: bool) { + clear_triggers(); + let forks = (0..count).map(|_| Arc::new(Mutex::new(0))).collect::>(); + + let mut handles = Vec::new(); + for i in 0..count { + let left_fork = forks[i].clone(); + let right_fork = forks[(i + 1) % count].clone(); + let h = future::spawn(async move { + _ = set_label_for_task(me(), Label(i)); + let l = timeout(Duration::from_secs(1), left_fork.lock()).await; + let r = timeout(Duration::from_secs(1), right_fork.lock()).await; + l.is_err() || r.is_err() + }); + handles.push(h); + } + + if trigger { + // Trigger timeout on the middle philosopher + trigger_timeouts(move |labels: &Labels| labels.get::