diff --git a/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs b/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs index 9a50a2f3f8c..cc8901c43f0 100644 --- a/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs +++ b/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs @@ -22,6 +22,12 @@ pub fn ensures(_attr: TokenStream, tokens: TokenStream) -> TokenStream { tokens } +#[cfg(not(feature = "prusti"))] +#[proc_macro_attribute] +pub fn async_invariant(_attr: TokenStream, tokens: TokenStream) -> TokenStream { + tokens +} + #[cfg(not(feature = "prusti"))] #[proc_macro_attribute] pub fn after_expiry(_attr: TokenStream, tokens: TokenStream) -> TokenStream { @@ -130,6 +136,12 @@ pub fn body_variant(_tokens: TokenStream) -> TokenStream { TokenStream::new() } +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn suspension_point(tokens: TokenStream) -> TokenStream { + tokens +} + // ---------------------- // --- PRUSTI ENABLED --- @@ -148,6 +160,12 @@ pub fn ensures(attr: TokenStream, tokens: TokenStream) -> TokenStream { rewrite_prusti_attributes(SpecAttributeKind::Ensures, attr.into(), tokens.into()).into() } +#[cfg(feature = "prusti")] +#[proc_macro_attribute] +pub fn async_invariant(attr: TokenStream, tokens: TokenStream) -> TokenStream { + rewrite_prusti_attributes(SpecAttributeKind::AsyncInvariant, attr.into(), tokens.into()).into() +} + #[cfg(feature = "prusti")] #[proc_macro_attribute] pub fn after_expiry(attr: TokenStream, tokens: TokenStream) -> TokenStream { @@ -273,5 +291,11 @@ pub fn body_variant(tokens: TokenStream) -> TokenStream { prusti_specs::body_variant(tokens.into()).into() } +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn suspension_point(tokens: TokenStream) -> TokenStream { + prusti_specs::suspension_point(tokens.into()).into() +} + // Ensure that you've also crated a transparent `#[cfg(not(feature = "prusti"))]` // version of your new macro above! diff --git a/prusti-contracts/prusti-contracts/src/lib.rs b/prusti-contracts/prusti-contracts/src/lib.rs index e1d1f17648a..1a580241cbf 100644 --- a/prusti-contracts/prusti-contracts/src/lib.rs +++ b/prusti-contracts/prusti-contracts/src/lib.rs @@ -66,6 +66,12 @@ pub use prusti_contracts_proc_macros::terminates; /// A macro to annotate body variant of a loop to prove termination pub use prusti_contracts_proc_macros::body_variant; +/// A macro to annotate suspension points inside async constructs +pub use prusti_contracts_proc_macros::suspension_point; + +/// A macro for writing async invariants on an async function +pub use prusti_contracts_proc_macros::async_invariant; + #[cfg(not(feature = "prusti"))] mod private { use core::marker::PhantomData; @@ -339,6 +345,10 @@ pub fn old(arg: T) -> T { arg } +pub fn suspension_point_on_exit_marker(_label: u32, _closures: T) {} + +pub fn suspension_point_on_entry_marker(_label: u32, _closures: T) {} + /// Universal quantifier. /// /// This is a Prusti-internal representation of the `forall` syntax. diff --git a/prusti-contracts/prusti-specs/src/lib.rs b/prusti-contracts/prusti-specs/src/lib.rs index 3403260c345..76074b2e238 100644 --- a/prusti-contracts/prusti-specs/src/lib.rs +++ b/prusti-contracts/prusti-specs/src/lib.rs @@ -70,6 +70,7 @@ fn extract_prusti_attributes( let tokens = match attr_kind { SpecAttributeKind::Requires | SpecAttributeKind::Ensures + | SpecAttributeKind::AsyncInvariant | SpecAttributeKind::AfterExpiry | SpecAttributeKind::AssertOnExpiry | SpecAttributeKind::RefineSpec => { @@ -81,6 +82,10 @@ fn extract_prusti_attributes( assert!(iter.next().is_none(), "Unexpected shape of an attribute."); group.stream() } + // these cannot appear here, as postconditions are only marked as async + // at a later stage + SpecAttributeKind::AsyncEnsures => + unreachable!("SpecAttributeKind::AsyncEnsures should not appear at this stage"), // Nothing to do for attributes without arguments. SpecAttributeKind::Pure | SpecAttributeKind::Terminates @@ -125,6 +130,39 @@ pub fn rewrite_prusti_attributes( // Collect the remaining Prusti attributes, removing them from `item`. prusti_attributes.extend(extract_prusti_attributes(&mut item)); + // in the case of an async fn, mark the postconditions as such, + // since they are handled differently + if let untyped::AnyFnItem::Fn(ref fn_item) = &item { + if fn_item.sig.asyncness.is_some() { + prusti_attributes = prusti_attributes + .into_iter() + .map(|(attr, tt)| match attr { + SpecAttributeKind::Ensures => (SpecAttributeKind::AsyncEnsures, tt), + _ => (attr, tt) + } + ) + .collect(); + } + } + + // make sure async invariants can only be attached to async fn's + if matches!(outer_attr_kind, SpecAttributeKind::AsyncInvariant) { + // TODO: should the error be reported to the function's or the attribute's span? + let untyped::AnyFnItem::Fn(ref fn_item) = &item else { + return syn::Error::new( + item.span(), + "async_invariant attached to non-function item" + ).to_compile_error(); + }; + if fn_item.sig.asyncness.is_none() { + return syn::Error::new( + item.span(), + "async_invariant attached to non-async function" + ) + .to_compile_error(); + } + } + // make sure to also update the check in the predicate! handling method if prusti_attributes .iter() @@ -162,6 +200,8 @@ fn generate_spec_and_assertions( let rewriting_result = match attr_kind { SpecAttributeKind::Requires => generate_for_requires(attr_tokens, item), SpecAttributeKind::Ensures => generate_for_ensures(attr_tokens, item), + SpecAttributeKind::AsyncEnsures => generate_for_async_ensures(attr_tokens, item), + SpecAttributeKind::AsyncInvariant => generate_for_async_invariant(attr_tokens, item), SpecAttributeKind::AfterExpiry => generate_for_after_expiry(attr_tokens, item), SpecAttributeKind::AssertOnExpiry => generate_for_assert_on_expiry(attr_tokens, item), SpecAttributeKind::Pure => generate_for_pure(attr_tokens, item), @@ -215,6 +255,55 @@ fn generate_for_ensures(attr: TokenStream, item: &untyped::AnyFnItem) -> Generat )) } +/// Generate spec items and attributes to typecheck and later retrieve "ensures" annotations, +/// but for async fn, as their postconditions will need to be moved to their generator method +/// instead of the future-constructor. +fn generate_for_async_ensures(attr: TokenStream, item: &untyped::AnyFnItem) -> GeneratedResult { + let mut rewriter = rewriter::AstRewriter::new(); + // parse the postcondition + let expr = parse_prusti(attr)?; + // generate a spec item for the postcondition itself, + // which will be attached to the future's implementation + let spec_id = rewriter.generate_spec_id(); + let spec_id_str = spec_id.to_string(); + let spec_item = + rewriter.generate_spec_item_fn(rewriter::SpecItemType::Postcondition, spec_id, expr.clone(), item)?; + + // and generate one wrapped in a `Poll` for the stub + let stub_spec_id = rewriter.generate_spec_id(); + let stub_spec_id_str = stub_spec_id.to_string(); + let stub_spec_item = rewriter.generate_async_ensures_item_fn(stub_spec_id, expr, item)?; + + Ok(( + vec![spec_item, stub_spec_item], + vec![ + parse_quote_spanned! {item.span()=> + #[prusti::async_post_spec_id_ref = #spec_id_str] + }, + parse_quote_spanned! {item.span()=> + #[prusti::async_stub_post_spec_id_ref = #stub_spec_id_str] + }, + ] + )) +} + + +/// Generate spec items and attributes to typecheck and later retrieve invariant annotations +/// on async functions. +fn generate_for_async_invariant(attr: TokenStream, item: &untyped::AnyFnItem) -> GeneratedResult { + let mut rewriter = rewriter::AstRewriter::new(); + let spec_id = rewriter.generate_spec_id(); + let spec_id_str = spec_id.to_string(); + let spec_item = + rewriter.process_assertion(rewriter::SpecItemType::Precondition, spec_id, attr, item)?; + Ok(( + vec![spec_item], + vec![parse_quote_spanned! {item.span()=> + #[prusti::async_inv_spec_id_ref = #spec_id_str] + }] + )) +} + /// Generate spec items and attributes to typecheck and later retrieve "after_expiry" annotations. fn generate_for_after_expiry(attr: TokenStream, item: &untyped::AnyFnItem) -> GeneratedResult { let mut rewriter = rewriter::AstRewriter::new(); @@ -433,6 +522,146 @@ pub fn prusti_refutation(tokens: TokenStream) -> TokenStream { generate_expression_closure(&AstRewriter::process_prusti_refutation, tokens) } + +pub fn suspension_point(tokens: TokenStream) -> TokenStream { + // parse the expression inside the suspension point, + // making sure it is just an await (possibly with attributes) + let expr: syn::Expr = handle_result!(syn::parse2(tokens)); + let expr_span = expr.span(); + let syn::Expr::Await(mut await_expr) = expr else { + return syn::Error::new( + expr.span(), + "suspension-point must contain a single await-expression" + ).to_compile_error(); + }; + let future = &await_expr.base; + + let mut label: Option = None; + let mut on_exit: Vec = Vec::new(); + let mut on_entry: Vec = Vec::new(); + + // extract label, on-exit, and on-entry conditions from the attributes + for attr in await_expr.attrs { + let attr_span = attr.span(); + if !matches!(attr.style, syn::AttrStyle::Outer) { + return syn::Error::new( + attr_span, + "only outer attributes allowed in suspension-point", + ).to_compile_error(); + } + if !(attr.path.segments.len() == 1 + || (attr.path.segments.len() == 2 && attr.path.segments[0].ident == "prusti_contracts")) + { + return syn::Error::new( + attr_span, + "invalid attribute in suspension-point", + ).to_compile_error(); + } + let name = attr.path.segments[attr.path.segments.len() - 1] + .ident + .to_string(); + // labels + if name == "label" { + let [proc_macro2::TokenTree::Group(group)] = + &attr.tokens.into_iter().collect::>()[..] + else { + return syn::Error::new( + attr_span, + "expected group with a single positive integer as label", + ).to_compile_error(); + }; + let [proc_macro2::TokenTree::Literal(lit)] = + &group.stream().into_iter().collect::>()[..] + else { + return syn::Error::new( + attr_span, + "expected single positive integer as label", + ).to_compile_error(); + }; + let lbl_num: u32 = lit + .to_string() + .parse() + .expect("expected single positive integer as label"); + if lbl_num == 0 { + return syn::Error::new( + attr_span, + "suspension-point label must be a positive integer", + ).to_compile_error(); + } + if label.replace(lbl_num).is_some() { + return syn::Error::new( + attr_span, + "multiple labels provided for suspension-point", + ).to_compile_error(); + } + // on-exit conditions + } else if name == "on_exit" { + on_exit.push(attr.tokens); + // on-entry conditions + } else if name == "on_entry" { + on_entry.push(attr.tokens); + // all other attributes are not permitted + } else { + // TODO: more precise error message with span? + panic!("invalid attribute in suspension-point"); + } + } + + let label = label.unwrap_or_else(|| panic!("suspension-point must have a label")); + + // generate spec-items for each condition + let create_spec_item = |tokens: TokenStream| -> syn::Result { + let expr = parse_prusti(tokens)?; + Ok(parse_quote_spanned! {expr_span=> + || -> bool { + let val: bool = #expr; + val + } + }) + }; + + let on_exit: syn::Result> = on_exit + .into_iter() + .map(create_spec_item) + .collect(); + let on_exit = handle_result!(on_exit); + + let on_entry: syn::Result> = on_entry + .into_iter() + .map(create_spec_item) + .collect(); + let on_entry = handle_result!(on_entry); + + // return just the await-expression + let on_exit_closures = match on_exit.len() { + 1 => { + let on_exit = &on_exit[0]; + quote_spanned! { expr_span => (#on_exit,) } + } + _ => quote_spanned! { expr_span => (#(#on_exit),*) }, + }; + let on_entry_closures = match on_entry.len() { + 1 => { + let on_entry = &on_entry[0]; + quote_spanned! { expr_span => (#on_entry,) } + } + _ => quote_spanned! { expr_span => (#(#on_entry),*) }, + }; + + await_expr.attrs = Vec::new(); + quote_spanned! { expr_span => + { + let fut = #future; + #[allow(unused_parens)] + ::prusti_contracts::suspension_point_on_exit_marker(#label, #on_exit_closures); + let res = fut.await; + #[allow(unused_parens)] + ::prusti_contracts::suspension_point_on_entry_marker(#label, #on_entry_closures); + res + } + } +} + /// Generates the TokenStream encoding an expression using prusti syntax /// Used for body invariants, assertions, and assumptions fn generate_expression_closure( @@ -855,6 +1084,8 @@ fn extract_prusti_attributes_for_types( let tokens = match attr_kind { SpecAttributeKind::Requires => unreachable!("requires on type"), SpecAttributeKind::Ensures => unreachable!("ensures on type"), + SpecAttributeKind::AsyncEnsures => unreachable!("ensures on type"), + SpecAttributeKind::AsyncInvariant => unreachable!("async-invariant on type"), SpecAttributeKind::AfterExpiry => unreachable!("after_expiry on type"), SpecAttributeKind::AssertOnExpiry => unreachable!("assert_on_expiry on type"), SpecAttributeKind::RefineSpec => unreachable!("refine_spec on type"), @@ -900,6 +1131,8 @@ fn generate_spec_and_assertions_for_types( let rewriting_result = match attr_kind { SpecAttributeKind::Requires => unreachable!(), SpecAttributeKind::Ensures => unreachable!(), + SpecAttributeKind::AsyncEnsures => unreachable!(), + SpecAttributeKind::AsyncInvariant => unreachable!(), SpecAttributeKind::AfterExpiry => unreachable!(), SpecAttributeKind::AssertOnExpiry => unreachable!(), SpecAttributeKind::Pure => unreachable!(), diff --git a/prusti-contracts/prusti-specs/src/rewriter.rs b/prusti-contracts/prusti-specs/src/rewriter.rs index 9595485ab59..f84a54315cb 100644 --- a/prusti-contracts/prusti-specs/src/rewriter.rs +++ b/prusti-contracts/prusti-specs/src/rewriter.rs @@ -145,6 +145,62 @@ impl AstRewriter { Ok(syn::Item::Fn(spec_item)) } + /// Wrap an expression for an async postcondition in `Poll` and create the appropriate function + /// Can only be used for async postconditions and must be called after `generate_spec_item_fn` + /// to create the postcondition for the poll implementation + pub fn generate_async_ensures_item_fn( + &mut self, + spec_id: SpecificationId, + expr: TokenStream, + item: &T, + ) -> syn::Result { + // NOTE: we don't need to check for `result` as a parameter again, as this will only + // be called after `generate_spec_item_fn` for the same expression + let item_span = expr.span(); + let item_name = syn::Ident::new( + &format!("prusti_{}_item_{}_{}", SpecItemType::Postcondition, item.sig().ident, spec_id), + item_span, + ); + let spec_id_str = spec_id.to_string(); + + let mut spec_item: syn::ItemFn = parse_quote_spanned! {item_span=> + #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case)] + #[prusti::spec_only] + #[prusti::spec_id = #spec_id_str] + fn #item_name() -> bool { + let val: bool = if let ::std::task::Poll::Ready(result) = result { + #expr + } else { + true + }; + val + } + }; + spec_item.sig.generics = item.sig().generics.clone(); + spec_item.sig.inputs = item.sig().inputs.clone(); + + // note that the result-arg's type should not be the async function's return type `T` + // but `Poll` + let result_arg: syn::FnArg = { + // analogous to `generate_result_arg` + let item_span = item.span(); + let output_ty = match &item.sig().output { + syn::ReturnType::Default => parse_quote_spanned!(item_span=> ()), + syn::ReturnType::Type(_, ty) => ty.clone(), + }; + let output_ty = parse_quote_spanned!(item_span=> ::std::task::Poll<#output_ty>); + let fn_arg = syn::FnArg::Typed(syn::PatType { + attrs: Vec::new(), + pat: Box::new(parse_quote_spanned!(item_span=> result)), + colon_token: syn::Token![:](item.sig().output.span()), + ty: output_ty, + }); + fn_arg + }; + spec_item.sig.inputs.push(result_arg); + Ok(syn::Item::Fn(spec_item)) + } + /// Parse an assertion into a Rust expression pub fn process_assertion( &mut self, diff --git a/prusti-contracts/prusti-specs/src/spec_attribute_kind.rs b/prusti-contracts/prusti-specs/src/spec_attribute_kind.rs index f286cbd317b..7f730496336 100644 --- a/prusti-contracts/prusti-specs/src/spec_attribute_kind.rs +++ b/prusti-contracts/prusti-specs/src/spec_attribute_kind.rs @@ -18,6 +18,8 @@ pub enum SpecAttributeKind { Terminates = 10, PrintCounterexample = 11, Verified = 12, + AsyncEnsures = 13, + AsyncInvariant = 14, } impl TryFrom for SpecAttributeKind { @@ -37,6 +39,7 @@ impl TryFrom for SpecAttributeKind { "model" => Ok(SpecAttributeKind::Model), "print_counterexample" => Ok(SpecAttributeKind::PrintCounterexample), "verified" => Ok(SpecAttributeKind::Verified), + "async_invariant" => Ok(SpecAttributeKind::AsyncInvariant), _ => Err(name), } } diff --git a/prusti-contracts/prusti-specs/src/specifications/common.rs b/prusti-contracts/prusti-specs/src/specifications/common.rs index dca66437e79..4baaab873b3 100644 --- a/prusti-contracts/prusti-specs/src/specifications/common.rs +++ b/prusti-contracts/prusti-specs/src/specifications/common.rs @@ -56,6 +56,9 @@ pub struct SpecificationId(Uuid); pub enum SpecIdRef { Precondition(SpecificationId), Postcondition(SpecificationId), + AsyncPostcondition(SpecificationId), + AsyncStubPostcondition(SpecificationId), + AsyncInvariant(SpecificationId), Purity(SpecificationId), Pledge { lhs: Option, diff --git a/prusti-encoder/src/encoder_traits/impure_function_enc.rs b/prusti-encoder/src/encoder_traits/impure_function_enc.rs index 0eff53dce5c..c4428323537 100644 --- a/prusti-encoder/src/encoder_traits/impure_function_enc.rs +++ b/prusti-encoder/src/encoder_traits/impure_function_enc.rs @@ -1,9 +1,16 @@ -use prusti_rustc_interface::middle::mir; +use prusti_rustc_interface::middle::{ty, mir}; +use prusti_interface::specs::typed::ProcedureKind; use task_encoder::{EncodeFullError, TaskEncoder, TaskEncoderDependencies}; -use vir::{MethodIdent, UnknownArity, ViperIdent}; +use vir::{self, MethodIdent, UnknownArity, ViperIdent}; use crate::encoders::{ - lifted::func_def_ty_params::LiftedTyParamsEnc, ImpureEncVisitor, MirImpureEnc, MirLocalDefEnc, MirSpecEnc + lifted::func_def_ty_params::LiftedTyParamsEnc, + ImpureEncVisitor, + MirImpureEnc, + MirLocalDefEnc, + MirSpecEnc, + rust_ty_predicates::RustTyPredicatesEnc, + predicate, }; use super::function_enc::FunctionEnc; @@ -125,6 +132,7 @@ where vcx.alloc(vir::CfgBlockLabelData::Start), vcx.alloc_slice(&start_stmts), vcx.mk_goto_stmt(vcx.alloc(vir::CfgBlockLabelData::BasicBlock(0))), + &[], )); let mut visitor = ImpureEncVisitor { @@ -132,6 +140,7 @@ where vcx, deps, def_id, + substs, local_decls: &body.local_decls, //ssa_analysis, fpcs_analysis, @@ -151,16 +160,24 @@ where vcx.alloc(vir::CfgBlockLabelData::End), &[], vcx.alloc(vir::TerminatorStmtData::Exit), + &[], )); Some(vcx.alloc_slice(&visitor.encoded_blocks)) } else { None }; + let proc_kind = crate::encoders::with_proc_spec( + def_id, + |def_spec| def_spec.proc_kind + ) + .unwrap_or(ProcedureKind::Method); + let spec = deps - .require_local::((def_id, substs, None, false))?; - let (spec_pres, spec_posts) = (spec.pres, spec.posts); + .require_local::((def_id, substs, None, false, false))?; + let (spec_pres, spec_posts, spec_async_invs) = (spec.pres, spec.posts, spec.async_invariants); + // TODO: fix capacity? let mut pres = Vec::with_capacity(arg_count - 1); let mut args = Vec::with_capacity(arg_count + substs.len()); for arg_idx in 0..arg_count { @@ -173,10 +190,70 @@ where args.extend(param_ty_decls.iter()); pres.extend(spec_pres); + // in the case of an async body, we additionally require that the ghost fields + // capturing the initial upvar state are equal to the upvar fields + // as well as that all invariants hold initially + if matches!(proc_kind, ProcedureKind::AsyncPoll) { + let gen_ty = vcx.tcx().type_of(def_id).skip_binder(); + let fields = { + let gen_ty = deps.require_ref::(gen_ty)?; + gen_ty.generic_predicate.expect_structlike().snap_data.field_access + }; + let upvar_tys = { + let ty::TyKind::Generator(_, args, _) = gen_ty.kind() else { + panic!("expected generator TyKind to be Generator"); + }; + args.as_generator().upvar_tys() + }; + let n_upvars = upvar_tys.len(); + assert_eq!(fields.len(), 2 * upvar_tys.len() + 1); + let gen_snap = local_defs.locals[1_u32.into()].impure_snap; + for i in 0 .. upvar_tys.len() { + let field = fields[i].read.apply(vcx, [gen_snap]); + let ghost_field = fields[n_upvars + i].read.apply(vcx, [gen_snap]); + pres.push(vcx.mk_bin_op_expr(vir::BinOpKind::CmpEq, field, ghost_field)); + } + + for inv in spec_async_invs { + pres.push(inv); + } + } + let mut posts = Vec::with_capacity(spec_posts.len() + 1); posts.push(local_defs.locals[mir::RETURN_PLACE].impure_pred); posts.extend(spec_posts); + // in the case of a future constructor, we also ensure that the generator's upvar + // fields, ghost fields, and state field are set correctly + // NOTE: this detection mechanism is not always correct, specifically, it does not + // correctly mark the future constructors of async fn's without specifications as such, + // but this should not matter, as the caller cannot obtain guarantees about this + // future type anyways + if matches!(proc_kind, ProcedureKind::AsyncConstructor) { + let gen_domain = local_defs.locals[0_u32.into()].ty; + let fields = gen_domain.expect_structlike().snap_data.field_access; + let n_upvars = local_defs.arg_count; + assert_eq!(fields.len(), 2 * n_upvars + 1); + let gen_snap = local_defs.locals[0_u32.into()].impure_snap; + for i in 0 .. n_upvars { + let arg = vcx.mk_old_expr(local_defs.locals[(i + 1).into()].impure_snap); + let upvar_field = fields[i].read.apply(vcx, [gen_snap]); + let ghost_field = fields[n_upvars + i].read.apply(vcx, [gen_snap]); + posts.push(vcx.mk_bin_op_expr(vir::BinOpKind::CmpEq, upvar_field, arg)); + posts.push(vcx.mk_bin_op_expr(vir::BinOpKind::CmpEq, ghost_field, arg)); + } + let state_field = fields[2 * n_upvars].read.apply(vcx, [gen_snap]); + let zero = deps + .require_ref::( + vcx.tcx().mk_ty_from_kind(ty::TyKind::Uint(ty::UintTy::U32)) + )? + .generic_predicate + .expect_prim() + .prim_to_snap + .apply(vcx, [vcx.mk_uint::<0>()]); + posts.push(vcx.mk_bin_op_expr(vir::BinOpKind::CmpEq, state_field, zero)); + } + Ok(ImpureFunctionEncOutput { method: vcx.mk_method( method_ref, diff --git a/prusti-encoder/src/encoder_traits/pure_function_enc.rs b/prusti-encoder/src/encoder_traits/pure_function_enc.rs index 6a7269653a7..07efad43a9a 100644 --- a/prusti-encoder/src/encoder_traits/pure_function_enc.rs +++ b/prusti-encoder/src/encoder_traits/pure_function_enc.rs @@ -109,7 +109,7 @@ where deps.emit_output_ref(task_key, MirFunctionEncOutputRef { function_ref }); let spec = deps - .require_local::((def_id, substs, None, true)) + .require_local::((def_id, substs, None, true, false)) .unwrap(); let mut func_args = ty_arg_decls diff --git a/prusti-encoder/src/encoders/async/mod.rs b/prusti-encoder/src/encoders/async/mod.rs new file mode 100644 index 00000000000..776e85af3b3 --- /dev/null +++ b/prusti-encoder/src/encoders/async/mod.rs @@ -0,0 +1,5 @@ +pub mod poll_stub; +pub mod suspension_points; + +pub use poll_stub::AsyncPollStubEnc; +pub use suspension_points::{SuspensionPointAnalysis, SuspensionPoint}; diff --git a/prusti-encoder/src/encoders/async/poll_stub.rs b/prusti-encoder/src/encoders/async/poll_stub.rs new file mode 100644 index 00000000000..8d0406dfd0b --- /dev/null +++ b/prusti-encoder/src/encoders/async/poll_stub.rs @@ -0,0 +1,408 @@ +use prusti_rustc_interface::{ + hir, + middle::{ + mir, + ty::{self, GenericArgs}, + }, + span::def_id::DefId, +}; +use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; +use vir::{Method, MethodIdent, UnknownArity}; + +use super::suspension_points::SuspensionPointAnalysis; +use crate::encoders::{ + lifted::{ + casters::CastTypePure, func_def_ty_params::LiftedTyParamsEnc, + rust_ty_cast::RustTyCastersEnc, + }, + r#type::rust_ty_predicates::RustTyPredicatesEnc, + MirLocalDefEnc, MirSpecEnc, +}; + +/// Encodes a poll call stub for an async item +pub struct AsyncPollStubEnc; + +#[derive(Clone, Debug)] +pub struct AsyncPollStubEncOutputRef<'vir> { + pub method_ref: MethodIdent<'vir, UnknownArity<'vir>>, + pub return_ty: ty::Ty<'vir>, + pub arg_tys: Vec>, +} +impl<'vir> task_encoder::OutputRefAny for AsyncPollStubEncOutputRef<'vir> {} + +#[derive(Clone, Debug)] +pub struct AsyncPollStubEncOutput<'vir> { + pub method: Method<'vir>, +} + +#[derive(Clone, Debug)] +pub struct AsyncPollStubEncError; + +impl TaskEncoder for AsyncPollStubEnc { + task_encoder::encoder_cache!(AsyncPollStubEnc); + + type TaskDescription<'vir> = DefId; + + type OutputRef<'vir> = AsyncPollStubEncOutputRef<'vir>; + type OutputFullLocal<'vir> = AsyncPollStubEncOutput<'vir>; + + type EncodingError = AsyncPollStubEncError; + + fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir> { + *task + } + + fn do_encode_full<'vir>( + task: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + // TODO: for now this is a generic encoding, check whether and how this needs to be adapted + // for a monomorphic encoding + let def_id = *task; + vir::with_vcx(|vcx| { + // get generator type + let gen_ty = vcx.tcx().type_of(def_id).skip_binder(); + let ty::TyKind::Generator(_def_id, gen_args, _) = gen_ty.kind() else { + panic!("expected type of async fn to be Generator"); + }; + assert_eq!(def_id, *_def_id); + // construct the receiver type (std::pin::Pin<&mut Self>) + let ref_ty = vcx.tcx().mk_ty_from_kind(ty::TyKind::Ref( + ty::Region::new_from_kind(vcx.tcx(), ty::RegionKind::ReErased), + gen_ty, + mir::Mutability::Mut, + )); + let recv_ty = { + let pin_def_id = vcx.tcx().require_lang_item(hir::LangItem::Pin, None); + vcx.tcx().mk_ty_from_kind(ty::TyKind::Adt( + vcx.tcx().adt_def(pin_def_id), + vcx.tcx().mk_args(&[ref_ty.into()]), + )) + }; + let enc_recv_ty = deps.require_ref::(recv_ty)?; + // construct the second argument type (std::task::Context) + let cx_ty = { + let cx_def_id = vcx.tcx().require_lang_item(hir::LangItem::Context, None); + let cx_ty = vcx.tcx().mk_ty_from_kind(ty::TyKind::Adt( + vcx.tcx().adt_def(cx_def_id), + ty::List::empty(), + )); + vcx.tcx().mk_ty_from_kind(ty::TyKind::Ref( + ty::Region::new_from_kind(vcx.tcx(), ty::RegionKind::ReErased), + cx_ty, + mir::Mutability::Mut, + )) + }; + let enc_cx_ty = deps.require_ref::(cx_ty)?; + // construct the return type (std::poll::Poll) + let ret_ty = { + let poll_def_id = vcx.tcx().require_lang_item(hir::LangItem::Poll, None); + vcx.tcx().mk_ty_from_kind(ty::TyKind::Adt( + vcx.tcx().adt_def(poll_def_id), + vcx.tcx() + .mk_args(&[gen_args.as_generator().return_ty().into()]), + )) + }; + let enc_ret_ty = deps.require_ref::(ret_ty)?; + + // construct the stub's signature + let substs = GenericArgs::identity_for_item(vcx.tcx(), def_id); + let local_defs = deps.require_local::((def_id, substs, None))?; + let method_name = + vir::vir_format_identifier!(vcx, "m_poll_{}", vcx.tcx().def_path_str(def_id)); + let arg_count = local_defs.arg_count + 1; + assert_eq!(arg_count, 3); + let param_ty_decls = deps + .require_local::(substs)? + .iter() + .map(|g| g.decl()) + .collect::>(); + let method_ref = { + let mut args = vec![&vir::TypeData::Ref; arg_count]; + args.extend(param_ty_decls.iter().map(|decl| decl.ty)); + let args = UnknownArity::new(vcx.alloc_slice(&args)); + MethodIdent::new(method_name, args) + }; + deps.emit_output_ref( + *task, + AsyncPollStubEncOutputRef { + method_ref, + return_ty: ret_ty, + arg_tys: vec![recv_ty, cx_ty], + }, + )?; + + let upvar_tys = gen_args.as_generator().upvar_tys().to_vec(); + let n_upvars = upvar_tys.len(); + + // read the generator snapshot from the generator argument + let gen_snap = { + let pin_snap = + enc_recv_ty.ref_to_snap(vcx, local_defs.locals[1_u32.into()].local_ex); + let ref_snap = { + let fields = enc_recv_ty + .generic_predicate + .expect_structlike() + .snap_data + .field_access; + assert_eq!(fields.len(), 1, "expected pin domain to have 1 field"); + let ref_snap = fields[0].read.apply(vcx, [pin_snap]); + let caster = deps.require_local::>(ref_ty)?; + caster.cast_to_concrete_if_possible(vcx, ref_snap) + }; + let enc_ref_ty = deps.require_ref::(ref_ty)?; + let fields = enc_ref_ty + .generic_predicate + .expect_ref() + .snap_data + .field_access; + assert_eq!(fields.len(), 1, "expected ref domain to have 1 field"); + let gen_snap = fields[0].read.apply(vcx, [ref_snap]); + let caster = deps.require_local::>(gen_ty)?; + caster.cast_to_concrete_if_possible(vcx, gen_snap) + }; + // and read the generator's fields + let gen_fields = { + let gen_ty = deps.require_ref::(gen_ty)?; + let gen_domain_data = gen_ty.generic_predicate.expect_structlike(); + let fields = gen_domain_data.snap_data.field_access; + fields + .iter() + .map(|field| field.read.apply(vcx, [gen_snap])) + .collect::>() + }; + assert_eq!(gen_fields.len(), 2 * n_upvars + 1); + + // viper function to take snapshot of u32, will be used for state values + let u32_snap_fn = deps + .require_ref::( + vcx.tcx().mk_ty_from_kind(ty::TyKind::Uint(ty::UintTy::U32)), + )? + .generic_predicate + .expect_prim() + .prim_to_snap; + let mk_u32_snap = |x: u32| { + let vir_cnst = vcx.mk_const_expr(vir::ConstData::Int(x.into())); + u32_snap_fn.apply(vcx, [vir_cnst]) + }; + + let suspension_points = deps + .require_local::(def_id) + .unwrap() + .0; + + // encode the stub's specification + let spec = deps.require_local::((def_id, substs, None, false, true))?; + + // encode the method's on_exit/on_entry conditions as post-/preconditions + let (on_exit_posts, on_entry_pres) = { + let state_field = gen_fields[2 * n_upvars]; + let to_bool = deps + .require_ref::(vcx.tcx().types.bool) + .unwrap() + .generic_predicate + .expect_prim() + .snap_to_prim; + // create reference snapshots to the generator's fields + let field_refs = { + let fields = &gen_fields[..n_upvars]; + upvar_tys + .iter() + .zip(fields) + .map(|(ty, field)| { + let caster = deps + .require_local::>(*ty) + .unwrap(); + let ref_ty = vcx.tcx().mk_ty_from_kind(ty::TyKind::Ref( + ty::Region::new_from_kind(vcx.tcx(), ty::RegionKind::ReErased), + *ty, + mir::Mutability::Not, + )); + let ref_cons = deps + .require_ref::(ref_ty) + .unwrap() + .generic_predicate + .expect_ref() + .snap_data + .field_snaps_to_snap; + ref_cons.apply(vcx, &[caster.cast_to_generic_if_necessary(vcx, field)]) + }) + .collect::>() + }; + + let mut encode_condition = |cond_def_id: &DefId| { + // Note that all of these conditions are obtained by encoding their closure's body, + // which takes a reference to the closure as the only parameter and captured values + // (i.e. the generator fields referred to inside the condition) are accessed as + // fields of that closure. + // Hence, we first need to construct a reference to such a closure in order to + // reify the encoded expression + let closure_ty = vcx.tcx().type_of(*cond_def_id).skip_binder(); + let closure_snap = { + let closure_ty = + deps.require_ref::(closure_ty).unwrap(); + let closure_cons = closure_ty + .generic_predicate + .expect_structlike() + .snap_data + .field_snaps_to_snap; + closure_cons.apply(vcx, &field_refs) + }; + let closure_ref = { + let caster = deps + .require_local::>(closure_ty) + .unwrap(); + let ref_ty = vcx.tcx().mk_ty_from_kind(ty::TyKind::Ref( + ty::Region::new_from_kind(vcx.tcx(), ty::RegionKind::ReErased), + closure_ty, + mir::Mutability::Not, + )); + let ref_ty = deps.require_ref::(ref_ty).unwrap(); + let ref_cons = ref_ty + .generic_predicate + .expect_ref() + .snap_data + .field_snaps_to_snap; + ref_cons.apply( + vcx, + &[caster.cast_to_generic_if_necessary(vcx, closure_snap)], + ) + }; + // given that reference, we can now encode the closure body and reify using + // that reference + let expr = deps + .require_local::( + crate::encoders::MirPureEncTask { + encoding_depth: 0, + kind: crate::encoders::PureKind::Closure, + parent_def_id: *cond_def_id, + param_env: vcx.tcx().param_env(cond_def_id), + substs, + // TODO: should this be `def_id` or `cond_def_id` + caller_def_id: Some(*cond_def_id), + }, + ) + .unwrap() + .expr; + use vir::Reify; + let expr = expr.reify(vcx, (*cond_def_id, vcx.alloc_slice(&[closure_ref]))); + to_bool.apply(vcx, [expr]) + }; + + let mut on_exit_posts: Vec<&'vir vir::ExprGenData<'_, !, !>> = Vec::new(); + let mut on_entry_pres: Vec<&'vir vir::ExprGenData<'_, !, !>> = Vec::new(); + + for sp in &suspension_points { + let is_in_state = vcx.mk_bin_op_expr( + vir::BinOpKind::CmpEq, + state_field, + mk_u32_snap(sp.label), + ); + let on_exits = sp.on_exit_closures.iter().map(|cond_def_id| { + vcx.mk_bin_op_expr( + vir::BinOpKind::Implies, + is_in_state, + encode_condition(cond_def_id), + ) + }); + on_exit_posts.extend(on_exits); + let on_entries = sp.on_entry_closures.iter().map(|cond_def_id| { + vcx.mk_bin_op_expr( + vir::BinOpKind::Implies, + is_in_state, + encode_condition(cond_def_id), + ) + }); + on_entry_pres.extend(on_entries); + } + + (on_exit_posts, on_entry_pres) + }; + + // add arguments and preconditions about their types + // note that the signature is (self: std::pin::Pin<&mut Self>, cx: &mut Context) + // and not the signature of the generator + // TODO: fix capacity here + let mut pres = Vec::with_capacity(arg_count + spec.async_invariants.len() - 1); + let mut args = Vec::with_capacity( + arg_count + substs.len() + n_upvars + spec.async_invariants.len() + 1, + ); + for arg_idx in 0..arg_count { + let name = local_defs.locals[arg_idx.into()].local.name; + args.push(vir::vir_local_decl! { vcx; [name] : Ref }); + } + pres.push(enc_recv_ty.ref_to_pred(vcx, local_defs.locals[1_u32.into()].local_ex, None)); + pres.push(enc_cx_ty.ref_to_pred(vcx, local_defs.locals[2_u32.into()].local_ex, None)); + // add type parameters (and their typing preconditions) + args.extend(param_ty_decls.iter()); + pres.extend(spec.pres); + + // constrain possible state values to the suspension-point labels as well as 0 + let state_value_constraint = { + let state_field = gen_fields[2 * n_upvars]; + let state_values = suspension_points + .iter() + .map(|sp| vcx.mk_const_expr(vir::ConstData::Int(sp.label.into()))) + .chain(std::iter::once(vcx.mk_uint::<0>())) + .map(|lbl| u32_snap_fn.apply(vcx, [lbl])); + state_values + .map(|v| vcx.mk_bin_op_expr(vir::BinOpKind::CmpEq, state_field, v)) + .reduce(|l, r| vcx.mk_bin_op_expr(vir::BinOpKind::Or, l, r)) + .unwrap() + }; + pres.push(state_value_constraint); + + // add preconditions corresponding to on_entry conditions + pres.extend(on_entry_pres); + + // add invariants as preconditions + for inv in &spec.async_invariants { + pres.push(*inv); + } + + // add postconditions for the return type as well as user-annotated ones + // we also add a postcondition on the generator type after the call and the requirement + // about its state values + // TODO: fix capacities here + let mut posts = Vec::with_capacity(spec.async_stub_posts.len() + 2); + posts.push(enc_ret_ty.ref_to_pred( + vcx, + local_defs.locals[mir::RETURN_PLACE].local_ex, + None, + )); + posts.push(enc_recv_ty.ref_to_pred( + vcx, + local_defs.locals[1_u32.into()].local_ex, + None, + )); + posts.push(state_value_constraint); + posts.extend(spec.async_stub_posts); + + // add postconditions corresponding to on_exit conditions + posts.extend(on_exit_posts); + + // add postconditions that polling the future does not change the ghost fields + // so they still capture the initial state of the async fn's arguments + // read the generator from the pinned mutable ref + for ghost_field in gen_fields[n_upvars..2 * n_upvars].iter() { + let old_ghost_field = vcx.mk_old_expr(ghost_field); + posts.push(vcx.mk_bin_op_expr(vir::BinOpKind::CmpEq, ghost_field, old_ghost_field)); + } + + // add invariants as postconditions + for inv in &spec.async_invariants { + posts.push(*inv); + } + + let method = vcx.mk_method( + method_ref, + vcx.alloc_slice(&args), + &[], + vcx.alloc_slice(&pres), + vcx.alloc_slice(&posts), + None, + ); + Ok((AsyncPollStubEncOutput { method }, ())) + }) + } +} diff --git a/prusti-encoder/src/encoders/async/suspension_points.rs b/prusti-encoder/src/encoders/async/suspension_points.rs new file mode 100644 index 00000000000..738f94652be --- /dev/null +++ b/prusti-encoder/src/encoders/async/suspension_points.rs @@ -0,0 +1,582 @@ +use std::collections::{HashMap, HashSet}; + +use prusti_interface::{environment::body::MirBody, specs::typed::ProcedureKind}; +use prusti_rustc_interface::{ + middle::{ + mir::{self, visit::Visitor}, + ty, + }, + span::def_id::DefId, +}; +use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; +use vir::VirCtxt; + +/// Analyzes a method's CFG to determine its suspension point +pub struct SuspensionPointAnalysis; + +#[derive(Clone, Debug)] +pub struct SuspensionPoint { + pub label: u32, + // the first BB of the busy loop, which is where invariants should be put + pub loop_head: mir::BasicBlock, + // the last BB of the busy loop containing the yield-terminator + pub yield_bb: mir::BasicBlock, + // DefId's of the closures containing the on_exit/on_entry conditions (if any) + pub on_exit_closures: Vec, + pub on_entry_closures: Vec, + // the local containing the future as well as its pinned reference inside the busy loop + pub future_local: mir::Local, + pub pin_local: mir::Local, + // the original local containing the future at the time of its construction + // Note that this may not be set, e.g. because the future is lacking a specification (in which + // case the constructor call is not being detected as such) or because this analysis was unable + // to track the future to the suspension-point. + pub original_future_local: Option, +} + +#[derive(Clone, Debug)] +pub struct SuspensionPointAnalysisOutput(pub Vec); + +impl task_encoder::OutputRefAny for SuspensionPointAnalysisOutput {} + +#[derive(Debug, Clone)] +pub struct SuspensionPointAnalysisError; + +impl TaskEncoder for SuspensionPointAnalysis { + task_encoder::encoder_cache!(SuspensionPointAnalysis); + + type TaskDescription<'vir> = DefId; + + type OutputFullLocal<'vir> = SuspensionPointAnalysisOutput; + + type EncodingError = SuspensionPointAnalysisError; + + fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir> { + *task + } + + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + let def_id = *task_key; + deps.emit_output_ref(def_id, ())?; + vir::with_vcx(|vcx| { + let local_def_id = def_id + .as_local() + .expect("SuspensionPointAnalysis requires local item"); + let substs = ty::GenericArgs::identity_for_item(vcx.tcx(), def_id); + let body = vcx + .body_mut() + .get_impure_fn_body(local_def_id, substs, None); + + let mut visitor = SuspensionPointVisitor::new(vcx, body.clone()); + visitor.visit_body(&body); + + // create the list of suspension-points by labeling all unannotated ones + let labels: HashSet = visitor.candidates.iter().flat_map(|c| c.label).collect(); + let mut next_label = 1; + let mut get_next_label = || -> u32 { + while labels.contains(&next_label) { + next_label += 1; + } + next_label + }; + + // extract the DefId's of the closures containing on_exit/on_entry conditions + let marker_closure_def_ids = |marker_call_bb: Option| -> Vec { + let Some(marker_call_bb) = marker_call_bb else { + return Vec::new(); + }; + let terminator = body.basic_blocks[marker_call_bb].terminator(); + let mir::TerminatorKind::Call { ref args, .. } = terminator.kind else { + panic!("marker function BB should have call terminator") + }; + assert_eq!(args.len(), 2, "on_exit_marker call should have 2 arguments"); + let arg_ty = args[1].ty(&body.local_decls, vcx.tcx()); + let ty::TyKind::Tuple(closure_tys) = arg_ty.kind() else { + panic!("second argument to marker function should be tuple") + }; + closure_tys + .iter() + .map(|ty| { + let ty::TyKind::Closure(def_id, _) = ty.kind() else { + panic!("tuple element of argument to marker function should be closure") + }; + *def_id + }) + .collect() + }; + + // attempt to track the places containing futures from their construction to their + // suspension-points in order to map back from suspension-points to their future's + // constructor calls + let mut fut_place_tracker = FuturePlaceTracker::new(vcx); + fut_place_tracker.visit_body(&body); + + let suspension_points: Vec = visitor + .candidates + .into_iter() + .map(|candidate| SuspensionPoint { + label: candidate.label.unwrap_or_else(&mut get_next_label), + on_exit_closures: marker_closure_def_ids(candidate.on_exit_marker), + loop_head: candidate.loop_head.unwrap(), + yield_bb: candidate.yield_bb.unwrap(), + on_entry_closures: marker_closure_def_ids(candidate.on_entry_marker), + future_local: candidate.future_place.unwrap(), + pin_local: candidate.pin_place.unwrap(), + original_future_local: fut_place_tracker + .original_fut_place + .get(&candidate.into_future_call.unwrap()) + .copied(), + }) + .collect(); + + let res = SuspensionPointAnalysisOutput(suspension_points); + Ok((res, ())) + }) + } +} + +#[derive(Default)] +struct SuspensionPointCandidate { + label: Option, + future_place: Option, + pin_place: Option, + on_exit_marker: Option, + into_future_call: Option, + loop_head: Option, + yield_bb: Option, + on_entry_marker: Option, +} + +struct SuspensionPointVisitor<'vir> { + vcx: &'vir VirCtxt<'vir>, + body: MirBody<'vir>, + candidates: Vec, +} + +impl<'vir> SuspensionPointVisitor<'vir> { + fn new(vcx: &'vir VirCtxt<'vir>, body: MirBody<'vir>) -> Self { + Self { + vcx, + body, + candidates: Vec::new(), + } + } + + fn check_suspension_point( + &self, + block: mir::BasicBlock, + data: &mir::BasicBlockData<'vir>, + ) -> Option { + const INVALID_MARKER_MSG: &str = + "detected use of marker function outside of suspension-point"; + let mut candidate = SuspensionPointCandidate::default(); + + // the first BB must be a call to the on_exit marker or into_future + let (fn_def_id, ret_place, _, next_bb) = check_function_call(data.terminator())?; + let def_path = self.vcx.tcx().def_path_str(fn_def_id); + + // if the call is to into_future, we can continue + let (pre_loop_fut_place, next_bb) = if def_path == "std::future::IntoFuture::into_future" { + candidate.into_future_call = Some(block); + (ret_place, next_bb) + // otherwise, we first check for the into_future call + } else if def_path == "prusti_contracts::suspension_point_on_exit_marker" { + candidate.on_exit_marker = Some(block); + let next_bb_data = &self.body.basic_blocks[next_bb]; + let Some((fn_def_id, ret_place, _, next_next_bb)) = + check_function_call(next_bb_data.terminator()) + else { + panic!("{INVALID_MARKER_MSG}"); + }; + if self.vcx.tcx().def_path_str(fn_def_id) != "std::future::IntoFuture::into_future" { + panic!("{INVALID_MARKER_MSG}"); + } + candidate.into_future_call = Some(next_bb); + (ret_place, next_next_bb) + } else { + return None; + }; + if !pre_loop_fut_place.projection.is_empty() { + return None; + } + + // generally, the next BB just moves the future to a different place for the busy loop. + // it may contain some statements for analysis purposes, but should otherwise + // contain only a single assignment + let next_bb = { + let data = &self.body.basic_blocks[next_bb]; + let fut_place = { + let stmts: Result, _> = data + .statements + .iter() + .filter_map(|stmt| match stmt.kind { + mir::StatementKind::Assign(..) => Some(Ok(stmt)), + mir::StatementKind::StorageLive(..) + | mir::StatementKind::StorageDead(..) + | mir::StatementKind::FakeRead(..) => None, + _ => Some(Err(format!("{stmt:?}"))), + }) + .collect(); + let stmt = match stmts.as_deref() { + Ok([stmt]) => stmt, + _ => return None, + }; + match stmt.kind { + mir::StatementKind::Assign(box ( + new_place, + mir::Rvalue::Use(mir::Operand::Move(old_place)), + )) if old_place == pre_loop_fut_place => new_place, + _ => return None, + } + }; + let next_bb = match data.terminator().kind { + mir::TerminatorKind::Goto { target } => target, + _ => return None, + }; + assert!( + fut_place.projection.is_empty(), + "expected no projections on place containing future" + ); + candidate.future_place = Some(fut_place.local); + next_bb + }; + + // the following BB should be the busy loop's head, + // which solely consists of a `FalseUnwind` terminator + let next_bb = { + let data = &self.body.basic_blocks[next_bb]; + if !data.statements.is_empty() { + return None; + } + match data.terminator().kind { + mir::TerminatorKind::FalseUnwind { + real_target, + unwind: _, + } => { + candidate.loop_head = Some(next_bb); + real_target + } + _ => return None, + } + }; + + // inside the busy loop, a reference to the future is taken and pinned + // TODO: verify what types of other statements can appear here (e.g. StorageDead) + let next_bb = { + let data = &self.body.basic_blocks[next_bb]; + let stmts: Result, _> = data + .statements + .iter() + .filter_map(|stmt| match stmt.kind { + mir::StatementKind::Assign(..) => Some(Ok(stmt)), + mir::StatementKind::StorageLive(..) => None, + _ => Some(Err(format!("{:?}", stmt))), + }) + .collect(); + let (ref_stmt, reborrow_stmt) = match stmts.as_deref() { + Ok([s1, s2]) => (s1, s2), + _ => return None, + }; + // the first statement should just take a mutable reference to the future + let ref_place = match ref_stmt.kind { + mir::StatementKind::Assign(box ( + ref_place, + mir::Rvalue::Ref( + _, + mir::BorrowKind::Mut { + kind: mir::MutBorrowKind::Default, + }, + src_place, + ), + )) if src_place.local == candidate.future_place.unwrap() + && ref_place.projection.is_empty() => + { + ref_place + } + _ => return None, + }; + // and the second should reborrow that reference to another place + let reborrow_place = match reborrow_stmt.kind { + mir::StatementKind::Assign(box ( + reborrow_place, + mir::Rvalue::Ref( + _, + mir::BorrowKind::Mut { + kind: mir::MutBorrowKind::TwoPhaseBorrow, + }, + mir::Place { + local: src_local, + projection, + }, + ), + )) if src_local == ref_place.local + && projection.len() == 1 + && projection[0] == mir::ProjectionElem::Deref + && reborrow_place.projection.is_empty() => + { + reborrow_place + } + _ => return None, + }; + // finally, the reborrowd reference is pinned + let (fn_def_id, ret_place, args, next_bb) = check_function_call(data.terminator())?; + match args[..] { + [mir::Operand::Move(arg_place)] if arg_place == reborrow_place => {} + _ => return None, + } + if self.vcx.tcx().def_path_str(fn_def_id) != "std::pin::Pin::

