@@ -12,17 +12,21 @@ let type_mismatch_error loc exp_ty fnd_ty =
1212 \ %{Type}"
1313 exp_ty fnd_ty)
1414
15- let arguments_to_string d =
16- if d = 1 then " one argument " else Printf. sprintf " %d arguments " d
15+ let number_to_string kind d =
16+ if d = 1 then Printf. sprintf " one %s " kind else Printf. sprintf " %d %ss " d kind
1717
1818let tuple_arg_mismatch_error loc expected =
1919 Error. type_error loc
20- (Printf. sprintf " Expected tuple with %d components " expected)
20+ (Printf. sprintf " Expected tuple with %s " (number_to_string " component " expected) )
2121
22- let module_arg_mismatch_error loc typ_constr expected =
22+ let arg_mismatch_error kind loc typ_constr expected =
2323 Error. type_error loc
24- (Printf. sprintf " Module %s expects %s" (Type. to_name typ_constr)
25- (arguments_to_string expected))
24+ (Printf. sprintf " %s %s expects %s" kind (Type. to_name typ_constr)
25+ (number_to_string " argument" expected))
26+
27+ let param_mismatch_error kind loc id expected =
28+ Error. type_error loc
29+ (Printf. sprintf " %s %s expects %s" kind id (number_to_string " parameter" expected))
2630
2731let unexpected_functor_error loc =
2832 Error. type_error loc " A functor cannot be instantiated in this context"
@@ -60,13 +64,13 @@ module ProcessTypeExpr = struct
6064 | [ tp_arg ] ->
6165 let + tp_arg' = process_type_expr tp_arg in
6266 App (constr, [ tp_arg' ], tp_attr)
63- | _ -> module_arg_mismatch_error (Type. to_loc tp_expr) constr 1 )
67+ | _ -> arg_mismatch_error " Constructor " (Type. to_loc tp_expr) constr 1 )
6468 | App (Map, tp_list , tp_attr ) -> (
6569 match tp_list with
6670 | [ tp1; tp2 ] ->
6771 let + tp1 = process_type_expr tp1 and + tp2 = process_type_expr tp2 in
6872 App (Map , [ tp1; tp2 ], tp_attr)
69- | _ -> module_arg_mismatch_error (Type. to_loc tp_expr) Map 2 )
73+ | _ -> arg_mismatch_error " Type " (Type. to_loc tp_expr) Map 2 )
7074 | App (Data _ , _tp_list , _tp_attr ) ->
7175 (* The parser should prevent this from happening. *)
7276 Error. internal_error (Type. to_loc tp_expr)
@@ -2306,20 +2310,32 @@ module ProcessModule = struct
23062310 (Symbol. kind orig_symbol) ident interface_ident
23072311 (Symbol. kind symbol))
23082312
2313+ (* * Check that module `mod_ident` (M) implements interface `int_ident` (I) *)
23092314 let check_module_type mod_ident int_ident =
23102315 let open Rewriter.Syntax in
2316+ (* Get qualified idents and symbols of M and I *)
23112317 let + qual_mod_ident, mod_symbol =
23122318 Rewriter. resolve_and_find mod_ident
2313- and + qual_int_ident, _int_symbol =
2319+ and + qual_int_ident, int_symbol =
23142320 Rewriter. resolve_and_find int_ident
23152321 in
2316- let interfaces =
2317- Rewriter.Symbol. extract mod_symbol ~f: (fun _ _subst -> function
2322+ (* Extract all interfaces implemented by M and check whether it is fully instantiated *)
2323+ let interfaces, mod_is_instance =
2324+ Rewriter.Symbol. extract mod_symbol ~f: (fun is_instance subst -> function
23182325 | Ast.Module. ModDef mod_def ->
23192326 (* Set.map (module QualIdent) mod_def.mod_decl.mod_decl_interfaces ~f:subst*)
2320- mod_def.mod_decl.mod_decl_interfaces
2321- | _ -> Set. empty (module QualIdent ))
2327+ mod_def.mod_decl.mod_decl_interfaces,
2328+ List. is_empty mod_def.mod_decl.mod_decl_formals || is_instance
2329+ | _ -> Set. empty (module QualIdent ), true )
23222330 in
2331+ (* Check whether I is fully instantiated *)
2332+ let int_is_instance =
2333+ Rewriter.Symbol. extract int_symbol ~f: (fun is_instance _subst -> function
2334+ | Ast.Module. ModDef mod_def ->
2335+ List. is_empty mod_def.mod_decl.mod_decl_formals || is_instance
2336+ | _ -> true )
2337+ in
2338+ (* Check if I is one of M's interfaces *)
23232339 if
23242340 not
23252341 (QualIdent. (qual_mod_ident = qual_int_ident)
@@ -2331,6 +2347,18 @@ module ProcessModule = struct
23312347 ! " %s %{QualIdent} does not implement interface %{QualIdent}"
23322348 (Symbol. kind (Rewriter.Symbol. orig_symbol mod_symbol) |> String. capitalize)
23332349 mod_ident int_ident)
2350+ else if
2351+ (* Make sure that I is the type of M itself rather than the expected type
2352+ of the module obtained by instantiating *)
2353+ int_is_instance && not mod_is_instance
2354+ then
2355+ Error. type_error
2356+ (QualIdent. to_loc mod_ident)
2357+ (Printf. sprintf
2358+ ! " %s %{QualIdent} first needs to be instantiated to obtain a module with interface %{QualIdent}"
2359+ (Symbol. kind (Rewriter.Symbol. orig_symbol mod_symbol) |> String. capitalize)
2360+ mod_ident int_ident)
2361+
23342362
23352363 let rec process_module (m : Module.t ) : Module.t Rewriter.t =
23362364 let open Rewriter.Syntax in
@@ -2370,40 +2398,44 @@ module ProcessModule = struct
23702398 let + mod_def = Rewriter. exit_module mod_def in
23712399 Module. ModDef mod_def
23722400 | ModInst mod_inst ->
2401+ (* A functor application `module M : I = F[args]` *)
2402+ (* Get symbol of I *)
23732403 let * mod_inst_type =
23742404 Rewriter. resolve mod_inst.mod_inst_type
23752405 in
23762406 let symbol = Module. ModInst { mod_inst with mod_inst_type } in
2407+ (* Pair up formal parameters of F with arguments `args` *)
23772408 let * to_check =
23782409 Rewriter.Option. map mod_inst.mod_inst_def
23792410 ~f: (fun (mod_inst_func , mod_inst_args ) ->
23802411 let * _ = Rewriter. declare_symbol symbol in
2412+ (* Get qualified name of F and its symbol *)
23812413 let + qual_functor_ident, functor_symbol =
23822414 Rewriter. resolve_and_find
23832415 mod_inst_func
23842416 in
2417+ (* Get formal parameters of F *)
23852418 let formals =
23862419 Rewriter.Symbol. extract functor_symbol ~f: (fun is_instance subst ->
23872420 function
23882421 | Ast.Module. ModDef mod_def when not is_instance ->
2389- Logs. info (fun m -> m ! " %{QualIdent}" mod_inst_func);
23902422 List. map mod_def.mod_decl.mod_decl_formals
23912423 ~f: (fun mod_inst ->
23922424 subst mod_inst.mod_inst_type)
23932425 | _ -> [] )
23942426 in
2427+ (* Pair up `args` and formals *)
23952428 let args_and_formals =
23962429 match List. zip mod_inst_args formals with
23972430 | Ok res -> res
23982431 | Unequal_lengths ->
2399- Error. type_error (* mod_inst.mod_inst_loc*) (QualIdent. to_loc mod_inst_func)
2400- (Printf. sprintf
2401- ! " Module %{QualIdent} expects %d arguments"
2402- mod_inst_func (List. length formals))
2432+ arg_mismatch_error " Module" (QualIdent. to_loc mod_inst_func) (Type. Var mod_inst_func)
2433+ (List. length formals)
24032434 in
24042435 (qual_functor_ident, mod_inst.mod_inst_type) :: args_and_formals)
24052436 in
24062437 let to_check = Option. value to_check ~default: [] in
2438+ (* Check that `args` satisfy module types of formals *)
24072439 let + _ =
24082440 Rewriter.List. iter to_check ~f: (fun (m , i ) ->
24092441 check_module_type m i)
@@ -2496,11 +2528,11 @@ module ProcessModule = struct
24962528 && (Set. is_empty seen || List. is_empty mod_def)
24972529 then
24982530 (* case: parent_symbol should be inherited now *)
2499- let _ = Logs. info (fun m -> m ! " Inheriting symbol %{Ident}" parent_symbol_ident) in
2531+ let _ = Logs. debug (fun m -> m ! " Inheriting symbol %{Ident}" parent_symbol_ident) in
25002532 let parent_symbol_def =
25012533 match parent_symbol.symbol_def with
25022534 | CallDef call when not @@ Callable. is_abstract call ->
2503- Logs. info (fun m -> m ! " Making %{Ident} free." (Callable. to_ident call));
2535+ Logs. debug (fun m -> m ! " Making %{Ident} free." (Callable. to_ident call));
25042536 Module. CallDef (Callable. make_free call)
25052537 | CallDef
25062538 ({ call_decl = { call_decl_kind = Lemma ; _ }; _ } as call)
@@ -2580,6 +2612,7 @@ module ProcessModule = struct
25802612 let * ( mod_decl_returns,
25812613 mod_decl_interfaces,
25822614 interface_ident,
2615+ interface_formals,
25832616 (merged_symbols, symbols_to_check) ) =
25842617 let + interface_opt =
25852618 Rewriter.Option. map m.mod_decl.mod_decl_returns ~f: (fun mid ->
@@ -2611,23 +2644,24 @@ module ProcessModule = struct
26112644 in
26122645 match interface_opt with
26132646 | Some (qual_interface_ident , interface_ident , ModDef interface ) ->
2614- ( Some qual_interface_ident,
2615- Set. add interface.mod_decl.mod_decl_interfaces qual_interface_ident,
2616- interface_ident,
2617- merge_defs qual_interface_ident interface.mod_def m.mod_def )
2647+ ( Some qual_interface_ident,
2648+ Set. add interface.mod_decl.mod_decl_interfaces qual_interface_ident,
2649+ interface_ident,
2650+ Some interface.mod_decl.mod_decl_formals,
2651+ merge_defs qual_interface_ident interface.mod_def m.mod_def )
26182652 | _ ->
26192653 let mod_ident = QualIdent. from_ident m.mod_decl.mod_decl_name in
26202654 let interfaces =
26212655 if is_root then m.mod_decl.mod_decl_interfaces
26222656 else Set. add m.mod_decl.mod_decl_interfaces mod_qual_ident
26232657 in
2624- (None , interfaces, mod_ident, (m.mod_def, Map. empty (module Ident )))
2658+ (None , interfaces, mod_ident, None , (m.mod_def, Map. empty (module Ident )))
26252659 in
26262660
26272661 (* let inherited_symbols = List.rev inherited_symbols in*)
26282662 let mod_def = mod_def_formals @ merged_symbols in
2629- let _ = Logs. info (fun mm -> mm ! " Merged in %{Ident}" (Symbol. to_name (ModDef m))) in
2630- let _ = List. iter ~f: (function SymbolDef symbol -> Logs. info (fun m -> m ! " %{Ident}" (Symbol. to_name symbol.symbol_def)) | _ -> () ) mod_def in
2663+ let _ = Logs. debug (fun mm -> mm ! " Merged in %{Ident}" (Symbol. to_name (ModDef m))) in
2664+ let _ = List. iter ~f: (function SymbolDef symbol -> Logs. debug (fun m -> m ! " %{Ident}" (Symbol. to_name symbol.symbol_def)) | _ -> () ) mod_def in
26312665 (* Find rep type and add it to module declaration *)
26322666 let mod_decl_rep =
26332667 List. fold_left mod_def ~init: None ~f: (fun rep_type -> function
@@ -2686,6 +2720,33 @@ module ProcessModule = struct
26862720 }
26872721 in
26882722
2723+ (* Make sure that the module preserves the parameters of its interface *)
2724+ let _ =
2725+ let interface_formals = Option. value interface_formals ~default: [] in
2726+ match interface_formals with
2727+ | [] -> ()
2728+ | _ ->
2729+ let res =
2730+ List. iter2 mod_decl.mod_decl_formals interface_formals
2731+ ~f: (fun param oparam ->
2732+ if
2733+ Ident. (param.mod_inst_name <> oparam.mod_inst_name)
2734+ || QualIdent. (param.mod_inst_type <> oparam.mod_inst_type)
2735+ then
2736+ Error. type_error param.mod_inst_loc
2737+ (Printf. sprintf
2738+ ! " Parameter %{Ident} of %s %{Ident} does not match declaration of \
2739+ parameter %{Ident} of interface %{QualIdent}"
2740+ param.mod_inst_name (Symbol. kind (ModDef m)) mod_decl.mod_decl_name
2741+ oparam.mod_inst_name interface_ident))
2742+ in
2743+ match res with
2744+ | Ok list -> list
2745+ | Unequal_lengths ->
2746+ param_mismatch_error " Interface" (Ident. to_loc mod_decl.mod_decl_name)
2747+ (QualIdent. to_string interface_ident) (List. length interface_formals)
2748+ in
2749+
26892750 let * _ =
26902751 Rewriter.List. iter mod_def ~f: (function
26912752 | Module. SymbolDef {symbol_def = ModInst { mod_inst_def = Some _ ; _ } ; _} | Module. Import _ -> Rewriter. return ()
@@ -2695,7 +2756,7 @@ module ProcessModule = struct
26952756 (* Check and rewrite all symbols *)
26962757 let * mod_def = Rewriter.List. map merged_symbols ~f: process_instr in
26972758
2698- (* Check symbols against interface *)
2759+ (* Check symbols against what is specified in the interface *)
26992760 let * _ =
27002761 Rewriter.List. iter mod_def ~f: (function
27012762 | SymbolDef symbol ->
0 commit comments