Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions prusti-contracts/prusti-contracts-proc-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ pub fn ensures(_attr: TokenStream, tokens: TokenStream) -> TokenStream {
tokens
}

#[cfg(not(feature = "prusti"))]
#[proc_macro_attribute]
pub fn async_invariant(_attr: TokenStream, tokens: TokenStream) -> TokenStream {
tokens
}

#[cfg(not(feature = "prusti"))]
#[proc_macro_attribute]
pub fn after_expiry(_attr: TokenStream, tokens: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -130,6 +136,12 @@ pub fn body_variant(_tokens: TokenStream) -> TokenStream {
TokenStream::new()
}

#[cfg(not(feature = "prusti"))]
#[proc_macro]
pub fn suspension_point(tokens: TokenStream) -> TokenStream {
tokens
}

// ----------------------
// --- PRUSTI ENABLED ---

Expand All @@ -148,6 +160,12 @@ pub fn ensures(attr: TokenStream, tokens: TokenStream) -> TokenStream {
rewrite_prusti_attributes(SpecAttributeKind::Ensures, attr.into(), tokens.into()).into()
}

#[cfg(feature = "prusti")]
#[proc_macro_attribute]
pub fn async_invariant(attr: TokenStream, tokens: TokenStream) -> TokenStream {
rewrite_prusti_attributes(SpecAttributeKind::AsyncInvariant, attr.into(), tokens.into()).into()
}

#[cfg(feature = "prusti")]
#[proc_macro_attribute]
pub fn after_expiry(attr: TokenStream, tokens: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -273,5 +291,11 @@ pub fn body_variant(tokens: TokenStream) -> TokenStream {
prusti_specs::body_variant(tokens.into()).into()
}

#[cfg(feature = "prusti")]
#[proc_macro]
pub fn suspension_point(tokens: TokenStream) -> TokenStream {
prusti_specs::suspension_point(tokens.into()).into()
}

// Ensure that you've also crated a transparent `#[cfg(not(feature = "prusti"))]`
// version of your new macro above!
10 changes: 10 additions & 0 deletions prusti-contracts/prusti-contracts/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ pub use prusti_contracts_proc_macros::terminates;
/// A macro to annotate body variant of a loop to prove termination
pub use prusti_contracts_proc_macros::body_variant;

/// A macro to annotate suspension points inside async constructs
pub use prusti_contracts_proc_macros::suspension_point;

/// A macro for writing async invariants on an async function
pub use prusti_contracts_proc_macros::async_invariant;

#[cfg(not(feature = "prusti"))]
mod private {
use core::marker::PhantomData;
Expand Down Expand Up @@ -339,6 +345,10 @@ pub fn old<T>(arg: T) -> T {
arg
}

pub fn suspension_point_on_exit_marker<T>(_label: u32, _closures: T) {}

pub fn suspension_point_on_entry_marker<T>(_label: u32, _closures: T) {}

/// Universal quantifier.
///
/// This is a Prusti-internal representation of the `forall` syntax.
Expand Down
233 changes: 233 additions & 0 deletions prusti-contracts/prusti-specs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ fn extract_prusti_attributes(
let tokens = match attr_kind {
SpecAttributeKind::Requires
| SpecAttributeKind::Ensures
| SpecAttributeKind::AsyncInvariant
| SpecAttributeKind::AfterExpiry
| SpecAttributeKind::AssertOnExpiry
| SpecAttributeKind::RefineSpec => {
Expand All @@ -81,6 +82,10 @@ fn extract_prusti_attributes(
assert!(iter.next().is_none(), "Unexpected shape of an attribute.");
group.stream()
}
// these cannot appear here, as postconditions are only marked as async
// at a later stage
SpecAttributeKind::AsyncEnsures =>
unreachable!("SpecAttributeKind::AsyncEnsures should not appear at this stage"),
// Nothing to do for attributes without arguments.
SpecAttributeKind::Pure
| SpecAttributeKind::Terminates
Expand Down Expand Up @@ -125,6 +130,39 @@ pub fn rewrite_prusti_attributes(
// Collect the remaining Prusti attributes, removing them from `item`.
prusti_attributes.extend(extract_prusti_attributes(&mut item));

// in the case of an async fn, mark the postconditions as such,
// since they are handled differently
if let untyped::AnyFnItem::Fn(ref fn_item) = &item {
if fn_item.sig.asyncness.is_some() {
prusti_attributes = prusti_attributes
.into_iter()
.map(|(attr, tt)| match attr {
SpecAttributeKind::Ensures => (SpecAttributeKind::AsyncEnsures, tt),
_ => (attr, tt)
}
)
.collect();
}
}

// make sure async invariants can only be attached to async fn's
if matches!(outer_attr_kind, SpecAttributeKind::AsyncInvariant) {
// TODO: should the error be reported to the function's or the attribute's span?
let untyped::AnyFnItem::Fn(ref fn_item) = &item else {
return syn::Error::new(
item.span(),
"async_invariant attached to non-function item"
).to_compile_error();
};
if fn_item.sig.asyncness.is_none() {
return syn::Error::new(
item.span(),
"async_invariant attached to non-async function"
)
.to_compile_error();
}
}

// make sure to also update the check in the predicate! handling method
if prusti_attributes
.iter()
Expand Down Expand Up @@ -162,6 +200,8 @@ fn generate_spec_and_assertions(
let rewriting_result = match attr_kind {
SpecAttributeKind::Requires => generate_for_requires(attr_tokens, item),
SpecAttributeKind::Ensures => generate_for_ensures(attr_tokens, item),
SpecAttributeKind::AsyncEnsures => generate_for_async_ensures(attr_tokens, item),
SpecAttributeKind::AsyncInvariant => generate_for_async_invariant(attr_tokens, item),
SpecAttributeKind::AfterExpiry => generate_for_after_expiry(attr_tokens, item),
SpecAttributeKind::AssertOnExpiry => generate_for_assert_on_expiry(attr_tokens, item),
SpecAttributeKind::Pure => generate_for_pure(attr_tokens, item),
Expand Down Expand Up @@ -215,6 +255,55 @@ fn generate_for_ensures(attr: TokenStream, item: &untyped::AnyFnItem) -> Generat
))
}

/// Generate spec items and attributes to typecheck and later retrieve "ensures" annotations,
/// but for async fn, as their postconditions will need to be moved to their generator method
/// instead of the future-constructor.
fn generate_for_async_ensures(attr: TokenStream, item: &untyped::AnyFnItem) -> GeneratedResult {
let mut rewriter = rewriter::AstRewriter::new();
// parse the postcondition
let expr = parse_prusti(attr)?;
// generate a spec item for the postcondition itself,
// which will be attached to the future's implementation
let spec_id = rewriter.generate_spec_id();
let spec_id_str = spec_id.to_string();
let spec_item =
rewriter.generate_spec_item_fn(rewriter::SpecItemType::Postcondition, spec_id, expr.clone(), item)?;

// and generate one wrapped in a `Poll` for the stub
let stub_spec_id = rewriter.generate_spec_id();
let stub_spec_id_str = stub_spec_id.to_string();
let stub_spec_item = rewriter.generate_async_ensures_item_fn(stub_spec_id, expr, item)?;

Ok((
vec![spec_item, stub_spec_item],
vec![
parse_quote_spanned! {item.span()=>
#[prusti::async_post_spec_id_ref = #spec_id_str]
},
parse_quote_spanned! {item.span()=>
#[prusti::async_stub_post_spec_id_ref = #stub_spec_id_str]
},
]
))
}


/// Generate spec items and attributes to typecheck and later retrieve invariant annotations
/// on async functions.
fn generate_for_async_invariant(attr: TokenStream, item: &untyped::AnyFnItem) -> GeneratedResult {
let mut rewriter = rewriter::AstRewriter::new();
let spec_id = rewriter.generate_spec_id();
let spec_id_str = spec_id.to_string();
let spec_item =
rewriter.process_assertion(rewriter::SpecItemType::Precondition, spec_id, attr, item)?;
Ok((
vec![spec_item],
vec![parse_quote_spanned! {item.span()=>
#[prusti::async_inv_spec_id_ref = #spec_id_str]
}]
))
}

/// Generate spec items and attributes to typecheck and later retrieve "after_expiry" annotations.
fn generate_for_after_expiry(attr: TokenStream, item: &untyped::AnyFnItem) -> GeneratedResult {
let mut rewriter = rewriter::AstRewriter::new();
Expand Down Expand Up @@ -433,6 +522,146 @@ pub fn prusti_refutation(tokens: TokenStream) -> TokenStream {
generate_expression_closure(&AstRewriter::process_prusti_refutation, tokens)
}


pub fn suspension_point(tokens: TokenStream) -> TokenStream {
// parse the expression inside the suspension point,
// making sure it is just an await (possibly with attributes)
let expr: syn::Expr = handle_result!(syn::parse2(tokens));
let expr_span = expr.span();
let syn::Expr::Await(mut await_expr) = expr else {
return syn::Error::new(
expr.span(),
"suspension-point must contain a single await-expression"
).to_compile_error();
};
let future = &await_expr.base;

let mut label: Option<u32> = None;
let mut on_exit: Vec<TokenStream> = Vec::new();
let mut on_entry: Vec<TokenStream> = Vec::new();

// extract label, on-exit, and on-entry conditions from the attributes
for attr in await_expr.attrs {
let attr_span = attr.span();
if !matches!(attr.style, syn::AttrStyle::Outer) {
return syn::Error::new(
attr_span,
"only outer attributes allowed in suspension-point",
).to_compile_error();
}
if !(attr.path.segments.len() == 1
|| (attr.path.segments.len() == 2 && attr.path.segments[0].ident == "prusti_contracts"))
{
return syn::Error::new(
attr_span,
"invalid attribute in suspension-point",
).to_compile_error();
}
let name = attr.path.segments[attr.path.segments.len() - 1]
.ident
.to_string();
// labels
if name == "label" {
let [proc_macro2::TokenTree::Group(group)] =
&attr.tokens.into_iter().collect::<Vec<_>>()[..]
else {
return syn::Error::new(
attr_span,
"expected group with a single positive integer as label",
).to_compile_error();
};
let [proc_macro2::TokenTree::Literal(lit)] =
&group.stream().into_iter().collect::<Vec<_>>()[..]
else {
return syn::Error::new(
attr_span,
"expected single positive integer as label",
).to_compile_error();
};
let lbl_num: u32 = lit
.to_string()
.parse()
.expect("expected single positive integer as label");
if lbl_num == 0 {
return syn::Error::new(
attr_span,
"suspension-point label must be a positive integer",
).to_compile_error();
}
if label.replace(lbl_num).is_some() {
return syn::Error::new(
attr_span,
"multiple labels provided for suspension-point",
).to_compile_error();
}
// on-exit conditions
} else if name == "on_exit" {
on_exit.push(attr.tokens);
// on-entry conditions
} else if name == "on_entry" {
on_entry.push(attr.tokens);
// all other attributes are not permitted
} else {
// TODO: more precise error message with span?
panic!("invalid attribute in suspension-point");
}
}

let label = label.unwrap_or_else(|| panic!("suspension-point must have a label"));

// generate spec-items for each condition
let create_spec_item = |tokens: TokenStream| -> syn::Result<syn::ExprClosure> {
let expr = parse_prusti(tokens)?;
Ok(parse_quote_spanned! {expr_span=>
|| -> bool {
let val: bool = #expr;
val
}
})
};

let on_exit: syn::Result<Vec<syn::ExprClosure>> = on_exit
.into_iter()
.map(create_spec_item)
.collect();
let on_exit = handle_result!(on_exit);

let on_entry: syn::Result<Vec<syn::ExprClosure>> = on_entry
.into_iter()
.map(create_spec_item)
.collect();
let on_entry = handle_result!(on_entry);

// return just the await-expression
let on_exit_closures = match on_exit.len() {
1 => {
let on_exit = &on_exit[0];
quote_spanned! { expr_span => (#on_exit,) }
}
_ => quote_spanned! { expr_span => (#(#on_exit),*) },
};
let on_entry_closures = match on_entry.len() {
1 => {
let on_entry = &on_entry[0];
quote_spanned! { expr_span => (#on_entry,) }
}
_ => quote_spanned! { expr_span => (#(#on_entry),*) },
};

await_expr.attrs = Vec::new();
quote_spanned! { expr_span =>
{
let fut = #future;
#[allow(unused_parens)]
::prusti_contracts::suspension_point_on_exit_marker(#label, #on_exit_closures);
let res = fut.await;
#[allow(unused_parens)]
::prusti_contracts::suspension_point_on_entry_marker(#label, #on_entry_closures);
res
}
}
}

/// Generates the TokenStream encoding an expression using prusti syntax
/// Used for body invariants, assertions, and assumptions
fn generate_expression_closure(
Expand Down Expand Up @@ -855,6 +1084,8 @@ fn extract_prusti_attributes_for_types(
let tokens = match attr_kind {
SpecAttributeKind::Requires => unreachable!("requires on type"),
SpecAttributeKind::Ensures => unreachable!("ensures on type"),
SpecAttributeKind::AsyncEnsures => unreachable!("ensures on type"),
SpecAttributeKind::AsyncInvariant => unreachable!("async-invariant on type"),
SpecAttributeKind::AfterExpiry => unreachable!("after_expiry on type"),
SpecAttributeKind::AssertOnExpiry => unreachable!("assert_on_expiry on type"),
SpecAttributeKind::RefineSpec => unreachable!("refine_spec on type"),
Expand Down Expand Up @@ -900,6 +1131,8 @@ fn generate_spec_and_assertions_for_types(
let rewriting_result = match attr_kind {
SpecAttributeKind::Requires => unreachable!(),
SpecAttributeKind::Ensures => unreachable!(),
SpecAttributeKind::AsyncEnsures => unreachable!(),
SpecAttributeKind::AsyncInvariant => unreachable!(),
SpecAttributeKind::AfterExpiry => unreachable!(),
SpecAttributeKind::AssertOnExpiry => unreachable!(),
SpecAttributeKind::Pure => unreachable!(),
Expand Down
Loading