::new_unchecked" + || !ret_place.projection.is_empty() + { + return None; + } + candidate.pin_place = Some(ret_place.local); + next_bb + }; + + // the following should reassign the `ResumeTy` arg and call `get_context` on it + // NOTE: from here on, we don't check the statements inside the BBs, as we don't track the + // places they use or assign to and rely on just checking terminators + let next_bb = { + let data = &self.body.basic_blocks[next_bb]; + let (fn_def_id, _, _, next_bb) = check_function_call(data.terminator())?; + if self.vcx.tcx().def_path_str(fn_def_id) != "std::future::get_context" { + return None; + } + next_bb + }; + + // then, the future is polled + let next_bb = { + let data = &self.body.basic_blocks[next_bb]; + let (_, _, args, next_bb) = check_function_call(data.terminator())?; + match args[..] { + [mir::Operand::Move(arg_place), _] + if arg_place.local == candidate.pin_place.unwrap() + && arg_place.projection.is_empty() => {} + _ => return None, + }; + next_bb + }; + + // and control flow switches on the result's discriminant + let (ready_bb, pending_bb) = { + let terminator = self.body.basic_blocks[next_bb].terminator(); + let mir::TerminatorKind::SwitchInt { + discr: _, + ref targets, + } = terminator.kind + else { + return None; + }; + let targets = targets.iter().collect::>(); + match targets[..] { + [(0, ready_bb), (1, pending_bb)] => (ready_bb, pending_bb), + _ => return None, + } + }; + + // the pending branch should first yield, and then goto back to the loop target + let next_bb = { + let terminator = self.body.basic_blocks[pending_bb].terminator(); + match terminator.kind { + mir::TerminatorKind::Yield { resume, .. } => { + candidate.yield_bb = Some(pending_bb); + resume + }, + _ => return None, + } + }; + match self.body.basic_blocks[next_bb].terminator().kind { + mir::TerminatorKind::Goto { target } if target == candidate.loop_head.unwrap() => {} + _ => return None, + }; + + // the ready branch should follow a false edge, drop the places containing the future before + // and during the busy loop and then possibly call the on_entry marker + let mir::TerminatorKind::FalseEdge { + real_target: next_bb, + .. + } = self.body.basic_blocks[ready_bb].terminator().kind + else { + return None; + }; + let next_bb = match self.body.basic_blocks[next_bb].terminator().kind { + mir::TerminatorKind::Drop { place, target, .. } + if place.local == candidate.future_place.unwrap() + && place.projection.is_empty() => + { + target + } + _ => return None, + }; + let next_bb = match self.body.basic_blocks[next_bb].terminator().kind { + mir::TerminatorKind::Drop { place, target, .. } if place == pre_loop_fut_place => { + target + } + _ => return None, + }; + let terminator = self.body.basic_blocks[next_bb].terminator(); + if let Some((fn_def_id, _, args, _)) = check_function_call(terminator) { + if self.vcx.tcx().def_path_str(fn_def_id) + == "prusti_contracts::suspension_point_on_entry_marker" + { + let [mir::Operand::Constant(box lbl_const), _] = args[..] else { + panic!("invalid arguments to on_entry marker") + }; + let mir::ConstantKind::Val(lbl_val, _) = lbl_const.literal else { + panic!("invalid arguments to on_entry marker") + }; + let lbl = lbl_val + .try_to_scalar_int() + .expect("could not convert label value to u32") + .try_to_u32() + .expect("could not convert label scalar to u32"); + candidate.label = Some(lbl); + candidate.on_entry_marker = Some(next_bb); + } + } + + // make sure markers can only appear as pairs + assert_eq!( + candidate.on_exit_marker.is_some(), + candidate.on_entry_marker.is_some(), + "found unpaired call to suspension-point marker function" + ); + + assert!(candidate.future_place.is_some()); + assert!(candidate.pin_place.is_some()); + assert!(candidate.into_future_call.is_some()); + assert!(candidate.loop_head.is_some()); + Some(candidate) + } +} + +/// Check if the terminator corresponds to a function call and if so, return that function's +/// DefId, the place for the return value, and the BB to continue with. +/// Function calls that necessarily diverge are *not* considered here. +fn check_function_call<'vir, 'a>( + terminator: &'a mir::Terminator<'vir>, +) -> Option<( + DefId, + mir::Place<'vir>, + &'a Vec>, + mir::BasicBlock, +)> { + let mir::TerminatorKind::Call { + ref func, + destination, + ref target, + ref args, + .. + } = terminator.kind + else { + return None; + }; + let mir::Operand::Constant(box ref cnst) = func else { + return None; + }; + let mir::ConstantKind::Val(_, ty) = cnst.literal else { + return None; + }; + let ty::TyKind::FnDef(fn_def_id, _) = ty.kind() else { + return None; + }; + Some((*fn_def_id, destination, args, *target.as_ref()?)) +} + +impl<'vir> Visitor<'vir> for SuspensionPointVisitor<'vir> { + fn visit_basic_block_data(&mut self, block: mir::BasicBlock, data: &mir::BasicBlockData<'vir>) { + // if this BB is the into_future call of an already detected, annotated suspension-point, + // we need to avoid detecting it as an unannotated one again + if self + .candidates + .iter() + .any(|candidate| candidate.into_future_call.unwrap() == block) + { + return; + } + + // otherwise, just check if there is a suspension-point starting at this BB + if let Some(candidate) = self.check_suspension_point(block, data) { + self.candidates.push(candidate); + } + } +} + +struct FuturePlaceTracker<'vir> { + vcx: &'vir VirCtxt<'vir>, + // tracking list mapping places to the original local their future was assigned to + current_fut_place: HashMap, mir::Local>, + // finalized list mapping BBs of `into_future` calls to the original locals + // containing their futures upon construction + original_fut_place: HashMap, +} + +impl<'vir> FuturePlaceTracker<'vir> { + fn new(vcx: &'vir VirCtxt<'vir>) -> Self { + Self { + vcx, + current_fut_place: HashMap::new(), + original_fut_place: HashMap::new(), + } + } +} + +impl<'vir> Visitor<'vir> for FuturePlaceTracker<'vir> { + fn visit_basic_block_data(&mut self, block: mir::BasicBlock, data: &mir::BasicBlockData<'vir>) { + // visit all statements in the BB + self.super_basic_block_data(block, data); + + // if this block calls a future constructor, add the local the future is being assigned to + // to the tracking list + let Some((fn_def_id, dest, args, _)) = check_function_call(data.terminator()) else { + return; + }; + let proc_kind = crate::encoders::with_proc_spec(fn_def_id, |proc_spec| proc_spec.proc_kind); + if matches!(proc_kind, Some(ProcedureKind::AsyncConstructor)) { + assert!( + dest.projection.is_empty(), + "expected future constructor to assign to local without projections" + ); + self.current_fut_place.insert(dest, dest.local); + return; + } + + // if this block ends with an `IntoFuture::into_future` call, we can stop tracking that + // future + if self.vcx.tcx().def_path_str(fn_def_id) == "std::future::IntoFuture::into_future" { + assert!( + dest.projection.is_empty(), + "expected `IntoFuture::into_future` to assign to unprojected local" + ); + let [mir::Operand::Move(place)] = args[..] else { + panic!("`IntoFuture::into_future` call should have exactly one moved argument"); + }; + if let Some(original_local) = self.current_fut_place.remove(&place) { + self.original_fut_place.insert(block, original_local); + }; + } + } + + fn visit_statement(&mut self, statement: &mir::Statement<'vir>, _location: mir::Location) { + // if this statement is an assignment, check if any of the tracked futures has moved + // NOTE: we only track futures being moved to different places, and not e.g. being + // referenced and then moved. + let mir::StatementKind::Assign(box ( + new_place, + mir::Rvalue::Use(mir::Operand::Move(old_place)), + )) = statement.kind + else { + return; + }; + + // if the old place being moved out of was in the tracking list, change it so the new place + // points to the original local + if let Some(original_local) = self.current_fut_place.remove(&old_place) { + self.current_fut_place.insert(new_place, original_local); + } + } +} diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index dcb9fa8b796..3567418db7b 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -2,11 +2,12 @@ use mir_state_analysis::{ free_pcs::{CapabilityKind, FreePcsAnalysis, FreePcsBasicBlock, FreePcsLocation, RepackOp}, utils::Place, }; +use prusti_interface::{environment::EnvQuery, specs::typed::ProcedureKind}; use prusti_rustc_interface::{ abi, middle::{ mir, - ty::{GenericArgs, TyKind}, + ty::{self, GenericArgs, TyKind}, }, span::def_id::DefId, }; @@ -14,7 +15,7 @@ use prusti_rustc_interface::{ // SsaAnalysis, //}; use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; -use vir::{MethodIdent, UnknownArity}; +use vir::{MethodIdent, UnknownArity, Reify}; pub struct MirImpureEnc; @@ -42,19 +43,22 @@ use crate::{ self, lifted::{ aggregate_cast::{ + self, AggregateSnapArgsCastEnc, AggregateSnapArgsCastEncTask }, func_app_ty_params::LiftedFuncAppTyParamsEnc }, - FunctionCallTaskDescription, MirBuiltinEnc + FunctionCallTaskDescription, MirBuiltinEnc, + r#async::{AsyncPollStubEnc, SuspensionPointAnalysis}, + MirSpecEnc, } }; use super::{ lifted::{ cast::{CastArgs, CastToEnc}, - casters::CastTypeImpure, + casters::{CastTypeImpure, CastTypePure}, rust_ty_cast::RustTyCastersEnc }, rust_ty_predicates::RustTyPredicatesEnc, @@ -113,6 +117,7 @@ pub struct ImpureEncVisitor<'vir, 'enc, E: TaskEncoder> pub monomorphize: bool, pub deps: &'enc mut TaskEncoderDependencies<'vir, E>, pub def_id: DefId, + pub substs: ty::GenericArgsRef<'vir>, pub local_decls: &'enc mir::LocalDecls<'vir>, //ssa_analysis: SsaAnalysis, pub fpcs_analysis: FreePcsAnalysis<'enc, 'vir>, @@ -484,6 +489,81 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> { let tmp = self.vcx.mk_local(name, ty); (tmp, self.vcx.mk_local_ex_local(tmp)) } + + /// helper function to create a call to an async generator's poll stub + fn mk_poll_call( + &mut self, + gen_def_id: DefId, + gen_args: ty::GenericArgsRef<'vir>, + args: &Vec>, + destination: & mir::Place<'vir>, + target: &Option, + pin_gen_ty: ty::Ty<'vir>, + poll_ret_ty: ty::Ty<'vir>, + ) { + let poll_stub = self.deps.require_ref::(gen_def_id).unwrap(); + let dest = self.encode_place(Place::from(*destination)).expr; + + let method_in = args.iter().map(|arg| self.encode_operand(arg)).collect::>(); + + for ((fn_arg_ty, arg), arg_ex) in poll_stub.arg_tys.iter().zip(args.iter()).zip(method_in.iter()) { + let local_decls = self.local_decls_src(); + let tcx = self.vcx().tcx(); + let arg_ty = arg.ty(local_decls, tcx); + let caster = self.deps() + .require_ref::>(CastArgs { + expected: *fn_arg_ty, + actual: arg_ty + }) + .unwrap(); + // In this context, `apply_cast_if_necessary` returns + // the impure operation to perform the cast + if let Some(stmt) = caster.apply_cast_if_necessary(self.vcx(), arg_ex) { + self.stmt(stmt); + } + } + + let mut method_args = + std::iter::once(dest).chain(method_in).collect::>(); + let mono = self.monomorphize; + let encoded_ty_args = self + .deps() + .require_local::( + (mono, gen_args) + ) + .unwrap() + .iter() + .map(|ty| ty.expr(self.vcx())); + + method_args.extend(encoded_ty_args); + + self.stmt( + self.vcx + .alloc(poll_stub.method_ref.apply(self.vcx, &method_args)), + ); + let expected_ty = destination.ty(self.local_decls_src(), self.vcx.tcx()).ty; + let result_cast = self + .deps() + .require_ref::>(CastArgs { + expected: expected_ty, + actual: poll_stub.return_ty, + }) + .unwrap(); + if let Some(stmt) = result_cast.apply_cast_if_necessary(self.vcx, dest) { + self.stmt(stmt); + } + + let terminator = target + .map(|target| { + self.vcx.mk_goto_stmt( + self.vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), + ) + }) + .unwrap_or_else(|| self.unreachable()); + + assert!(self.current_terminator.replace(terminator).is_none()); + } } impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor<'vir, 'enc, E> { @@ -539,6 +619,66 @@ impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor< } */ + // if this is the head of an await's busy loop, we need to add some invariants + let invariants = (|| { + let suspension_points = self.deps.require_local::(self.def_id).unwrap().0; + let suspension_point = suspension_points + .into_iter() + .find(|sp| sp.loop_head == block)?; + let fut_ty = self.local_decls[suspension_point.future_local].ty; + let fut_ty = self.vcx.tcx().expand_opaque_types(fut_ty); + let ty::TyKind::Generator(fut_def_id, ..) = fut_ty.kind() else { + panic!("suspension-point future place does not contain generator type"); + }; + // make sure to get generator type without any substitutions already applied + let fut_ty = self.vcx.tcx().type_of(fut_def_id).skip_binder(); + let ty::TyKind::Generator(fut_def_id, params, _) = fut_ty.kind() else { + unreachable!() + }; + // FIXME: this encodes the invariants for the wrong generator, namely the self-arg, + // as it encodes from the polled future's perspective rather than from the polling + // future's + // we might be able to fix this by manually encoding using `MirPure` and refiying + // the result here + let fut_invs = self.deps.require_local::( + (*fut_def_id, self.substs, None, false, true) + ).unwrap().async_invariants; + + let upvar_tys = params.as_generator().upvar_tys(); + let ghost_field_invs = if let Some(original_fut_local) = suspension_point.original_future_local { + let gen_snap = self.local_defs.locals[suspension_point.future_local].impure_snap; + let fut_ty = self.deps.require_ref::(fut_ty).unwrap(); + let fields = fut_ty.generic_predicate.expect_structlike().snap_data.field_access; + assert_eq!(fields.len(), 2 * upvar_tys.len() + 1); + let ghost_fields = fields[upvar_tys.len() .. 2 * upvar_tys.len()] + .iter() + .map(|f| f.read.apply(self.vcx, [gen_snap])); + ghost_fields + .zip(upvar_tys.iter()) + .enumerate() + .map(|(idx, (f, ty))| { + let ty = self.deps.require_ref::(ty).unwrap().snapshot(); + let name = vir::vir_format!(self.vcx, "_fut_arg_snap${original_fut_local:?}p${idx}"); + let snap_local = self.vcx.mk_local_ex(name, ty); + self.vcx.mk_bin_op_expr(vir::BinOpKind::CmpEq, f, snap_local) + }) + .collect() + } else { + Vec::new() + }; + + // we now add the necessary invariant about the typing of the generator and `ResumeTy` + // argument, the generator ghost fields being equal to the snapshots taken at the + // constructor call, as well as the polled future's invariants + let mut invs = Vec::with_capacity(2 + upvar_tys.len() + fut_invs.len()); + invs.push(self.local_defs.locals[suspension_point.future_local].impure_pred); + invs.push(self.local_defs.locals[2_u32.into()].impure_pred); + invs.extend(ghost_field_invs); + invs.extend(fut_invs); + + Some(self.vcx.alloc_slice(&invs)) + })(); + assert!(self.current_terminator.is_none()); self.super_basic_block_data(block, data); let stmts = self.current_stmts.take().unwrap(); @@ -547,7 +687,8 @@ impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor< self.vcx.mk_cfg_block( self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(block.as_usize())), self.vcx.alloc_slice(&stmts), - terminator + terminator, + invariants.unwrap_or(&[]), ) ); } @@ -694,6 +835,105 @@ impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor< } } + // future constructors + mir::Rvalue::Aggregate( + box kind@mir::AggregateKind::Generator(def_id, _params, _movbl), + operands + ) if self.vcx.tcx().generator_is_async(*def_id) => { + let generator_ty = self.deps.require_ref::(rvalue_ty).unwrap(); + let snap_cons = generator_ty + .generic_predicate + .expect_structlike() + .snap_data + .field_snaps_to_snap; + let operand_tys = operands + .iter() + .map(|oper| oper.ty(self.local_decls, self.vcx.tcx())) + .collect::>(); + // cast given arguments to field types + let ty_caster = self.deps.require_local::( + AggregateSnapArgsCastEncTask { + tys: operand_tys, + aggregate_type: kind.into() + } + ).unwrap(); + let operand_snaps = operands + .iter() + .map(|oper| self.encode_operand_snap(oper)) + .collect::>(); + let casted_args = ty_caster.apply_casts(self.vcx, operand_snaps.into_iter()); + // duplicate them to also initialize the ghost fields + // and initialize the state to 0 + let zero = self + .deps + .require_ref::( + self.vcx.tcx().mk_ty_from_kind(ty::TyKind::Uint(ty::UintTy::U32)) + ) + .unwrap() + .generic_predicate + .expect_prim() + .prim_to_snap + .apply(self.vcx, [self.vcx.mk_uint::<0>()]); + let n_args = casted_args.len(); + let args = casted_args + .into_iter() + .cycle() + .take(2 * n_args) + .chain(std::iter::once(zero)) + .collect::>(); + snap_cons.apply(self.vcx, self.vcx.alloc_slice(&args)) + } + + // FIXME: this is only a dummy to inspect generated async code + mir::Rvalue::Ref(region, borrow_kind, place) => { + // get the type of the place the ref points to + let place_ty = place.ty(self.local_decls, self.vcx.tcx()); + // and construct the type of the reference pointing to it + let ref_domain = { + // either by manually creating the domain + // (with hardcoded name for Ref-domains) + // let place_ty = self.deps.require_ref::(place_ty.ty).unwrap(); + // let dom_name = match borrow_kind { + // mir::BorrowKind::Mut { .. } => "s_Ref_Mut", + // _ => "s_Ref", + // }; + // let dom_args = self.vcx.alloc([place_ty.snapshot]); + // self.vcx.alloc(vir::TypeData::Domain(dom_name, dom_args)) + // or by wrapping it in a ref and using the encoder + let mutability = match borrow_kind { + mir::BorrowKind::Shared | mir::BorrowKind::Shallow => mir::Mutability::Not, + mir::BorrowKind::Mut { .. } => mir::Mutability::Mut, + }; + let ref_ty = self.vcx.tcx().mk_ty_from_kind(TyKind::Ref(*region, place_ty.ty, mutability)); + let (ref_ty, _) = crate::encoders::most_generic_ty::extract_type_params(self.vcx.tcx(), ref_ty); + let enc_ref_ty = self.deps.require_ref::(ref_ty).unwrap(); + enc_ref_ty.snapshot + }; + let name = match borrow_kind { + mir::BorrowKind::Shared | mir::BorrowKind::Shallow => "DummyRefLocal", + mir::BorrowKind::Mut { .. } => "DummyRefMutLocal", + }; + self.vcx.mk_local_ex(name, ref_domain) + } + // FIXME: this is only a dummy to inspect generated async code + mir::Rvalue::Aggregate( + box mir::AggregateKind::Closure(def_id, args), + fields + ) => { + let closure_ty = self.vcx.tcx().type_of(def_id).skip_binder(); + let closure_ty = self.deps.require_ref::(closure_ty).unwrap(); + let snap_cons = closure_ty + .generic_predicate + .expect_structlike() + .snap_data + .field_snaps_to_snap; + let fields = fields + .iter() + .map(|f| self.encode_operand_snap(f)) + .collect::>(); + snap_cons.apply(self.vcx, &fields) + } + //mir::Rvalue::Discriminant(Place<'vir>) => {} //mir::Rvalue::ShallowInitBox(Operand<'vir>, Ty<'vir>) => {} //mir::Rvalue::CopyForDeref(Place<'vir>) => {} @@ -806,8 +1046,21 @@ impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor< self.vcx.alloc_slice(&otherwise_stmts), ) } - mir::TerminatorKind::Return => - self.vcx.mk_goto_stmt(self.vcx.alloc(vir::CfgBlockLabelData::End)), + mir::TerminatorKind::Return => { + // make sure all async-invariants still hold + // FIXME: see comment regarding async-invariants at yield-terminators + if self.vcx.tcx().generator_is_async(self.def_id) { + let invs = self + .deps + .require_local::((self.def_id, self.substs, None, false, false)) + .unwrap() + .async_invariants; + for inv in invs { + self.stmt(self.vcx.mk_exhale_stmt(inv)); + } + } + self.vcx.mk_goto_stmt(self.vcx.alloc(vir::CfgBlockLabelData::End)) + } mir::TerminatorKind::Call { func, args, @@ -815,10 +1068,163 @@ impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor< target, .. } => { - let (func_def_id, caller_substs) = self.get_def_id_and_caller_substs(func); - let is_pure = crate::encoders::with_proc_spec(func_def_id, |spec| - spec.kind.is_pure().unwrap_or_default() - ).unwrap_or_default(); + let (mut func_def_id, mut caller_substs) = self.get_def_id_and_caller_substs(func); + let (is_pure, proc_kind) = crate::encoders::with_proc_spec(func_def_id, |spec| + (spec.kind.is_pure().unwrap_or_default(), spec.proc_kind) + ).unwrap_or((false, ProcedureKind::Method)); + + // encode suspension-point on_exit/on_entry markers as assert/assume instead of method calls + let def_path = self.vcx.tcx().def_path_str(func_def_id); + let suspension_point_marker = match def_path.as_str() { + "prusti_contracts::suspension_point_on_exit_marker" => Some(true), + "prusti_contracts::suspension_point_on_entry_marker" => Some(false), + _ => None, + }; + if let Some(is_on_exit) = suspension_point_marker { + // generate assert/assume for each closure of the tuple passed as second arg + let closures_local = { + assert_eq!(args.len(), 2, "expected 2 arguments to suspension-point marker"); + let mir::Operand::Move(place) = &args[1] else { + panic!("expected closure tuple argument to suspension-point to be moved") + }; + assert!( + place.projection.is_empty(), + "expected no projections on closure tuple argument to suspension-point marker" + ); + place.local + }; + let closure_def_ids = { + let ty::TyKind::Tuple(closures) = self.local_decls[closures_local].ty.kind() else { + panic!("expected second argument to suspension-point marker to be tuple") + }; + closures + .into_iter() + .map(|closure_ty| match closure_ty.kind() { + ty::TyKind::Closure(def_id, _) => def_id, + _ => panic!("expected second argument to suspension-point marker to be tuple of closures") + }) + }; + + let closure_fields = { + let tuple_ty = self.local_defs.locals[closures_local].ty; + tuple_ty.expect_structlike().snap_data.field_access + }; + let tuple_expr = self.local_defs.locals[closures_local].impure_snap; + for (i, def_id) in closure_def_ids.enumerate() { + // construct a (pure) reference to the closure to provide the exit/entry + // predicate spec method body with + let closure_ref = { + let closure_ty = self.vcx.tcx().type_of(def_id).skip_binder(); + let caster = self.deps.require_local::>(closure_ty).unwrap(); + let ref_ty = self.vcx.tcx().mk_ty_from_kind(ty::TyKind::Ref( + ty::Region::new_from_kind(self.vcx.tcx(), ty::RegionKind::ReErased), + closure_ty, + mir::Mutability::Not, + )); + let ref_ty = self.deps.require_ref::(ref_ty).unwrap(); + let ref_cons = ref_ty.generic_predicate.expect_ref().snap_data.field_snaps_to_snap; + let closure = closure_fields[i].read.apply(self.vcx, [tuple_expr]); + let closure_ref = ref_cons.apply(self.vcx, &[closure]); + closure_ref + }; + let expr = self.deps + .require_local::( + crate::encoders::MirPureEncTask { + encoding_depth: 0, + kind: crate::encoders::PureKind::Closure, + parent_def_id: *def_id, + param_env: self.vcx.tcx().param_env(def_id), + substs: self.substs, + // TODO: should this be `def_id` or `caller_def_id` + caller_def_id: Some(self.def_id), + }, + ) + .unwrap() + .expr; + let expr = expr.reify(self.vcx, (*def_id, self.vcx.alloc_slice(&[closure_ref]))); + let to_bool = self.deps + .require_ref::(self.vcx.tcx().types.bool).unwrap() + .generic_predicate + .expect_prim() + .snap_to_prim; + let expr = to_bool.apply(self.vcx, [expr]); + let stmt_kind = if is_on_exit { + vir::StmtGenData::Exhale + } else { + vir::StmtGenData::Inhale + }; + self.stmt(self.vcx.alloc(stmt_kind(expr))); + } + + // and then just proceed with the next BB + let terminator = target + .map(|target| { + self.vcx.mk_goto_stmt( + self.vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), + ) + }) + .unwrap_or_else(|| self.unreachable()); + assert!(self.current_terminator.replace(terminator).is_none()); + return; + } + + // intercept some calls that are necessary for async code, + // but need to be encoded differently + if self.vcx.tcx().trait_of_item(func_def_id).is_some() { + let env_query = EnvQuery::new(self.vcx.tcx()); + let (resolved_def_id, resolved_params) = env_query.resolve_method_call(self.def_id, func_def_id, caller_substs); + match self.vcx.tcx().def_path_str(func_def_id).as_ref() { + // we can resolve calls to into_future + "std::future::IntoFuture::into_future" => { + func_def_id = resolved_def_id; + caller_substs = resolved_params; + }, + // and replace calls to poll with the annotated poll stub + "std::future::Future::poll" => { + // the generator is passed (by move) as the first argument + let mir::Operand::Move(ref fut_place) = args[0] else { + panic!("expected first argument to poll to be moved"); + }; + let fut_ty = { + // TODO: verify that this works with manual poll calls + assert!(fut_place.projection.is_empty(), "expected no projections on poll-argument"); + let fut_ty = &self.local_decls[fut_place.local].ty; + let fut_ty =match fut_ty.kind() { + ty::Adt(adt_def, args) => { + assert_eq!(self.vcx.tcx().def_path_str(adt_def.did()), "std::pin::Pin"); + assert_eq!(args.len(), 1); + args[0].as_type().expect("expected Pin's generic arg to be type") + }, + _ => panic!("expected poll argument to be pinned"), + }; + let ty::TyKind::Ref(_, ref fut_ty, mir::Mutability::Mut) = fut_ty.kind() else { + panic!("expected poll argument to be pinned mutable reference"); + }; + // generally, the future type is hidden behind an OTA + self.vcx.tcx().expand_opaque_types(*fut_ty) + }; + + // if the future is now known to be a specific generator, + // we can use its `DefId` to redirect the call to its poll-stub + if let ty::TyKind::Generator(gen_def_id, gen_args, _) = fut_ty.kind() { + self.mk_poll_call( + *gen_def_id, + gen_args, + args, + destination, + target, + self.local_decls[fut_place.local].ty, + destination.ty(self.local_decls, self.vcx.tcx()).ty, + ); + return; + } else { + println!("unable to resolve poll for {fut_ty:?}"); + } + }, + _ => {}, + } + } let dest = self.encode_place(Place::from(*destination)).expr; @@ -873,7 +1279,12 @@ impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor< let method_in = args.iter().map(|arg| self.encode_operand(arg)).collect::>(); - for ((fn_arg_ty, arg), arg_ex) in fn_arg_tys.iter().zip(args.iter()).zip(method_in.iter()) { + for (idx, ((fn_arg_ty, arg), arg_ex)) in fn_arg_tys + .iter() + .zip(args.iter()) + .zip(method_in.iter()) + .enumerate() + { let local_decls = self.local_decls_src(); let tcx = self.vcx().tcx(); let arg_ty = arg.ty(local_decls, tcx); @@ -888,6 +1299,19 @@ impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor< if let Some(stmt) = caster.apply_cast_if_necessary(self.vcx(), arg_ex) { self.stmt(stmt); } + + // if this is a call to a future constructor, create a new variable + // capturing a snapshot of the argument + if matches!(proc_kind, ProcedureKind::AsyncConstructor) { + let arg_ty = self.deps.require_ref::(*fn_arg_ty).unwrap(); + let name = vir::vir_format!(self.vcx, "_fut_arg_snap${dest:?}${idx}"); + let snap = arg_ty.generic_predicate.ref_to_snap.apply(self.vcx, &[arg_ex]); + let decl = self.vcx.mk_local_decl_stmt( + vir::vir_local_decl! { self.vcx; [name] : [arg_ty.generic_predicate.snapshot]}, + Some(snap), + ); + self.stmt(decl); + } } let mut method_args = @@ -979,6 +1403,52 @@ impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor< ) } + kind@mir::TerminatorKind::Yield { + value, + resume, + resume_arg, + drop: _, + } => { + // when yielding, we exhale permissions to the yielded value, + let mir::Operand::Move(place) = value else { + panic!("expected yielded value to be moved") + }; + let yield_val_permission = self.local_defs.locals[place.local].impure_pred; + self.stmt(self.vcx.mk_exhale_stmt(yield_val_permission)); + // make sure all async-invariants still hold + // FIXME: currently, the async-invariants are exhaled as if the generator is still + // in the first argument place _1p, but generally it will be be moved to different + // places during the body's execution. + // Furthermore, it is behind a pinned mutable reference when polled, so we also + // need to make sure that accessing it via the original place has the correct + // behavior once Pin and mutable references are supported. + // TODO: figure out the place in which the generator is during busy-loop analysis + if self.vcx.tcx().generator_is_async(self.def_id) { + let invs = self + .deps + .require_local::((self.def_id, self.substs, None, false, false)) + .unwrap() + .async_invariants; + for inv in invs { + self.stmt(self.vcx.mk_exhale_stmt(inv)); + } + } + // TODO: once shared state is supported, we also need to havoc all shared state, + // and inhale all invariants again + // inhale permissions to the obtained resume-values, + let resume_permission = self.local_defs.locals[resume_arg.local].impure_pred; + self.stmt(self.vcx.mk_inhale_stmt(resume_permission)); + // and continue with the resume BB + // Ensure that the terminator succ that we use for the repacks is the correct one + const REAL_TARGET_SUCC_IDX: usize = 0; + assert_eq!(&self.current_fpcs.as_ref().unwrap().terminator.succs[REAL_TARGET_SUCC_IDX].location.block, resume); + self.fpcs_repacks_terminator(REAL_TARGET_SUCC_IDX, |rep| &rep.repacks_start); + self.vcx.mk_goto_stmt( + self.vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(resume.as_usize())), + ) + } + unsupported_kind => self.vcx.mk_dummy_stmt( vir::vir_format!(self.vcx, "terminator {unsupported_kind:?}") ), diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index cebf1c40f04..722dbced937 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -783,13 +783,17 @@ impl<'vir: 'enc, 'enc> Enc<'vir, 'enc> { } mir::ProjectionElem::Field(field_idx, ty) => { match place_ty.ty.kind() { - TyKind::Closure(_def_id, args) => { - let upvars = args.as_closure().upvar_tys().iter().collect::>().len(); - let tuple_ref = self.deps.require_local::( - upvars, - ).unwrap(); - (tuple_ref.mk_elem(self.vcx, expr, field_idx.as_usize()), place_ref) - } + // NOTE: this special treatment of closure fields belongs to a different + // (currently unimplemented) encoding of closures and does not work + // with the current encoding of closures as wrapper struct-likes for their + // upvars + // TyKind::Closure(_def_id, args) => { + // let upvars = args.as_closure().upvar_tys().iter().collect::>().len(); + // let tuple_ref = self.deps.require_local::( + // upvars, + // ).unwrap(); + // (tuple_ref.mk_elem(self.vcx, expr, field_idx.as_usize()), place_ref) + // } tykind => { let e_ty = self.deps.require_ref::(place_ty.ty).unwrap(); let struct_like = e_ty diff --git a/prusti-encoder/src/encoders/mod.rs b/prusti-encoder/src/encoders/mod.rs index a4e80bc1a42..241e14439d7 100644 --- a/prusti-encoder/src/encoders/mod.rs +++ b/prusti-encoder/src/encoders/mod.rs @@ -10,6 +10,7 @@ mod local_def; mod r#type; mod r#const; mod mono; +mod r#async; cfg_if::cfg_if! { if #[cfg(feature = "mono_function_encoding")] { @@ -56,3 +57,4 @@ pub use viper_tuple::{ ViperTupleEncOutput, }; pub use r#const::ConstEnc; +pub use r#async::poll_stub::AsyncPollStubEnc; diff --git a/prusti-encoder/src/encoders/pure/spec.rs b/prusti-encoder/src/encoders/pure/spec.rs index 77c4af7ba67..0130ae7514a 100644 --- a/prusti-encoder/src/encoders/pure/spec.rs +++ b/prusti-encoder/src/encoders/pure/spec.rs @@ -1,18 +1,28 @@ use prusti_rustc_interface::{ + hir, middle::{mir, ty}, - span::def_id::DefId, + span::{def_id::DefId, Symbol}, }; use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use vir::Reify; -use crate::encoders::{mir_pure::PureKind, rust_ty_predicates::RustTyPredicatesEnc, MirPureEnc}; +use crate::encoders::{ + predicate, + mir_pure::PureKind, + rust_ty_predicates::RustTyPredicatesEnc, + MirPureEnc, + most_generic_ty, + lifted::{rust_ty_cast::RustTyCastersEnc, casters::CastTypePure}, +}; pub struct MirSpecEnc; #[derive(Clone)] pub struct MirSpecEncOutput<'vir> { pub pres: Vec>, pub posts: Vec>, + pub async_stub_posts: Vec>, + pub async_invariants: Vec>, pub pre_args: &'vir [vir::Expr<'vir>], pub post_args: &'vir [vir::Expr<'vir>], } @@ -25,6 +35,7 @@ impl TaskEncoder for MirSpecEnc { ty::GenericArgsRef<'tcx>, // ? this should be the "signature", after applying the env/substs Option, // ID of the caller function, if any bool, // If to encode as pure or not + bool, // whether to encode for an async poll stub or not ); type OutputFullLocal<'vir> = MirSpecEncOutput<'vir>; @@ -39,7 +50,7 @@ impl TaskEncoder for MirSpecEnc { task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir, Self>, ) -> EncodeFullResult<'vir, Self> { - let (def_id, substs, caller_def_id, pure) = *task_key; + let (def_id, substs, caller_def_id, pure, is_poll_stub) = *task_key; deps.emit_output_ref(*task_key, ())?; let local_defs = deps @@ -105,41 +116,210 @@ impl TaskEncoder for MirSpecEnc { }) .collect::>>(); + let is_async = vcx.tcx().generator_is_async(def_id); + if !is_async && is_poll_stub { + panic!("cannot set is_poll_stub for non-async bodies"); + } + + // on async functions, there is a mismatch between the signature of the declared + // async fn (and thus the spec function) and its body on the MIR level, whose + // parameters are the return place, the future, and the `ResumeTy` + // hence, instead of directly accessing function arguments, we need to read them + // from the generator's fields. + // The first half of these fields contain the generator's upvars, which may be mutated + // by poll calls / during the body's execution and the second half contains ghost + // fields, which are not changed and capture the upvars' initial value + let async_generator_fields = if !is_async { + None + } else { + let generator_snap = if !is_poll_stub { + // the async body simply takes the generator as argument and returns the result + local_defs.locals[1_u32.into()].impure_snap + } else { + // the poll stub, however, (which does not have a `DefId` and thus a signature + // we can see) takes a pinned mutable reference to the generator and returns + // the result wrapped in a `Poll`. + // hence, we first need to read the generator from the pinned reference and fix + // the typing of the return value (note that the postcondition itself is + // already wrapped in `Poll` by the `prusti_contracts` macros) + let gen_ty = vcx.tcx().type_of(def_id).skip_binder(); + let gen_snap = { + // construct the receiver type + let ref_ty = vcx.tcx().mk_ty_from_kind(ty::TyKind::Ref( + ty::Region::new_from_kind(vcx.tcx(), ty::RegionKind::ReErased), + gen_ty, + mir::Mutability::Mut, + )); + let pin_ty = { + let pin_def_id = vcx.tcx().require_lang_item(hir::LangItem::Pin, None); + vcx.tcx().mk_ty_from_kind(ty::TyKind::Adt( + vcx.tcx().adt_def(pin_def_id), + vcx.tcx().mk_args(&[ref_ty.into()]), + )) + }; + // and gradually read the generator snapshot from the argument + let ref_snap = { + let pin_ty = deps.require_ref::(pin_ty)?; + // note that we cannot use the `LocalDef`'s `impure_snap`, as its type is the + // generator itself due to the `DefId` belonging to the generator body + let pin_snap = pin_ty.ref_to_snap(vcx, local_defs.locals[1_u32.into()].local_ex); + let fields = pin_ty.generic_predicate.expect_structlike().snap_data.field_access; + assert_eq!(fields.len(), 1, "expected pin domain to have 1 field"); + let ref_snap = fields[0].read.apply(vcx, [pin_snap]); + let caster = deps.require_local::>(ref_ty).unwrap(); + caster.cast_to_concrete_if_possible(vcx, ref_snap) + }; + let ref_ty = deps.require_ref::(ref_ty)?; + let fields = ref_ty.generic_predicate.expect_ref().snap_data.field_access; + assert_eq!(fields.len(), 1, "expected ref domain to have 1 field"); + let gen_snap = fields[0].read.apply(vcx, [ref_snap]); + let caster = deps.require_local::>(gen_ty).unwrap(); + caster.cast_to_concrete_if_possible(vcx, gen_snap) + }; + gen_snap + }; + let ghost_fields = { + let generator_ty = local_defs.locals[mir::Local::from(1_u32)].ty; + let predicate::PredicateEncData::StructLike(gen_domain_data) = generator_ty.specifics else { + panic!("expected generator domain to be struct-like"); + }; + let fields = gen_domain_data.snap_data.field_access; + let n_fields = fields.len(); + assert!(n_fields % 2 == 1); + fields.iter().take(n_fields - 1) + }; + let ghost_field_reads = ghost_fields + .map( + |field| field.read.apply(vcx, [generator_snap]) + ) + .collect::>(); + Some(ghost_field_reads) + }; + + // TODO: check what happens for a pure async fn let post_args = if pure { all_args - } else { + } else if !is_async { let post_args: Vec<_> = pre_args .iter() .map(|arg| vcx.mk_old_expr(arg)) .chain([local_defs.locals[mir::RETURN_PLACE].impure_snap]) .collect(); vcx.alloc_slice(&post_args) + } else { + // set the arguments available to the postcondition to be old-expressions (as the + // generator is consumed by the function) of the ghost fields as well as the return value + let return_expr = if !is_poll_stub { + local_defs.locals[mir::RETURN_PLACE].impure_snap + } else { + // construct the return type + let ret_ty = { + let gen_ty = vcx.tcx().type_of(def_id).skip_binder(); + let ty::TyKind::Generator(_, args, _) = gen_ty.kind() else { + panic!("expected async fn type to be Generator"); + }; + let ret_ty = args.as_generator().return_ty(); + let poll_def_id = vcx.tcx().require_lang_item(hir::LangItem::Poll, None); + vcx.tcx().mk_ty_from_kind(ty::TyKind::Adt( + vcx.tcx().adt_def(poll_def_id), + vcx.tcx().mk_args(&[ret_ty.into()]), + )) + }; + // and build an expression for the return value with that type + let ret_ty = deps.require_ref::(ret_ty)?; + ret_ty.ref_to_snap(vcx, local_defs.locals[0_u32.into()].local_ex) + }; + let gen_fields = async_generator_fields.as_ref().unwrap(); + let post_args = gen_fields + .iter() + .skip(gen_fields.len() / 2) + .map(|ghost_field| vcx.mk_old_expr(ghost_field)) + .chain(std::iter::once(return_expr)) + .collect::>(); + vcx.alloc_slice(&post_args) }; + + let mut mk_post_spec_expr = |spec_def_id: &DefId| { + let expr = deps + .require_local::( + crate::encoders::MirPureEncTask { + encoding_depth: 0, + kind: PureKind::Spec, + parent_def_id: *spec_def_id, + param_env: vcx.tcx().param_env(spec_def_id), + substs, + // TODO: should this be `def_id` or `caller_def_id` + caller_def_id: Some(def_id), + }, + ) + .unwrap() + .expr; + let expr = expr.reify(vcx, (*spec_def_id, post_args)); + to_bool.apply(vcx, [expr]) + }; + let posts = specs .posts .iter() - .map(|spec_def_id| { - let expr = deps - .require_local::( - crate::encoders::MirPureEncTask { - encoding_depth: 0, - kind: PureKind::Spec, - parent_def_id: *spec_def_id, - param_env: vcx.tcx().param_env(spec_def_id), - substs, - // TODO: should this be `def_id` or `caller_def_id` - caller_def_id: Some(def_id), - }, - ) - .unwrap() - .expr; - let expr = expr.reify(vcx, (*spec_def_id, post_args)); - to_bool.apply(vcx, [expr]) - }) + .map(&mut mk_post_spec_expr) .collect::>>(); + + // we also encode the wrapped postconditions for async poll stubs + // using the same arguments available to standard postconditions + // for non-async items, this will just be empty + let async_stub_posts = specs + .async_stub_posts + .iter() + .map(mk_post_spec_expr) + .collect::>>(); + + // async invariants are encoded using the same arguments as postconditions + // except for `result`, which cannot be used in async invariants + // async invariants also need to be encoded using the generator's ghost fields + // instead of the function arguments + // Note that they are *not* encoded using old-expressions of the ghsot fields, + // so they cannot be used as a postcondition on the generator's body (as the body + // consumes the generator). For the poll stub, there is no such restriction. + let async_invariants = if !is_async { + Vec::new() + } else { + let gen_fields = async_generator_fields.unwrap(); + let n_fields = gen_fields.len(); + let inv_args = vcx.alloc_slice( + &gen_fields + .into_iter() + .take(n_fields / 2) + .collect::>() + ); + specs + .async_invariants + .iter() + .map(|spec_def_id| { + let expr = deps + .require_local::( + crate::encoders::MirPureEncTask { + encoding_depth: 0, + kind: PureKind::Spec, + parent_def_id: *spec_def_id, + param_env: vcx.tcx().param_env(spec_def_id), + substs, + // TODO: should this be `def_id` or `caller_def_id` + caller_def_id: Some(def_id), + } + ) + .unwrap() + .expr; + let expr = expr.reify(vcx, (*spec_def_id, inv_args)); + to_bool.apply(vcx, [expr]) + }) + .collect::>>() + }; + let data = MirSpecEncOutput { pres, posts, + async_stub_posts, + async_invariants, pre_args, post_args, }; diff --git a/prusti-encoder/src/encoders/spec.rs b/prusti-encoder/src/encoders/spec.rs index 6d2431da65b..49e572296fb 100644 --- a/prusti-encoder/src/encoders/spec.rs +++ b/prusti-encoder/src/encoders/spec.rs @@ -18,6 +18,8 @@ pub struct SpecEncOutput<'vir> { //pub expr: vir::Expr<'vir>, pub pres: &'vir [DefId], pub posts: &'vir [DefId], + pub async_stub_posts: &'vir [DefId], + pub async_invariants: &'vir [DefId], } use std::cell::RefCell; @@ -93,7 +95,16 @@ impl TaskEncoder for SpecEnc { .and_then(|specs| specs.base_spec.posts.expect_empty_or_inherent()) .map(|specs| vcx.alloc_slice(specs)) .unwrap_or_default(); - Ok((SpecEncOutput { pres, posts, }, () )) + let async_stub_posts = specs + .and_then(|specs| specs.base_spec.async_stub_posts.expect_empty_or_inherent()) + .map(|specs| vcx.alloc_slice(specs)) + .unwrap_or_default(); + let async_invariants = specs + .and_then(|specs| specs.base_spec.async_invariants.expect_empty_or_inherent()) + .map(|specs| vcx.alloc_slice(specs)) + .unwrap_or_default(); + + Ok((SpecEncOutput { pres, posts, async_stub_posts, async_invariants, }, () )) }) }) } diff --git a/prusti-encoder/src/encoders/type/domain.rs b/prusti-encoder/src/encoders/type/domain.rs index 56d166a0531..a2cf3e74916 100644 --- a/prusti-encoder/src/encoders/type/domain.rs +++ b/prusti-encoder/src/encoders/type/domain.rs @@ -1,3 +1,5 @@ +use std::fmt::Write; + use prusti_rustc_interface::{ middle::ty::{self, TyKind, util::IntTypeExt, IntTy, UintTy}, abi, @@ -261,6 +263,62 @@ impl TaskEncoder for DomainEnc { let specifics = enc.mk_struct_specifics(vec![]); Ok((Some(enc.finalize(task_key)), specifics)) } + &TyKind::Generator(def_id, params, _movability) if vcx.tcx().generator_is_async(def_id) => { + // generators are encoded like a struct with one field per upvar, + // and additional ghost field per upvar (capturing their initial state), + // as well as a field for the generator's state + // the generics of that struct are given by the parent arguments + // (i.e. the async fn's and its parent's) + let gen_args = params.as_generator(); + let generics = gen_args + .parent_args() + .into_iter() + .filter_map(|arg| arg.as_type()) + .map(|ty| deps.require_local::>(ty).unwrap().expect_generic()) + .collect::>(); + let mut enc = DomainEncData::new(vcx, task_key, generics, deps); + enc.deps.emit_output_ref(*task_key, enc.output_ref(base_name))?; + let u32_ty = vcx.tcx().mk_ty_from_kind(ty::TyKind::Uint(UintTy::U32)); + let fields: Result, _> = gen_args + .upvar_tys() + .iter() + .chain(gen_args.upvar_tys().iter()) + .chain(std::iter::once(u32_ty)) + .map(|ty| FieldTy::from_ty(vcx, enc.deps, ty)) + .collect(); + let specifics = enc.mk_struct_specifics(fields?); + Ok((Some(enc.finalize(task_key)), specifics)) + }, + // FIXME: for now, we encode closures as wrapper struct-likes for their upvars + // in order to use them for async specifications + TyKind::Closure(_def_id, args) => { + // closures are encoded like a struct with one field per upvar + let args = args.as_closure(); + let generics = args + .parent_args() + .iter() + .filter_map(|arg| arg.as_type()) + .map(|ty| deps.require_local::>(ty).unwrap().expect_generic()) + .collect::>(); + let mut enc = DomainEncData::new(vcx, task_key, generics, deps); + enc.deps.emit_output_ref(*task_key, enc.output_ref(base_name))?; + let fields: Result, _> = args + .upvar_tys() + .iter() + .map(|ty| FieldTy::from_ty(vcx, enc.deps, ty)) + .collect(); + let specifics = enc.mk_struct_specifics(fields?); + Ok((Some(enc.finalize(task_key)), specifics)) + } + // FIXME: these are empty dummy domains to permit encoding async code + TyKind::FnPtr(_) + | TyKind::GeneratorWitness(_) + | TyKind::RawPtr(_) => { + let mut enc = DomainEncData::new(vcx, task_key, Vec::new(), deps); + enc.deps.emit_output_ref(*task_key, enc.output_ref(base_name))?; + let specifics = enc.mk_struct_specifics(Vec::new()); + Ok((Some(enc.finalize(task_key)), specifics)) + } kind => todo!("{kind:?}"), } }) diff --git a/prusti-encoder/src/encoders/type/lifted/aggregate_cast.rs b/prusti-encoder/src/encoders/type/lifted/aggregate_cast.rs index 710b09d8be5..e79f5a3a5d4 100644 --- a/prusti-encoder/src/encoders/type/lifted/aggregate_cast.rs +++ b/prusti-encoder/src/encoders/type/lifted/aggregate_cast.rs @@ -1,6 +1,6 @@ use prusti_rustc_interface::{ abi::VariantIdx, - middle::{mir, ty::{GenericArgs, Ty}}, + middle::{mir, ty::{self, GenericArgs, Ty}}, span::def_id::DefId, }; use task_encoder::{TaskEncoder, EncodeFullResult}; @@ -27,6 +27,9 @@ pub enum AggregateType { def_id: DefId, variant_index: VariantIdx, }, + Generator { + def_id: DefId, + } } impl From<&mir::AggregateKind<'_>> for AggregateType { @@ -39,6 +42,11 @@ impl From<&mir::AggregateKind<'_>> for AggregateType { variant_index: *variant_index, } } + mir::AggregateKind::Generator(def_id, gen_args, ..) => { + Self::Generator { + def_id: *def_id, + } + }, _ => unimplemented!(), } } @@ -119,6 +127,30 @@ impl TaskEncoder for AggregateSnapArgsCastEnc { }) .collect::>() } + AggregateType::Generator { def_id } => { + // TODO: is skipping the binder always correct? + let gen_ty = vcx.tcx().type_of(def_id).skip_binder(); + let ty::TyKind::Generator(_, gen_args, _) = gen_ty.kind() else { + panic!("expected generator type for AggregateType::Generator"); + }; + let gen_args = gen_args.as_generator(); + let n_fields = gen_args.upvar_tys().len(); + assert_eq!(n_fields, task_key.tys.len()); + gen_args + .upvar_tys() + .iter() + .zip(task_key.tys.iter()) + .map(|(field_ty, actual_ty)| { + let cast = deps + .require_ref::>(CastArgs { + expected: field_ty, + actual: *actual_ty, + }) + .unwrap(); + cast.cast_function() + }) + .collect::>() + } }; Ok(( AggregateSnapArgsCastEncOutput(vcx.alloc(cast_functions)), diff --git a/prusti-encoder/src/encoders/type/lifted/generic.rs b/prusti-encoder/src/encoders/type/lifted/generic.rs index 41c3af077e3..b2e03b06031 100644 --- a/prusti-encoder/src/encoders/type/lifted/generic.rs +++ b/prusti-encoder/src/encoders/type/lifted/generic.rs @@ -49,8 +49,11 @@ impl TaskEncoder for LiftedGenericEnc { deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, ) -> EncodeFullResult<'vir, Self> { with_vcx(|vcx| { + // NOTE: some generic parameters might have names that need to be mangled first, + // such as generators which have parameters "", "", "", ... let output_ref = vcx.mk_local_decl( - vcx.alloc_str(task_key.name.as_str()), + // vcx.alloc_str(task_key.name.as_str()), + vir::vir_format_identifier!(vcx, "{}", task_key.name).to_str(), deps.require_ref::(())?.type_snapshot, ); deps.emit_output_ref(*task_key, LiftedGeneric(output_ref))?; diff --git a/prusti-encoder/src/encoders/type/most_generic_ty.rs b/prusti-encoder/src/encoders/type/most_generic_ty.rs index 498b23e2267..4af07e7560b 100644 --- a/prusti-encoder/src/encoders/type/most_generic_ty.rs +++ b/prusti-encoder/src/encoders/type/most_generic_ty.rs @@ -1,4 +1,5 @@ use prusti_rustc_interface::{ + hir, middle::ty::{self, TyKind}, span::symbol, }; @@ -42,6 +43,21 @@ impl<'tcx> MostGenericTy<'tcx> { } }, TyKind::Param(_) => String::from("Param"), + // TODO: this is to avoid name clashes between the identical generator domains + // but we should find a better way to do this + TyKind::Generator(def_id, _, _) => + vir::vir_format_identifier!(vcx, "Generator_{}", vcx.tcx().def_path_str(def_id)).to_str().to_string(), + TyKind::GeneratorWitness(_) => String::from("GeneratorWitness"), + TyKind::RawPtr(ty::TypeAndMut { ty: _, mutbl }) => { + if mutbl.is_mut() { + String::from("RawPtr_mutable") + } else { + String::from("RawPtr_immutable") + } + } + TyKind::FnPtr(_) => String::from("FnPtr"), + TyKind::Closure(def_id, _) => + vir::vir_format_identifier!(vcx, "Closure_{}", vcx.tcx().def_path_str(def_id)).to_str().to_string(), other => unimplemented!("get_domain_base_name for {:?}", other), } } @@ -91,6 +107,20 @@ impl<'tcx> MostGenericTy<'tcx> { | TyKind::Never | TyKind::Param(_) | TyKind::Uint(_) => Vec::new(), + // NOTE: for now this is only for generators originating from async items + TyKind::Generator(_, args, _) => args + .as_generator() + .parent_args() + .into_iter() + .flat_map(|arg| arg.as_type()) + .map(as_param_ty) + .collect(), + // FIXME: these are only in here to permit encoding async code + TyKind::RawPtr(_) + | TyKind::Str + | TyKind::FnPtr(_) + | TyKind::GeneratorWitness(_) + | TyKind::Closure(..)=> Vec::new(), other => todo!("generics for {:?}", other), } } @@ -152,6 +182,145 @@ pub fn extract_type_params<'tcx>( TyKind::Bool | TyKind::Char | TyKind::Int(_) | TyKind::Uint(_) | TyKind::Float(_) | TyKind::Never | TyKind::Str => { (MostGenericTy(ty), Vec::new()) } + // for now, we replace OTAs by their underlying type in order to ensure that OTAs to async + // generators are correctly encoded to the same viper type as the generator + TyKind::Alias(ty::AliasKind::Opaque, _alias_ty) => { + let underlying = tcx.expand_opaque_types(ty); + extract_type_params(tcx, underlying) + } + TyKind::Generator(def_id, args, movability) if tcx.generator_is_async(def_id) => { + let args = args.as_generator(); + // the only generic arguments we need to include are the parent arguments, + // i.e. those present in the async fn (or its outer scope) + // the other generic arguments arise due to the future's representation as a + // generator with , , , , + // and the tupled upvar-ty, all of which are fixed for a given future + // (or even all futures) + let parent_args = args + .parent_args() + .into_iter() + .flat_map(|arg| ty::GenericArg::as_type(*arg)) + .collect::>(); + // we use `List::identity_for_item` to get generic parameters with correct names, + // as creating placeholders ourselves might result in misnaming parameters + // that are used in the upvars (attempting to recreate the names from the parent-args + // also doesn't always work, since they might already be substituted in the parent-args) + // the parent generic parameters are now contained in the front of the result + // obtained from `List::identity_for_item` + // TODO: verify that this is stable + { + // sanity-check: number of generic arguments (and type params among them) matches + let id = ty::List::identity_for_item(tcx, def_id); + assert_eq!(id.len(), args.parent_args().len() + 5); + let id = id.into_iter().flat_map(ty::GenericArg::as_type).collect::>(); + assert_eq!(id.len(), parent_args.len() + 5); + // TODO: should we check that those afterwards are actually called and so on? + } + let parent_id: Vec = ty::List::identity_for_item(tcx, def_id) + .into_iter() + .flat_map(ty::GenericArg::as_type) + .take(parent_args.len()) + .map(|arg| arg.into()) + .collect(); + // we also use a dummy witness type to avoid encoding the same generator twice, + // as the witness type appears once with the generator and once with the OTA to it + let dummy_witness = tcx.mk_ty_from_kind( + ty::TyKind::GeneratorWitness(ty::Binder::dummy(ty::List::empty())) + ); + // note that the upvar types given in the generic arguments might already contain + // substitutions for some of the async item's type parameters, so we use the `TyCtxt` + // to obtain the generator's type without any substitutions + let generic_upvars_ty = { + let gen_ty = tcx.type_of(def_id).skip_binder(); + let TyKind::Generator(_, args, _) = gen_ty.kind() else { + panic!("TyCtxt::type_of returned non-generator type for generator DefId"); + }; + args.as_generator().tupled_upvars_ty() + }; + let id_parts = ty::GeneratorArgsParts { + parent_args: tcx.mk_args(&parent_id), + resume_ty: args.resume_ty(), + yield_ty: args.yield_ty(), + return_ty: args.return_ty(), + witness: dummy_witness, + tupled_upvars_ty: generic_upvars_ty, + }; + let id_args = ty::GeneratorArgs::new(tcx, id_parts); + let ty = tcx.mk_ty_from_kind( + TyKind::Generator(def_id, id_args.args, movability) + ); + (MostGenericTy(ty), parent_args) + } + // FIXME: these are only dummies to permit encoding async code + TyKind::GeneratorWitness(_) => { + // we erase generic args inside the witness to avoid encoding + // the same dummy domain twice + let dummy_witness_ty = TyKind::GeneratorWitness(ty::Binder::dummy(ty::List::empty())); + (MostGenericTy(tcx.mk_ty_from_kind(dummy_witness_ty)), Vec::new()) + }, + // FIXME: these are only dummies to permit encoding async code + TyKind::FnPtr(binder) => { + // we erase the signature to avoid encoding + // the same dummy domain twice + let sig = binder.skip_binder(); + let unit_ty = tcx.mk_ty_from_kind(TyKind::Tuple(ty::List::empty())); + let dummy_sig = ty::Binder::dummy(ty::FnSig { + inputs_and_output: tcx.mk_type_list(&[unit_ty]), + unsafety: hir::Unsafety::Normal, + ..sig + }); + let ty = tcx.mk_ty_from_kind(TyKind::FnPtr(dummy_sig)); + (MostGenericTy(ty), Vec::new()) + } + // FIXME: these are only dummies to permit encoding async code + TyKind::RawPtr(ty::TypeAndMut { ty: orig, mutbl }) => { + let ty = to_placeholder(tcx, None); + let ty = tcx.mk_ty_from_kind(TyKind::RawPtr(ty::TypeAndMut { ty, mutbl })); + (MostGenericTy(ty), Vec::new()) + // (MostGenericTy(ty), vec![orig]) + } + // FIXME: for now, we encode closures like simple wrappers around their upvars in order to + // use them for on-exit/on-entry conditions in async code + TyKind::Closure(def_id, args) => { + // analogous to generator + let args = args.as_closure(); + let parent_args = args + .parent_args() + .into_iter() + .flat_map(|arg| arg.as_type()) + .collect::>(); + // sanity-checks + let id_args = ty::List::identity_for_item(tcx, def_id); + { + assert_eq!(id_args.len(), args.parent_args().len() + 3); + let id = id_args.into_iter().flat_map(ty::GenericArg::as_type).collect::>(); + assert_eq!(id.len(), parent_args.len() + 3); + } + let parent_id: Vec = id_args + .into_iter() + .flat_map(ty::GenericArg::as_type) + .take(parent_args.len()) + .map(|arg| arg.into()) + .collect(); + let most_generic_args = { + let closure_ty = tcx.type_of(def_id).skip_binder(); + let TyKind::Closure(_, args) = closure_ty.kind() else { + panic!("TyCtxt::type_of returned non-closure type for closure DefId") + }; + args.as_closure() + }; + let id_parts = ty::ClosureArgsParts { + parent_args: tcx.mk_args(&parent_id), + closure_kind_ty: most_generic_args.kind_ty(), + closure_sig_as_fn_ptr_ty: most_generic_args.sig_as_fn_ptr_ty(), + tupled_upvars_ty: most_generic_args.tupled_upvars_ty(), + }; + let id_args = ty::ClosureArgs::new(tcx, id_parts); + let ty = tcx.mk_ty_from_kind( + TyKind::Closure(def_id, id_args.args) + ); + (MostGenericTy(ty), parent_args) + } _ => todo!("extract_type_params for {:?}", ty), } } diff --git a/prusti-encoder/src/encoders/type/predicate.rs b/prusti-encoder/src/encoders/type/predicate.rs index ac0a526a0f1..dd04529b2dd 100644 --- a/prusti-encoder/src/encoders/type/predicate.rs +++ b/prusti-encoder/src/encoders/type/predicate.rs @@ -369,6 +369,64 @@ impl TaskEncoder for PredicateEnc { ); Ok((enc.mk_prim(&snap.base_name), ())) } + TyKind::Generator(def_id, args, _m) if enc.vcx.tcx().generator_is_async(*def_id) => { + // generators are encoded like a struct with one field per upvar, + // ghost field per upvar capturing that upvar's initial state, + // as well as a field for the generator's state + let snap_data = snap.specifics.expect_structlike(); + let specifics = enc.mk_struct_ref(None, snap_data); + deps.emit_output_ref( + *task_key, + enc.output_ref(PredicateEncData::StructLike(specifics)) + )?; + let upvar_tys = args.as_generator().upvar_tys(); + let u32_ty = enc.vcx.tcx().mk_ty_from_kind(ty::TyKind::Uint(ty::UintTy::U32)); + let fields: Result, _> = upvar_tys + .into_iter() + .chain(upvar_tys) + .chain(std::iter::once(u32_ty)) + .map(|ty| deps.require_ref::(ty)) + .collect(); + let fields = enc.mk_field_apps(specifics.ref_to_field_refs, fields?); + let fn_snap_body = + enc.mk_struct_ref_to_snap_body(None, fields, snap_data.field_snaps_to_snap); + Ok((enc.mk_struct(fn_snap_body), ())) + } + // FIXME: for now, we encode closures as wrapper struct-likes for their upvars + // in order to use them for async specifications + TyKind::Closure(_def_id, args) => { + // closures are encoded like a struct with one field per upvar + let snap_data = snap.specifics.expect_structlike(); + let specifics = enc.mk_struct_ref(None, snap_data); + deps.emit_output_ref( + *task_key, + enc.output_ref(PredicateEncData::StructLike(specifics)) + )?; + let fields: Result, _> = args + .as_closure() + .upvar_tys() + .into_iter() + .map(|ty| deps.require_ref::(ty)) + .collect(); + let fields = enc.mk_field_apps(specifics.ref_to_field_refs, fields?); + let fn_snap_body = + enc.mk_struct_ref_to_snap_body(None, fields, snap_data.field_snaps_to_snap); + Ok((enc.mk_struct(fn_snap_body), ())) + } + // FIXME: these are empty dummy domains to permit encoding async code + TyKind::FnPtr(_) + | TyKind::GeneratorWitness(_) + | TyKind::RawPtr(_) => { + let snap_data = snap.specifics.expect_structlike(); + let specifics = enc.mk_struct_ref(None, snap_data); + deps.emit_output_ref( + *task_key, + enc.output_ref(PredicateEncData::StructLike(specifics)) + ); + let fn_snap_body = + enc.mk_struct_ref_to_snap_body(None, Vec::new(), snap_data.field_snaps_to_snap); + Ok((enc.mk_struct(fn_snap_body), ())) + } unsupported_type => todo!("type not supported: {unsupported_type:?}"), } } @@ -687,16 +745,18 @@ impl<'vir, 'tcx> PredicateEncValues<'vir, 'tcx> { )); let inner_snap = inner.ref_to_snap.apply(self.vcx, inner_ref_to_args); - let snap = if data.perm.is_none() { - // `Ref` is only part of snapshots for mutable references. - data.snap_data - .field_snaps_to_snap - .apply(self.vcx, &[inner_snap, self_ref]) - } else { - data.snap_data - .field_snaps_to_snap - .apply(self.vcx, &[inner_snap]) - }; + // FIXME: this does not work as of now, so we never use the self-ref + let snap = data.snap_data.field_snaps_to_snap.apply(self.vcx, &[inner_snap]); + // let snap = if data.perm.is_none() { + // // `Ref` is only part of snapshots for mutable references. + // data.snap_data + // .field_snaps_to_snap + // .apply(self.vcx, &[inner_snap, self_ref]) + // } else { + // data.snap_data + // .field_snaps_to_snap + // .apply(self.vcx, &[inner_snap]) + // }; let fn_snap_body = self.vcx.mk_unfolding_expr(self.self_pred_read, snap); self.finalize(Some(fn_snap_body)) } diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index bb01c9cf99b..528d98bae74 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -56,6 +56,25 @@ pub fn test_entrypoint<'tcx>( assert!(res.is_ok()); } } + // FIXME: for now this is only experimental + hir::def::DefKind::Generator if tcx.generator_is_async(def_id.into()) => { + let def_id = def_id.to_def_id(); + if prusti_interface::specs::is_spec_fn(tcx, def_id) { + continue; + } + + let (is_pure, is_trusted) = crate::encoders::with_proc_spec(def_id, |proc_spec| { + let is_pure = proc_spec.kind.is_pure().unwrap_or_default(); + let is_trusted = proc_spec.trusted.extract_inherit().unwrap_or_default(); + (is_pure, is_trusted) + }).unwrap_or_default(); + + if !(is_trusted && is_pure) { + let substs = ty::GenericArgs::identity_for_item(tcx, def_id); + let res = crate::encoders::MirPolyImpureEnc::encode(def_id, false); + assert!(res.is_ok()); + } + }, unsupported_item_kind => { tracing::debug!("unsupported item: {unsupported_item_kind:?}"); } @@ -75,6 +94,12 @@ pub fn test_entrypoint<'tcx>( let mut program_functions = vec![]; let mut program_methods = vec![]; + header(&mut viper_code, "async stubs"); + for output in crate::encoders::AsyncPollStubEnc::all_outputs() { + viper_code.push_str(&format!("{:?}\n", output.method)); + program_methods.push(output.method); + } + // We output results from both monomorphic and polymorphic encoding of // functions, because even when Prusti is configured to use the monomorphic // it will still use `MirPolyImpureEnc` directly sometimes (see usages diff --git a/prusti-interface/src/specs/mod.rs b/prusti-interface/src/specs/mod.rs index cf5414e1d6f..b953c3ab104 100644 --- a/prusti-interface/src/specs/mod.rs +++ b/prusti-interface/src/specs/mod.rs @@ -17,7 +17,7 @@ use prusti_rustc_interface::{ def_id::{DefId, LocalDefId}, intravisit, FnRetTy, }, - middle::{hir::map::Map, ty}, + middle::{hir::map::{self as hir_map, Map}, ty}, span::Span, }; use std::{convert::TryInto, fmt::Debug}; @@ -33,10 +33,11 @@ use typed::SpecIdRef; use crate::specs::{ external::ExternSpecResolver, - typed::{ProcedureSpecification, ProcedureSpecificationKind, SpecGraph, SpecificationItem}, + typed::{ProcedureSpecification, ProcedureSpecificationKind, SpecGraph, SpecificationItem, ProcedureKind}, }; use prusti_specs::specifications::common::SpecificationId; + #[derive(Debug)] struct ProcedureSpecRefs { spec_id_refs: Vec, @@ -86,6 +87,9 @@ pub struct SpecCollector<'a, 'tcx> { prusti_refutations: Vec, ghost_begin: Vec, ghost_end: Vec, + + /// Map from future constructors to their poll methods + async_parent: FxHashMap, } impl<'a, 'tcx> SpecCollector<'a, 'tcx> { @@ -103,6 +107,7 @@ impl<'a, 'tcx> SpecCollector<'a, 'tcx> { prusti_refutations: vec![], ghost_begin: vec![], ghost_end: vec![], + async_parent: FxHashMap::default(), } } @@ -150,6 +155,12 @@ impl<'a, 'tcx> SpecCollector<'a, 'tcx> { self.env, ); } + // both async postconditions and async invariants are not added to the method + // they were attached to (which ends up being the future constructor) but to + // the poll method, which is determined elsewhere + SpecIdRef::AsyncPostcondition(_) + | SpecIdRef::AsyncStubPostcondition(_) + | SpecIdRef::AsyncInvariant(_) => {}, SpecIdRef::Purity(spec_id) => { spec.add_purity(*self.spec_functions.get(spec_id).unwrap(), self.env); } @@ -191,6 +202,50 @@ impl<'a, 'tcx> SpecCollector<'a, 'tcx> { def_spec.proc_specs.insert(local_id.to_def_id(), spec); } } + + // attach async specifications to the future's poll-methods instead of the future + // constructor + for (local_id, parent_id) in self.async_parent.iter() { + // look up parent's spec and skip this method if the parent doesn't have one + let Some(parent_spec) = self.procedure_specs.get(parent_id) else { + continue; + }; + + // mark parent as an async constructor + def_spec + .proc_specs + .get_mut(&parent_id.to_def_id()) + .expect("async parent must have entry in DefSpecMap") + .set_proc_kind(ProcedureKind::AsyncConstructor); + + // the spec is then essentially inherited from the parent, + // but for now, only postconditions and trusted are allowed + let mut spec = SpecGraph::new(ProcedureSpecification::empty(local_id.to_def_id())); + spec.set_proc_kind(ProcedureKind::AsyncPoll); + spec.set_kind(parent_spec.into()); + spec.set_trusted(parent_spec.trusted); + for spec_id_ref in &parent_spec.spec_id_refs { + match spec_id_ref { + SpecIdRef::AsyncPostcondition(spec_id) => { + spec.add_postcondition(*self.spec_functions.get(spec_id).unwrap(), self.env); + } + SpecIdRef::AsyncStubPostcondition(spec_id) => { + spec.add_async_stub_postcondition(*self.spec_functions.get(spec_id).unwrap(), self.env); + } + SpecIdRef::AsyncInvariant(spec_id) => { + spec.add_async_invariant(*self.spec_functions.get(spec_id).unwrap(), self.env); + } + // all other spec items should stay attached to the original method + _ => {}, + } + } + // poll methods should not already be covered by the existing code, + // as they cannot be annotated by the user + let old = def_spec.proc_specs.insert(local_id.to_def_id(), spec); + if old.is_some() { + panic!("async fn poll method {local_id:?} already had a spec"); + } + } } fn determine_extern_specs(&self, def_spec: &mut typed::DefSpecificationMap) { @@ -381,6 +436,21 @@ fn get_procedure_spec_ids(def_id: DefId, attrs: &[ast::Attribute]) -> Option intravisit::Visitor<'tcx> for SpecCollector<'a, 'tcx> { let def_id = local_id.to_def_id(); let attrs = self.env.query.get_local_attributes(local_id); + // if this is a closure representing an async fn's poll, locate and store its parent + if self.env.tcx().generator_is_async(def_id) { + let hir_id = self.env.tcx().hir().local_def_id_to_hir_id(local_id); + let parent = self.env.tcx().hir() + .find_parent(hir_id) + .expect("expected async-fn-generator to have a parent"); + // we can get the parent's LocalDefId via its associated body + let (parent_def_id, _) = hir_map::associated_body(parent) + .expect("async-fn-generator parent should have a body"); + assert!(self.env.tcx().asyncness(parent_def_id.to_def_id()).is_async(), "found async generator with non-async parent"); + assert!(!self.env.tcx().generator_is_async(parent_def_id.to_def_id()), "found async generator whose parent is also async generator"); + let old = self.async_parent.insert(local_id, parent_def_id); + if old.is_some() { + panic!("parent {:?} of async-generator {:?} already has parent", old.unwrap(), local_id); + } + } + // Collect spec functions if let Some(raw_spec_id) = read_prusti_attr("spec_id", attrs) { let spec_id: SpecificationId = parse_spec_id(raw_spec_id, def_id); diff --git a/prusti-interface/src/specs/typed.rs b/prusti-interface/src/specs/typed.rs index 2041266c555..dfcd5369dcb 100644 --- a/prusti-interface/src/specs/typed.rs +++ b/prusti-interface/src/specs/typed.rs @@ -82,6 +82,12 @@ impl DefSpecificationMap { if let Some(posts) = spec.posts.extract_with_selective_replacement() { specs.extend(posts); } + if let Some(posts) = spec.async_stub_posts.extract_with_selective_replacement() { + specs.extend(posts); + } + if let Some(invs) = spec.async_invariants.extract_with_selective_replacement() { + specs.extend(invs); + } if let Some(Some(term)) = spec.terminates.extract_with_selective_replacement() { specs.push(term.to_def_id()); } @@ -202,6 +208,13 @@ impl DefSpecificationMap { } } +#[derive(Debug, Copy, Clone, TyEncodable, TyDecodable)] +pub enum ProcedureKind { + Method, + AsyncConstructor, + AsyncPoll, +} + #[derive(Debug, Clone, TyEncodable, TyDecodable)] pub struct ProcedureSpecification { // DefId of fn signature to which the spec was attached. @@ -210,10 +223,13 @@ pub struct ProcedureSpecification { pub kind: SpecificationItem, pub pres: SpecificationItem>, pub posts: SpecificationItem>, + pub async_stub_posts: SpecificationItem>, + pub async_invariants: SpecificationItem>, pub pledges: SpecificationItem>, pub trusted: SpecificationItem, pub terminates: SpecificationItem>, pub purity: SpecificationItem>, // for type-conditional spec refinements + pub proc_kind: ProcedureKind, } impl ProcedureSpecification { @@ -225,10 +241,13 @@ impl ProcedureSpecification { kind: SpecificationItem::Inherent(ProcedureSpecificationKind::Impure), pres: SpecificationItem::Empty, posts: SpecificationItem::Empty, + async_stub_posts: SpecificationItem::Empty, + async_invariants: SpecificationItem::Empty, pledges: SpecificationItem::Empty, trusted: SpecificationItem::Inherent(false), terminates: SpecificationItem::Inherent(None), purity: SpecificationItem::Inherent(None), + proc_kind: ProcedureKind::Method, } } } @@ -453,6 +472,46 @@ impl SpecGraph { } } + /// Attaches the async stub postcondition `post` to this [SpecGraph]. + /// + /// If this postcondition has a constraint it will be attached to the corresponding + /// constrained spec **and** the base spec, otherwise just to the base spec. + pub fn add_async_stub_postcondition<'tcx>(&mut self, post: LocalDefId, env: &Environment<'tcx>) { + match self.get_constraint(post, env) { + None => { + self.base_spec.async_stub_posts.push(post.to_def_id()); + self.specs_with_constraints + .values_mut() + .for_each(|s| s.async_stub_posts.push(post.to_def_id())); + } + Some(obligation) => { + self.get_constrained_spec_mut(obligation) + .async_stub_posts + .push(post.to_def_id()); + } + } + } + + /// Attaches the async invariant `inv` to this [SpecGraph]. + /// + /// If this postcondition has a constraint it will be attached to the corresponding + /// constrained spec **and** the base spec, otherwise just to the base spec. + pub fn add_async_invariant<'tcx>(&mut self, inv: LocalDefId, env: &Environment<'tcx>) { + match self.get_constraint(inv, env) { + None => { + self.base_spec.async_invariants.push(inv.to_def_id()); + self.specs_with_constraints + .values_mut() + .for_each(|s| s.async_invariants.push(inv.to_def_id())); + } + Some(obligation) => { + self.get_constrained_spec_mut(obligation) + .async_invariants + .push(inv.to_def_id()); + } + } + } + pub fn add_purity<'tcx>(&mut self, purity: LocalDefId, env: &Environment<'tcx>) { match self.get_constraint(purity, env) { None => { @@ -502,6 +561,14 @@ impl SpecGraph { .for_each(|s| s.kind.set(kind)); } + /// Sets the [ProcedureKind] for the base spec and all constrained specs. + pub fn set_proc_kind(&mut self, proc_kind: ProcedureKind) { + self.base_spec.proc_kind = proc_kind; + self.specs_with_constraints + .values_mut() + .for_each(|s| s.proc_kind = proc_kind); + } + /// Lazily gets/creates a constrained spec. /// If the constrained spec does not yet exist, the base spec serves as a template for /// the newly created constrained spec. @@ -777,11 +844,14 @@ impl Refinable for ProcedureSpecification { source: self.source, pres: self.pres.refine(replace_empty(&EMPTYL, &other.pres)), posts: self.posts.refine(replace_empty(&EMPTYL, &other.posts)), + async_stub_posts: self.async_stub_posts.refine(replace_empty(&EMPTYL, &other.async_stub_posts)), + async_invariants: self.async_invariants.refine(replace_empty(&EMPTYL, &other.async_invariants)), pledges: self.pledges.refine(replace_empty(&EMPTYP, &other.pledges)), kind: self.kind.refine(&other.kind), trusted: self.trusted.refine(&other.trusted), terminates: self.terminates.refine(&other.terminates), purity: self.purity.refine(&other.purity), + proc_kind: self.proc_kind, } } } diff --git a/vir/src/debug.rs b/vir/src/debug.rs index 3ebe9322a2a..b66716545dc 100644 --- a/vir/src/debug.rs +++ b/vir/src/debug.rs @@ -258,6 +258,9 @@ impl<'vir, Curr, Next> Debug for MethodGenData<'vir, Curr, Next> { writeln!(f, "{{")?; for block in body.blocks.iter() { writeln!(f, "label {:?}", block.label)?; + for inv in block.invariants { + writeln!(f, " invariant {:?}", inv); + } for stmt in block.stmts { writeln!(f, " {:?}", stmt)?; } diff --git a/vir/src/gendata.rs b/vir/src/gendata.rs index a1470cf67b7..ef560105a9c 100644 --- a/vir/src/gendata.rs +++ b/vir/src/gendata.rs @@ -318,6 +318,7 @@ pub struct CfgBlockGenData<'vir, Curr, Next> { #[vir(reify_pass, is_ref)] pub label: CfgBlockLabel<'vir>, pub stmts: &'vir [StmtGen<'vir, Curr, Next>], pub terminator: TerminatorStmtGen<'vir, Curr, Next>, + pub invariants: &'vir [ExprGen<'vir, Curr, Next>], } #[derive(VirHash, VirReify, VirSerde)] diff --git a/vir/src/make.rs b/vir/src/make.rs index 698ae99376f..cd4a3d91d27 100644 --- a/vir/src/make.rs +++ b/vir/src/make.rs @@ -522,6 +522,13 @@ impl<'tcx> VirCtxt<'tcx> { self.alloc(StmtGenData::Exhale(expr)) } + pub fn mk_inhale_stmt<'vir, Curr, Next>( + &'vir self, + expr: ExprGen<'vir, Curr, Next>, + ) -> StmtGen<'vir, Curr, Next> { + self.alloc(StmtGenData::Inhale(expr)) + } + pub fn mk_unfold_stmt<'vir, Curr, Next>( &'vir self, pred_app: PredicateAppGen<'vir, Curr, Next> @@ -632,11 +639,13 @@ impl<'tcx> VirCtxt<'tcx> { label: CfgBlockLabel<'vir>, stmts: &'vir [StmtGen<'vir, Curr, Next>], terminator: TerminatorStmtGen<'vir, Curr, Next>, + invariants: &'vir [ExprGen<'vir, Curr, Next>], ) -> CfgBlockGen<'vir, Curr, Next> { self.alloc(CfgBlockGenData { label, stmts, - terminator + terminator, + invariants, }) } diff --git a/vir/src/viper_ident.rs b/vir/src/viper_ident.rs index ec5557f0393..14f2e01393f 100644 --- a/vir/src/viper_ident.rs +++ b/vir/src/viper_ident.rs @@ -38,6 +38,9 @@ fn sanitize_char(c: char) -> Option { ' ' => Some("$space$".to_string()), ',' => Some("$comma$".to_string()), ':' => Some("$colon$".to_string()), + '{' => Some("$opencur$".to_string()), + '}' => Some("$closecur$".to_string()), + '#' => Some("$hash$".to_string()), _ => None, } }