diff --git a/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs b/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs index 9a50a2f3f8c..cc8901c43f0 100644 --- a/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs +++ b/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs @@ -22,6 +22,12 @@ pub fn ensures(_attr: TokenStream, tokens: TokenStream) -> TokenStream { tokens } +#[cfg(not(feature = "prusti"))] +#[proc_macro_attribute] +pub fn async_invariant(_attr: TokenStream, tokens: TokenStream) -> TokenStream { + tokens +} + #[cfg(not(feature = "prusti"))] #[proc_macro_attribute] pub fn after_expiry(_attr: TokenStream, tokens: TokenStream) -> TokenStream { @@ -130,6 +136,12 @@ pub fn body_variant(_tokens: TokenStream) -> TokenStream { TokenStream::new() } +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn suspension_point(tokens: TokenStream) -> TokenStream { + tokens +} + // ---------------------- // --- PRUSTI ENABLED --- @@ -148,6 +160,12 @@ pub fn ensures(attr: TokenStream, tokens: TokenStream) -> TokenStream { rewrite_prusti_attributes(SpecAttributeKind::Ensures, attr.into(), tokens.into()).into() } +#[cfg(feature = "prusti")] +#[proc_macro_attribute] +pub fn async_invariant(attr: TokenStream, tokens: TokenStream) -> TokenStream { + rewrite_prusti_attributes(SpecAttributeKind::AsyncInvariant, attr.into(), tokens.into()).into() +} + #[cfg(feature = "prusti")] #[proc_macro_attribute] pub fn after_expiry(attr: TokenStream, tokens: TokenStream) -> TokenStream { @@ -273,5 +291,11 @@ pub fn body_variant(tokens: TokenStream) -> TokenStream { prusti_specs::body_variant(tokens.into()).into() } +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn suspension_point(tokens: TokenStream) -> TokenStream { + prusti_specs::suspension_point(tokens.into()).into() +} + // Ensure that you've also crated a transparent `#[cfg(not(feature = "prusti"))]` // version of your new macro above! diff --git a/prusti-contracts/prusti-contracts/src/lib.rs b/prusti-contracts/prusti-contracts/src/lib.rs index e1d1f17648a..1a580241cbf 100644 --- a/prusti-contracts/prusti-contracts/src/lib.rs +++ b/prusti-contracts/prusti-contracts/src/lib.rs @@ -66,6 +66,12 @@ pub use prusti_contracts_proc_macros::terminates; /// A macro to annotate body variant of a loop to prove termination pub use prusti_contracts_proc_macros::body_variant; +/// A macro to annotate suspension points inside async constructs +pub use prusti_contracts_proc_macros::suspension_point; + +/// A macro for writing async invariants on an async function +pub use prusti_contracts_proc_macros::async_invariant; + #[cfg(not(feature = "prusti"))] mod private { use core::marker::PhantomData; @@ -339,6 +345,10 @@ pub fn old(arg: T) -> T { arg } +pub fn suspension_point_on_exit_marker(_label: u32, _closures: T) {} + +pub fn suspension_point_on_entry_marker(_label: u32, _closures: T) {} + /// Universal quantifier. /// /// This is a Prusti-internal representation of the `forall` syntax. diff --git a/prusti-contracts/prusti-specs/src/lib.rs b/prusti-contracts/prusti-specs/src/lib.rs index 3403260c345..76074b2e238 100644 --- a/prusti-contracts/prusti-specs/src/lib.rs +++ b/prusti-contracts/prusti-specs/src/lib.rs @@ -70,6 +70,7 @@ fn extract_prusti_attributes( let tokens = match attr_kind { SpecAttributeKind::Requires | SpecAttributeKind::Ensures + | SpecAttributeKind::AsyncInvariant | SpecAttributeKind::AfterExpiry | SpecAttributeKind::AssertOnExpiry | SpecAttributeKind::RefineSpec => { @@ -81,6 +82,10 @@ fn extract_prusti_attributes( assert!(iter.next().is_none(), "Unexpected shape of an attribute."); group.stream() } + // these cannot appear here, as postconditions are only marked as async + // at a later stage + SpecAttributeKind::AsyncEnsures => + unreachable!("SpecAttributeKind::AsyncEnsures should not appear at this stage"), // Nothing to do for attributes without arguments. SpecAttributeKind::Pure | SpecAttributeKind::Terminates @@ -125,6 +130,39 @@ pub fn rewrite_prusti_attributes( // Collect the remaining Prusti attributes, removing them from `item`. prusti_attributes.extend(extract_prusti_attributes(&mut item)); + // in the case of an async fn, mark the postconditions as such, + // since they are handled differently + if let untyped::AnyFnItem::Fn(ref fn_item) = &item { + if fn_item.sig.asyncness.is_some() { + prusti_attributes = prusti_attributes + .into_iter() + .map(|(attr, tt)| match attr { + SpecAttributeKind::Ensures => (SpecAttributeKind::AsyncEnsures, tt), + _ => (attr, tt) + } + ) + .collect(); + } + } + + // make sure async invariants can only be attached to async fn's + if matches!(outer_attr_kind, SpecAttributeKind::AsyncInvariant) { + // TODO: should the error be reported to the function's or the attribute's span? + let untyped::AnyFnItem::Fn(ref fn_item) = &item else { + return syn::Error::new( + item.span(), + "async_invariant attached to non-function item" + ).to_compile_error(); + }; + if fn_item.sig.asyncness.is_none() { + return syn::Error::new( + item.span(), + "async_invariant attached to non-async function" + ) + .to_compile_error(); + } + } + // make sure to also update the check in the predicate! handling method if prusti_attributes .iter() @@ -162,6 +200,8 @@ fn generate_spec_and_assertions( let rewriting_result = match attr_kind { SpecAttributeKind::Requires => generate_for_requires(attr_tokens, item), SpecAttributeKind::Ensures => generate_for_ensures(attr_tokens, item), + SpecAttributeKind::AsyncEnsures => generate_for_async_ensures(attr_tokens, item), + SpecAttributeKind::AsyncInvariant => generate_for_async_invariant(attr_tokens, item), SpecAttributeKind::AfterExpiry => generate_for_after_expiry(attr_tokens, item), SpecAttributeKind::AssertOnExpiry => generate_for_assert_on_expiry(attr_tokens, item), SpecAttributeKind::Pure => generate_for_pure(attr_tokens, item), @@ -215,6 +255,55 @@ fn generate_for_ensures(attr: TokenStream, item: &untyped::AnyFnItem) -> Generat )) } +/// Generate spec items and attributes to typecheck and later retrieve "ensures" annotations, +/// but for async fn, as their postconditions will need to be moved to their generator method +/// instead of the future-constructor. +fn generate_for_async_ensures(attr: TokenStream, item: &untyped::AnyFnItem) -> GeneratedResult { + let mut rewriter = rewriter::AstRewriter::new(); + // parse the postcondition + let expr = parse_prusti(attr)?; + // generate a spec item for the postcondition itself, + // which will be attached to the future's implementation + let spec_id = rewriter.generate_spec_id(); + let spec_id_str = spec_id.to_string(); + let spec_item = + rewriter.generate_spec_item_fn(rewriter::SpecItemType::Postcondition, spec_id, expr.clone(), item)?; + + // and generate one wrapped in a `Poll` for the stub + let stub_spec_id = rewriter.generate_spec_id(); + let stub_spec_id_str = stub_spec_id.to_string(); + let stub_spec_item = rewriter.generate_async_ensures_item_fn(stub_spec_id, expr, item)?; + + Ok(( + vec![spec_item, stub_spec_item], + vec![ + parse_quote_spanned! {item.span()=> + #[prusti::async_post_spec_id_ref = #spec_id_str] + }, + parse_quote_spanned! {item.span()=> + #[prusti::async_stub_post_spec_id_ref = #stub_spec_id_str] + }, + ] + )) +} + + +/// Generate spec items and attributes to typecheck and later retrieve invariant annotations +/// on async functions. +fn generate_for_async_invariant(attr: TokenStream, item: &untyped::AnyFnItem) -> GeneratedResult { + let mut rewriter = rewriter::AstRewriter::new(); + let spec_id = rewriter.generate_spec_id(); + let spec_id_str = spec_id.to_string(); + let spec_item = + rewriter.process_assertion(rewriter::SpecItemType::Precondition, spec_id, attr, item)?; + Ok(( + vec![spec_item], + vec![parse_quote_spanned! {item.span()=> + #[prusti::async_inv_spec_id_ref = #spec_id_str] + }] + )) +} + /// Generate spec items and attributes to typecheck and later retrieve "after_expiry" annotations. fn generate_for_after_expiry(attr: TokenStream, item: &untyped::AnyFnItem) -> GeneratedResult { let mut rewriter = rewriter::AstRewriter::new(); @@ -433,6 +522,146 @@ pub fn prusti_refutation(tokens: TokenStream) -> TokenStream { generate_expression_closure(&AstRewriter::process_prusti_refutation, tokens) } + +pub fn suspension_point(tokens: TokenStream) -> TokenStream { + // parse the expression inside the suspension point, + // making sure it is just an await (possibly with attributes) + let expr: syn::Expr = handle_result!(syn::parse2(tokens)); + let expr_span = expr.span(); + let syn::Expr::Await(mut await_expr) = expr else { + return syn::Error::new( + expr.span(), + "suspension-point must contain a single await-expression" + ).to_compile_error(); + }; + let future = &await_expr.base; + + let mut label: Option = None; + let mut on_exit: Vec = Vec::new(); + let mut on_entry: Vec = Vec::new(); + + // extract label, on-exit, and on-entry conditions from the attributes + for attr in await_expr.attrs { + let attr_span = attr.span(); + if !matches!(attr.style, syn::AttrStyle::Outer) { + return syn::Error::new( + attr_span, + "only outer attributes allowed in suspension-point", + ).to_compile_error(); + } + if !(attr.path.segments.len() == 1 + || (attr.path.segments.len() == 2 && attr.path.segments[0].ident == "prusti_contracts")) + { + return syn::Error::new( + attr_span, + "invalid attribute in suspension-point", + ).to_compile_error(); + } + let name = attr.path.segments[attr.path.segments.len() - 1] + .ident + .to_string(); + // labels + if name == "label" { + let [proc_macro2::TokenTree::Group(group)] = + &attr.tokens.into_iter().collect::>()[..] + else { + return syn::Error::new( + attr_span, + "expected group with a single positive integer as label", + ).to_compile_error(); + }; + let [proc_macro2::TokenTree::Literal(lit)] = + &group.stream().into_iter().collect::>()[..] + else { + return syn::Error::new( + attr_span, + "expected single positive integer as label", + ).to_compile_error(); + }; + let lbl_num: u32 = lit + .to_string() + .parse() + .expect("expected single positive integer as label"); + if lbl_num == 0 { + return syn::Error::new( + attr_span, + "suspension-point label must be a positive integer", + ).to_compile_error(); + } + if label.replace(lbl_num).is_some() { + return syn::Error::new( + attr_span, + "multiple labels provided for suspension-point", + ).to_compile_error(); + } + // on-exit conditions + } else if name == "on_exit" { + on_exit.push(attr.tokens); + // on-entry conditions + } else if name == "on_entry" { + on_entry.push(attr.tokens); + // all other attributes are not permitted + } else { + // TODO: more precise error message with span? + panic!("invalid attribute in suspension-point"); + } + } + + let label = label.unwrap_or_else(|| panic!("suspension-point must have a label")); + + // generate spec-items for each condition + let create_spec_item = |tokens: TokenStream| -> syn::Result { + let expr = parse_prusti(tokens)?; + Ok(parse_quote_spanned! {expr_span=> + || -> bool { + let val: bool = #expr; + val + } + }) + }; + + let on_exit: syn::Result> = on_exit + .into_iter() + .map(create_spec_item) + .collect(); + let on_exit = handle_result!(on_exit); + + let on_entry: syn::Result> = on_entry + .into_iter() + .map(create_spec_item) + .collect(); + let on_entry = handle_result!(on_entry); + + // return just the await-expression + let on_exit_closures = match on_exit.len() { + 1 => { + let on_exit = &on_exit[0]; + quote_spanned! { expr_span => (#on_exit,) } + } + _ => quote_spanned! { expr_span => (#(#on_exit),*) }, + }; + let on_entry_closures = match on_entry.len() { + 1 => { + let on_entry = &on_entry[0]; + quote_spanned! { expr_span => (#on_entry,) } + } + _ => quote_spanned! { expr_span => (#(#on_entry),*) }, + }; + + await_expr.attrs = Vec::new(); + quote_spanned! { expr_span => + { + let fut = #future; + #[allow(unused_parens)] + ::prusti_contracts::suspension_point_on_exit_marker(#label, #on_exit_closures); + let res = fut.await; + #[allow(unused_parens)] + ::prusti_contracts::suspension_point_on_entry_marker(#label, #on_entry_closures); + res + } + } +} + /// Generates the TokenStream encoding an expression using prusti syntax /// Used for body invariants, assertions, and assumptions fn generate_expression_closure( @@ -855,6 +1084,8 @@ fn extract_prusti_attributes_for_types( let tokens = match attr_kind { SpecAttributeKind::Requires => unreachable!("requires on type"), SpecAttributeKind::Ensures => unreachable!("ensures on type"), + SpecAttributeKind::AsyncEnsures => unreachable!("ensures on type"), + SpecAttributeKind::AsyncInvariant => unreachable!("async-invariant on type"), SpecAttributeKind::AfterExpiry => unreachable!("after_expiry on type"), SpecAttributeKind::AssertOnExpiry => unreachable!("assert_on_expiry on type"), SpecAttributeKind::RefineSpec => unreachable!("refine_spec on type"), @@ -900,6 +1131,8 @@ fn generate_spec_and_assertions_for_types( let rewriting_result = match attr_kind { SpecAttributeKind::Requires => unreachable!(), SpecAttributeKind::Ensures => unreachable!(), + SpecAttributeKind::AsyncEnsures => unreachable!(), + SpecAttributeKind::AsyncInvariant => unreachable!(), SpecAttributeKind::AfterExpiry => unreachable!(), SpecAttributeKind::AssertOnExpiry => unreachable!(), SpecAttributeKind::Pure => unreachable!(), diff --git a/prusti-contracts/prusti-specs/src/rewriter.rs b/prusti-contracts/prusti-specs/src/rewriter.rs index 9595485ab59..f84a54315cb 100644 --- a/prusti-contracts/prusti-specs/src/rewriter.rs +++ b/prusti-contracts/prusti-specs/src/rewriter.rs @@ -145,6 +145,62 @@ impl AstRewriter { Ok(syn::Item::Fn(spec_item)) } + /// Wrap an expression for an async postcondition in `Poll` and create the appropriate function + /// Can only be used for async postconditions and must be called after `generate_spec_item_fn` + /// to create the postcondition for the poll implementation + pub fn generate_async_ensures_item_fn( + &mut self, + spec_id: SpecificationId, + expr: TokenStream, + item: &T, + ) -> syn::Result { + // NOTE: we don't need to check for `result` as a parameter again, as this will only + // be called after `generate_spec_item_fn` for the same expression + let item_span = expr.span(); + let item_name = syn::Ident::new( + &format!("prusti_{}_item_{}_{}", SpecItemType::Postcondition, item.sig().ident, spec_id), + item_span, + ); + let spec_id_str = spec_id.to_string(); + + let mut spec_item: syn::ItemFn = parse_quote_spanned! {item_span=> + #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case)] + #[prusti::spec_only] + #[prusti::spec_id = #spec_id_str] + fn #item_name() -> bool { + let val: bool = if let ::std::task::Poll::Ready(result) = result { + #expr + } else { + true + }; + val + } + }; + spec_item.sig.generics = item.sig().generics.clone(); + spec_item.sig.inputs = item.sig().inputs.clone(); + + // note that the result-arg's type should not be the async function's return type `T` + // but `Poll` + let result_arg: syn::FnArg = { + // analogous to `generate_result_arg` + let item_span = item.span(); + let output_ty = match &item.sig().output { + syn::ReturnType::Default => parse_quote_spanned!(item_span=> ()), + syn::ReturnType::Type(_, ty) => ty.clone(), + }; + let output_ty = parse_quote_spanned!(item_span=> ::std::task::Poll<#output_ty>); + let fn_arg = syn::FnArg::Typed(syn::PatType { + attrs: Vec::new(), + pat: Box::new(parse_quote_spanned!(item_span=> result)), + colon_token: syn::Token![:](item.sig().output.span()), + ty: output_ty, + }); + fn_arg + }; + spec_item.sig.inputs.push(result_arg); + Ok(syn::Item::Fn(spec_item)) + } + /// Parse an assertion into a Rust expression pub fn process_assertion( &mut self, diff --git a/prusti-contracts/prusti-specs/src/spec_attribute_kind.rs b/prusti-contracts/prusti-specs/src/spec_attribute_kind.rs index f286cbd317b..7f730496336 100644 --- a/prusti-contracts/prusti-specs/src/spec_attribute_kind.rs +++ b/prusti-contracts/prusti-specs/src/spec_attribute_kind.rs @@ -18,6 +18,8 @@ pub enum SpecAttributeKind { Terminates = 10, PrintCounterexample = 11, Verified = 12, + AsyncEnsures = 13, + AsyncInvariant = 14, } impl TryFrom for SpecAttributeKind { @@ -37,6 +39,7 @@ impl TryFrom for SpecAttributeKind { "model" => Ok(SpecAttributeKind::Model), "print_counterexample" => Ok(SpecAttributeKind::PrintCounterexample), "verified" => Ok(SpecAttributeKind::Verified), + "async_invariant" => Ok(SpecAttributeKind::AsyncInvariant), _ => Err(name), } } diff --git a/prusti-contracts/prusti-specs/src/specifications/common.rs b/prusti-contracts/prusti-specs/src/specifications/common.rs index dca66437e79..4baaab873b3 100644 --- a/prusti-contracts/prusti-specs/src/specifications/common.rs +++ b/prusti-contracts/prusti-specs/src/specifications/common.rs @@ -56,6 +56,9 @@ pub struct SpecificationId(Uuid); pub enum SpecIdRef { Precondition(SpecificationId), Postcondition(SpecificationId), + AsyncPostcondition(SpecificationId), + AsyncStubPostcondition(SpecificationId), + AsyncInvariant(SpecificationId), Purity(SpecificationId), Pledge { lhs: Option, diff --git a/prusti-interface/src/specs/mod.rs b/prusti-interface/src/specs/mod.rs index cf5414e1d6f..b953c3ab104 100644 --- a/prusti-interface/src/specs/mod.rs +++ b/prusti-interface/src/specs/mod.rs @@ -17,7 +17,7 @@ use prusti_rustc_interface::{ def_id::{DefId, LocalDefId}, intravisit, FnRetTy, }, - middle::{hir::map::Map, ty}, + middle::{hir::map::{self as hir_map, Map}, ty}, span::Span, }; use std::{convert::TryInto, fmt::Debug}; @@ -33,10 +33,11 @@ use typed::SpecIdRef; use crate::specs::{ external::ExternSpecResolver, - typed::{ProcedureSpecification, ProcedureSpecificationKind, SpecGraph, SpecificationItem}, + typed::{ProcedureSpecification, ProcedureSpecificationKind, SpecGraph, SpecificationItem, ProcedureKind}, }; use prusti_specs::specifications::common::SpecificationId; + #[derive(Debug)] struct ProcedureSpecRefs { spec_id_refs: Vec, @@ -86,6 +87,9 @@ pub struct SpecCollector<'a, 'tcx> { prusti_refutations: Vec, ghost_begin: Vec, ghost_end: Vec, + + /// Map from future constructors to their poll methods + async_parent: FxHashMap, } impl<'a, 'tcx> SpecCollector<'a, 'tcx> { @@ -103,6 +107,7 @@ impl<'a, 'tcx> SpecCollector<'a, 'tcx> { prusti_refutations: vec![], ghost_begin: vec![], ghost_end: vec![], + async_parent: FxHashMap::default(), } } @@ -150,6 +155,12 @@ impl<'a, 'tcx> SpecCollector<'a, 'tcx> { self.env, ); } + // both async postconditions and async invariants are not added to the method + // they were attached to (which ends up being the future constructor) but to + // the poll method, which is determined elsewhere + SpecIdRef::AsyncPostcondition(_) + | SpecIdRef::AsyncStubPostcondition(_) + | SpecIdRef::AsyncInvariant(_) => {}, SpecIdRef::Purity(spec_id) => { spec.add_purity(*self.spec_functions.get(spec_id).unwrap(), self.env); } @@ -191,6 +202,50 @@ impl<'a, 'tcx> SpecCollector<'a, 'tcx> { def_spec.proc_specs.insert(local_id.to_def_id(), spec); } } + + // attach async specifications to the future's poll-methods instead of the future + // constructor + for (local_id, parent_id) in self.async_parent.iter() { + // look up parent's spec and skip this method if the parent doesn't have one + let Some(parent_spec) = self.procedure_specs.get(parent_id) else { + continue; + }; + + // mark parent as an async constructor + def_spec + .proc_specs + .get_mut(&parent_id.to_def_id()) + .expect("async parent must have entry in DefSpecMap") + .set_proc_kind(ProcedureKind::AsyncConstructor); + + // the spec is then essentially inherited from the parent, + // but for now, only postconditions and trusted are allowed + let mut spec = SpecGraph::new(ProcedureSpecification::empty(local_id.to_def_id())); + spec.set_proc_kind(ProcedureKind::AsyncPoll); + spec.set_kind(parent_spec.into()); + spec.set_trusted(parent_spec.trusted); + for spec_id_ref in &parent_spec.spec_id_refs { + match spec_id_ref { + SpecIdRef::AsyncPostcondition(spec_id) => { + spec.add_postcondition(*self.spec_functions.get(spec_id).unwrap(), self.env); + } + SpecIdRef::AsyncStubPostcondition(spec_id) => { + spec.add_async_stub_postcondition(*self.spec_functions.get(spec_id).unwrap(), self.env); + } + SpecIdRef::AsyncInvariant(spec_id) => { + spec.add_async_invariant(*self.spec_functions.get(spec_id).unwrap(), self.env); + } + // all other spec items should stay attached to the original method + _ => {}, + } + } + // poll methods should not already be covered by the existing code, + // as they cannot be annotated by the user + let old = def_spec.proc_specs.insert(local_id.to_def_id(), spec); + if old.is_some() { + panic!("async fn poll method {local_id:?} already had a spec"); + } + } } fn determine_extern_specs(&self, def_spec: &mut typed::DefSpecificationMap) { @@ -381,6 +436,21 @@ fn get_procedure_spec_ids(def_id: DefId, attrs: &[ast::Attribute]) -> Option intravisit::Visitor<'tcx> for SpecCollector<'a, 'tcx> { let def_id = local_id.to_def_id(); let attrs = self.env.query.get_local_attributes(local_id); + // if this is a closure representing an async fn's poll, locate and store its parent + if self.env.tcx().generator_is_async(def_id) { + let hir_id = self.env.tcx().hir().local_def_id_to_hir_id(local_id); + let parent = self.env.tcx().hir() + .find_parent(hir_id) + .expect("expected async-fn-generator to have a parent"); + // we can get the parent's LocalDefId via its associated body + let (parent_def_id, _) = hir_map::associated_body(parent) + .expect("async-fn-generator parent should have a body"); + assert!(self.env.tcx().asyncness(parent_def_id.to_def_id()).is_async(), "found async generator with non-async parent"); + assert!(!self.env.tcx().generator_is_async(parent_def_id.to_def_id()), "found async generator whose parent is also async generator"); + let old = self.async_parent.insert(local_id, parent_def_id); + if old.is_some() { + panic!("parent {:?} of async-generator {:?} already has parent", old.unwrap(), local_id); + } + } + // Collect spec functions if let Some(raw_spec_id) = read_prusti_attr("spec_id", attrs) { let spec_id: SpecificationId = parse_spec_id(raw_spec_id, def_id); diff --git a/prusti-interface/src/specs/typed.rs b/prusti-interface/src/specs/typed.rs index 2041266c555..7a9f1a7344c 100644 --- a/prusti-interface/src/specs/typed.rs +++ b/prusti-interface/src/specs/typed.rs @@ -82,6 +82,9 @@ impl DefSpecificationMap { if let Some(posts) = spec.posts.extract_with_selective_replacement() { specs.extend(posts); } + if let Some(posts) = spec.async_stub_posts.extract_with_selective_replacement() { + specs.extend(posts); + } if let Some(Some(term)) = spec.terminates.extract_with_selective_replacement() { specs.push(term.to_def_id()); } @@ -202,6 +205,13 @@ impl DefSpecificationMap { } } +#[derive(Debug, Copy, Clone, TyEncodable, TyDecodable)] +pub enum ProcedureKind { + Method, + AsyncConstructor, + AsyncPoll, +} + #[derive(Debug, Clone, TyEncodable, TyDecodable)] pub struct ProcedureSpecification { // DefId of fn signature to which the spec was attached. @@ -210,10 +220,13 @@ pub struct ProcedureSpecification { pub kind: SpecificationItem, pub pres: SpecificationItem>, pub posts: SpecificationItem>, + pub async_stub_posts: SpecificationItem>, + pub async_invariants: SpecificationItem>, pub pledges: SpecificationItem>, pub trusted: SpecificationItem, pub terminates: SpecificationItem>, pub purity: SpecificationItem>, // for type-conditional spec refinements + pub proc_kind: ProcedureKind, } impl ProcedureSpecification { @@ -225,10 +238,13 @@ impl ProcedureSpecification { kind: SpecificationItem::Inherent(ProcedureSpecificationKind::Impure), pres: SpecificationItem::Empty, posts: SpecificationItem::Empty, + async_stub_posts: SpecificationItem::Empty, + async_invariants: SpecificationItem::Empty, pledges: SpecificationItem::Empty, trusted: SpecificationItem::Inherent(false), terminates: SpecificationItem::Inherent(None), purity: SpecificationItem::Inherent(None), + proc_kind: ProcedureKind::Method, } } } @@ -453,6 +469,46 @@ impl SpecGraph { } } + /// Attaches the async stub postcondition `post` to this [SpecGraph]. + /// + /// If this postcondition has a constraint it will be attached to the corresponding + /// constrained spec **and** the base spec, otherwise just to the base spec. + pub fn add_async_stub_postcondition<'tcx>(&mut self, post: LocalDefId, env: &Environment<'tcx>) { + match self.get_constraint(post, env) { + None => { + self.base_spec.async_stub_posts.push(post.to_def_id()); + self.specs_with_constraints + .values_mut() + .for_each(|s| s.async_stub_posts.push(post.to_def_id())); + } + Some(obligation) => { + self.get_constrained_spec_mut(obligation) + .async_stub_posts + .push(post.to_def_id()); + } + } + } + + /// Attaches the async invariant `inv` to this [SpecGraph]. + /// + /// If this postcondition has a constraint it will be attached to the corresponding + /// constrained spec **and** the base spec, otherwise just to the base spec. + pub fn add_async_invariant<'tcx>(&mut self, inv: LocalDefId, env: &Environment<'tcx>) { + match self.get_constraint(inv, env) { + None => { + self.base_spec.async_invariants.push(inv.to_def_id()); + self.specs_with_constraints + .values_mut() + .for_each(|s| s.async_invariants.push(inv.to_def_id())); + } + Some(obligation) => { + self.get_constrained_spec_mut(obligation) + .async_invariants + .push(inv.to_def_id()); + } + } + } + pub fn add_purity<'tcx>(&mut self, purity: LocalDefId, env: &Environment<'tcx>) { match self.get_constraint(purity, env) { None => { @@ -502,6 +558,14 @@ impl SpecGraph { .for_each(|s| s.kind.set(kind)); } + /// Sets the [ProcedureKind] for the base spec and all constrained specs. + pub fn set_proc_kind(&mut self, proc_kind: ProcedureKind) { + self.base_spec.proc_kind = proc_kind; + self.specs_with_constraints + .values_mut() + .for_each(|s| s.proc_kind = proc_kind); + } + /// Lazily gets/creates a constrained spec. /// If the constrained spec does not yet exist, the base spec serves as a template for /// the newly created constrained spec. @@ -777,11 +841,14 @@ impl Refinable for ProcedureSpecification { source: self.source, pres: self.pres.refine(replace_empty(&EMPTYL, &other.pres)), posts: self.posts.refine(replace_empty(&EMPTYL, &other.posts)), + async_stub_posts: self.async_stub_posts.refine(replace_empty(&EMPTYL, &other.async_stub_posts)), + async_invariants: self.async_invariants.refine(replace_empty(&EMPTYL, &other.async_invariants)), pledges: self.pledges.refine(replace_empty(&EMPTYP, &other.pledges)), kind: self.kind.refine(&other.kind), trusted: self.trusted.refine(&other.trusted), terminates: self.terminates.refine(&other.terminates), purity: self.purity.refine(&other.purity), + proc_kind: self.proc_kind, } } }