diff --git a/pulse/src/ml/pulseparser.mly b/pulse/src/ml/pulseparser.mly index 870df800a92..6e86b3f3f38 100644 --- a/pulse/src/ml/pulseparser.mly +++ b/pulse/src/ml/pulseparser.mly @@ -68,6 +68,48 @@ let add_decorations decors ds = | Inl p -> Inl (PulseSyntaxExtension_Sugar.add_decorations p decors) | Inr d -> Inr (FStarC_Parser_AST.add_decorations d decors)) ds +(* Build an F* AST term for a Pulse fn type expression. + fn (x:t1) (y:t2) requires pre ensures post + becomes: (x:t1) -> (y:t2) -> stt ret_type pre (fun ret_name -> post) +*) +let build_fn_type_term (binders : FStarC_Parser_AST.binder list list) (comp : PulseSyntaxExtension_Sugar.computation_type) r : FStarC_Parser_AST.term = + let open PulseSyntaxExtension_Sugar in + let flat_bs = List.flatten binders in + let annots = List.map fst comp.annots in + let star op t1 t2 = mk_term (Op (FStarC_Ident.mk_ident (op, r), [t1; t2])) r Un in + let star_join terms = + match terms with + | [] -> mk_term (Var (FStarC_Ident.lid_of_ids [FStarC_Ident.mk_ident("emp", r)])) r Un + | [t] -> t + | t :: rest -> List.fold_left (star "**") t rest + in + let requires = List.filter_map (function Requires t | Preserves t -> Some t | _ -> None) annots in + let ensures = List.filter_map (function Ensures t | Preserves t -> Some t | _ -> None) annots in + let ret_info = List.find_map (function Returns (id_opt, ty) -> Some (id_opt, ty) | _ -> None) annots in + let opens = List.find_map (function Opens t -> Some t | _ -> None) annots in + let ret_ty = match ret_info with Some (_, ty) -> ty | None -> mk_term (Var (FStarC_Ident.lid_of_ids [FStarC_Ident.mk_ident("unit", r)])) r Un in + let ret_name = match ret_info with Some (Some id, _) -> id | _ -> FStarC_Ident.mk_ident("_", r) in + let pre = star_join requires in + let post = star_join ensures in + let post_pat = mk_pattern (PatVar (ret_name, None, [])) r in + let post_lam = mk_term (Abs ([post_pat], post)) r Un in + let stt_name = match comp.tag with + | ST -> "stt" | STGhost -> "stt_ghost" + | STAtomic -> "stt_atomic" | STUnobservable -> "stt_unobservable" + in + let stt_var = mk_term (Var (FStarC_Ident.lid_of_ids [FStarC_Ident.mk_ident(stt_name, r)])) r Un in + let stt_app = match comp.tag with + | ST -> + mkApp stt_var [(ret_ty, Nothing); (pre, Nothing); (post_lam, Nothing)] r + | STGhost | STAtomic | STUnobservable -> + let inames = match opens with + | Some t -> t + | None -> mk_term (Var (FStarC_Ident.lid_of_ids [FStarC_Ident.mk_ident("emp_inames", r)])) r Un + in + mkApp stt_var [(ret_ty, Nothing); (inames, Nothing); (pre, Nothing); (post_lam, Nothing)] r + in + mk_term (Product (flat_bs, stt_app)) r Type_level + %} /* pulse specific tokens; rest are inherited from F* */ @@ -124,6 +166,18 @@ declBody: | p=pulseDecl { [Inl p] } | d=decoratableDecl { List.map (fun x -> Inr x) d } +(* Extend F*'s typeDefinition to support fn type abbreviations: + type name = fn comp *) +%public +typeDefinition: + | EQUALS q=qualOptFn fn_params=list(multiBinder) + ascription=pulseComputationType + { + let comp = with_computation_tag ascription q in + let body = build_fn_type_term fn_params comp (rr $loc) in + (fun id binders kopt -> TyconAbbrev(id, binders, kopt, body)) + } + pulseDecl: | q=qualOptFn (* workaround what seems to be a menhir bug *) isRec=maybeRec lid=lidentOrOperator us=univParams bs=pulseBinderList @@ -187,6 +241,27 @@ pulseDeclEOF: p } +(* fn type as a term: fn (x:int) requires p ensures q + We extend typ and simpleArrow to support fn type syntax in all type + positions. typ covers type definitions, val return types, record + fields, etc. simpleArrow is needed separately because binder + annotations (x : simpleArrow) do not go through typ. *) +%public +typ: + | q=qualOptFn bs=list(multiBinder) comp=pulseComputationType + { + let comp = with_computation_tag comp q in + build_fn_type_term bs comp (rr $loc) + } + +%public +simpleArrow: + | q=qualOptFn bs=list(multiBinder) comp=pulseComputationType + { + let comp = with_computation_tag comp q in + build_fn_type_term bs comp (rr $loc) + } + pulseBinderList: | bs=list(multiBinder) { bs } diff --git a/pulse/test/PulseFnTerms.fst b/pulse/test/PulseFnTerms.fst new file mode 100644 index 00000000000..7ffd9676ba7 --- /dev/null +++ b/pulse/test/PulseFnTerms.fst @@ -0,0 +1,278 @@ +module PulseFnTerms +open Pulse.Nolib +#lang-pulse + +(* ================================================================ + Tests for fn-type-as-term syntax. + + The fn type can appear in type positions: binder annotations + and type abbreviation RHS. Syntax: + + fn (x:t) ... requires pre ensures post + ghost fn (x:t) ... requires pre ensures post + atomic fn (x:t) ... requires pre ensures post + + This reuses F*'s existing (x : type) binder rule by extending + simpleArrow with fn-type productions. + ================================================================ *) + +(* -------------------------------------------------------------- *) +(* 1. Type abbreviation — fn type as RHS *) +(* -------------------------------------------------------------- *) + +type unit_action = fn (_u : unit) requires emp ensures emp + +(* -------------------------------------------------------------- *) +(* 2. Type abbreviation with meaningful binders *) +(* -------------------------------------------------------------- *) + +assume val my_res : nat -> slprop + +type my_action = fn (v : nat) requires emp ensures my_res v + +(* -------------------------------------------------------------- *) +(* 3. fn type in binder position — higher-order function *) +(* -------------------------------------------------------------- *) + +fn call_unit_action (f : fn (_u : unit) requires emp ensures emp) + requires emp + ensures emp +{ + f () +} + +fn call_unit_action' (f : unit_action) + requires emp + ensures emp +{ + f () +} + +(* -------------------------------------------------------------- *) +(* 4. fn type binder with parameters *) +(* -------------------------------------------------------------- *) + +fn apply_action (f : fn (v : nat) requires emp ensures my_res v) + requires emp + ensures my_res 42 +{ + f 42 +} + +(* -------------------------------------------------------------- *) +(* 5. Multiple fn-typed binders *) +(* -------------------------------------------------------------- *) + +fn compose_actions + (f : fn (_u : unit) requires emp ensures emp) + (g : fn (_u : unit) requires emp ensures emp) + requires emp + ensures emp +{ + f (); + g () +} + +(* -------------------------------------------------------------- *) +(* 6. ghost fn type in binder and type abbreviation *) +(* -------------------------------------------------------------- *) + +type ghost_unit_action = ghost fn (_u : unit) requires emp ensures emp + +fn call_ghost (f : ghost fn (_u : unit) requires emp ensures emp) + requires emp + ensures emp +{ + f () +} + +(* -------------------------------------------------------------- *) +(* 7. Using a type abbreviation as a binder type *) +(* -------------------------------------------------------------- *) + +fn call_via_abbrev (f : unit_action) + requires emp + ensures emp +{ + f () +} + +(* -------------------------------------------------------------- *) +(* 8. fn type with preserves annotation *) +(* -------------------------------------------------------------- *) + +type preserving_action (p : slprop) = fn (_u : unit) preserves p + +fn call_preserving (p : slprop) (f : fn (_u : unit) preserves p) + requires p + ensures p +{ + f () +} + +(* -------------------------------------------------------------- *) +(* 9. fn type with multiple binders *) +(* -------------------------------------------------------------- *) + +fn apply_two_arg (f : fn (x : nat) (y : nat) requires emp ensures my_res (x + y)) + requires emp + ensures my_res 42 +{ + f 20 22 +} + +(* -------------------------------------------------------------- *) +(* 10. atomic fn type — abbreviation and binder *) +(* -------------------------------------------------------------- *) + +type atomic_unit_action = atomic fn (_u : unit) requires emp ensures emp + +fn call_atomic (f : atomic fn (_u : unit) requires emp ensures emp) + requires emp + ensures emp +{ + f () +} + +(* -------------------------------------------------------------- *) +(* 11. atomic fn type with opens *) +(* -------------------------------------------------------------- *) + +type atomic_action_opens (is : inames) = + atomic fn (_u : unit) opens is requires emp ensures emp + +fn call_atomic_opens + (is : inames) + (f : atomic fn (_u : unit) opens is requires emp ensures emp) + requires emp + ensures emp +{ + f () +} + +(* -------------------------------------------------------------- *) +(* 12. ghost fn type with opens *) +(* -------------------------------------------------------------- *) + +type ghost_action_opens (is : inames) = + ghost fn (_u : unit) opens is requires emp ensures emp + +fn call_ghost_opens + (is : inames) + (f : ghost fn (_u : unit) opens is requires emp ensures emp) + requires emp + ensures emp +{ + f () +} + +(* -------------------------------------------------------------- *) +(* 13. fn type with returns annotation *) +(* -------------------------------------------------------------- *) + +type returning_action = fn (_u : unit) requires emp returns v : nat ensures my_res v + +fn call_returning (f : fn (_u : unit) requires emp returns v : nat ensures my_res v) + requires emp + ensures (exists* v. my_res v) +{ + let v = f (); + () +} + +(* -------------------------------------------------------------- *) +(* 14. atomic fn with opens and returns *) +(* -------------------------------------------------------------- *) + +type atomic_returning (is : inames) = + atomic fn (_u : unit) opens is requires emp returns v : nat ensures my_res v + +fn call_atomic_returning + (is : inames) + (f : atomic fn (_u : unit) opens is requires emp returns v : nat ensures my_res v) + requires emp + ensures (exists* v. my_res v) +{ + let v = f (); + () +} + +atomic +fn call_atomic_returning' + (is : inames) + (f : atomic fn (_u : unit) opens is requires emp returns v : nat ensures my_res v) + requires emp + ensures (exists* v. my_res v) + opens is +{ + let v = f (); + () +} + +(* -------------------------------------------------------------- *) +(* 15. bare fn type — no computation annotations *) +(* -------------------------------------------------------------- *) + +// This is perhaps weird, but it returns unit with emp pre/post +type bare_fn_type = fn (x : nat) (y : nat) + +fn call_bare (f : fn (x : nat)) + requires emp + ensures emp +{ + f 1; + f 2; +} + +// Perhaps even weirder, this is exactly `stt unit emp (fun _ -> emp)` +type x = fn + +(* -------------------------------------------------------------- *) +(* 16. fn type in record field position *) +(* -------------------------------------------------------------- *) + +noeq +type record_with_fn = { + action: fn (_u : unit) requires emp ensures emp; + getter: fn (_u : unit) requires emp returns v : nat ensures my_res v; +} + +fn call_record_action (r : record_with_fn) + requires emp + ensures emp +{ + let f = r.action; + f () +} + +fn call_record_getter (r : record_with_fn) + requires emp + ensures (exists* v. my_res v) +{ + let f = r.getter; + let v = f (); + () +} + +(* -------------------------------------------------------------- *) +(* 17. fn type in val declaration *) +(* -------------------------------------------------------------- *) + +assume val val_unit_action : fn (_u : unit) requires emp ensures emp + +assume val val_returning : fn (_u : unit) requires emp returns v : nat ensures my_res v + +fn use_val_unit_action (_u : unit) + requires emp + ensures emp +{ + val_unit_action () +} + +fn use_val_returning (_u : unit) + requires emp + ensures (exists* v. my_res v) +{ + let v = val_returning (); + () +} diff --git a/src/ml/FStarC_Parser_Parse.mly b/src/ml/FStarC_Parser_Parse.mly index e78fc7b3279..27516250ecb 100644 --- a/src/ml/FStarC_Parser_Parse.mly +++ b/src/ml/FStarC_Parser_Parse.mly @@ -519,6 +519,7 @@ typars: | LBRACE record_field_decls=right_flexible_nonempty_list(SEMICOLON, recordFieldDecl) RBRACE { record_field_decls } +%public typeDefinition: | { (fun id binders kopt -> check_id id; TyconAbstract(id, binders, kopt)) } | EQUALS t=typ @@ -890,6 +891,7 @@ path(Id): | id=Id { [id] } | uid=uident DOT p=path(Id) { uid::p } +%public ident: | x=lident { x } | x=uident { x } @@ -1201,7 +1203,7 @@ calcStep: CalcStep (rel, justif, next) } -%inline +%public typ: | t=simpleTerm { t } @@ -1279,6 +1281,7 @@ tmArrow(Tm): } | e=Tm { e } +%public simpleArrow: | dom=simpleArrowDomain RARROW tgt=simpleArrow {