Skip to content

Commit f84cfb5

Browse files
committed
fix in module type-checking
1 parent 1c90a79 commit f84cfb5

File tree

3 files changed

+97
-30
lines changed

3 files changed

+97
-30
lines changed

CHANGELOG

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ All notable changes to this project will be docmented in this file.
44

55
## [Unreleased]
66

7+
### Fixed
78

9+
- Fixed some loop holes when type-checking modules against interfaces.
10+
811
## [1.1.0] - 2026-01-27
912

1013
### Added

lib/frontend/typing.ml

Lines changed: 89 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,21 @@ let type_mismatch_error loc exp_ty fnd_ty =
1212
\ %{Type}"
1313
exp_ty fnd_ty)
1414

15-
let arguments_to_string d =
16-
if d = 1 then "one argument" else Printf.sprintf "%d arguments" d
15+
let number_to_string kind d =
16+
if d = 1 then Printf.sprintf "one %s" kind else Printf.sprintf "%d %ss" d kind
1717

1818
let tuple_arg_mismatch_error loc expected =
1919
Error.type_error loc
20-
(Printf.sprintf "Expected tuple with %d components" expected)
20+
(Printf.sprintf "Expected tuple with %s" (number_to_string "component" expected))
2121

22-
let module_arg_mismatch_error loc typ_constr expected =
22+
let arg_mismatch_error kind loc typ_constr expected =
2323
Error.type_error loc
24-
(Printf.sprintf "Module %s expects %s" (Type.to_name typ_constr)
25-
(arguments_to_string expected))
24+
(Printf.sprintf "%s %s expects %s" kind (Type.to_name typ_constr)
25+
(number_to_string "argument" expected))
26+
27+
let param_mismatch_error kind loc id expected =
28+
Error.type_error loc
29+
(Printf.sprintf "%s %s expects %s" kind id (number_to_string "parameter" expected))
2630

2731
let unexpected_functor_error loc =
2832
Error.type_error loc "A functor cannot be instantiated in this context"
@@ -60,13 +64,13 @@ module ProcessTypeExpr = struct
6064
| [ tp_arg ] ->
6165
let+ tp_arg' = process_type_expr tp_arg in
6266
App (constr, [ tp_arg' ], tp_attr)
63-
| _ -> module_arg_mismatch_error (Type.to_loc tp_expr) constr 1)
67+
| _ -> arg_mismatch_error "Constructor" (Type.to_loc tp_expr) constr 1)
6468
| App (Map, tp_list, tp_attr) -> (
6569
match tp_list with
6670
| [ tp1; tp2 ] ->
6771
let+ tp1 = process_type_expr tp1 and+ tp2 = process_type_expr tp2 in
6872
App (Map, [ tp1; tp2 ], tp_attr)
69-
| _ -> module_arg_mismatch_error (Type.to_loc tp_expr) Map 2)
73+
| _ -> arg_mismatch_error "Type" (Type.to_loc tp_expr) Map 2)
7074
| App (Data _, _tp_list, _tp_attr) ->
7175
(* The parser should prevent this from happening. *)
7276
Error.internal_error (Type.to_loc tp_expr)
@@ -2306,20 +2310,32 @@ module ProcessModule = struct
23062310
(Symbol.kind orig_symbol) ident interface_ident
23072311
(Symbol.kind symbol))
23082312

2313+
(** Check that module `mod_ident` (M) implements interface `int_ident` (I) *)
23092314
let check_module_type mod_ident int_ident =
23102315
let open Rewriter.Syntax in
2316+
(* Get qualified idents and symbols of M and I *)
23112317
let+ qual_mod_ident, mod_symbol =
23122318
Rewriter.resolve_and_find mod_ident
2313-
and+ qual_int_ident, _int_symbol =
2319+
and+ qual_int_ident, int_symbol =
23142320
Rewriter.resolve_and_find int_ident
23152321
in
2316-
let interfaces =
2317-
Rewriter.Symbol.extract mod_symbol ~f:(fun _ _subst -> function
2322+
(* Extract all interfaces implemented by M and check whether it is fully instantiated *)
2323+
let interfaces, mod_is_instance =
2324+
Rewriter.Symbol.extract mod_symbol ~f:(fun is_instance subst -> function
23182325
| Ast.Module.ModDef mod_def ->
23192326
(*Set.map (module QualIdent) mod_def.mod_decl.mod_decl_interfaces ~f:subst*)
2320-
mod_def.mod_decl.mod_decl_interfaces
2321-
| _ -> Set.empty (module QualIdent))
2327+
mod_def.mod_decl.mod_decl_interfaces,
2328+
List.is_empty mod_def.mod_decl.mod_decl_formals || is_instance
2329+
| _ -> Set.empty (module QualIdent), true)
23222330
in
2331+
(* Check whether I is fully instantiated *)
2332+
let int_is_instance =
2333+
Rewriter.Symbol.extract int_symbol ~f:(fun is_instance _subst -> function
2334+
| Ast.Module.ModDef mod_def ->
2335+
List.is_empty mod_def.mod_decl.mod_decl_formals || is_instance
2336+
| _ -> true)
2337+
in
2338+
(* Check if I is one of M's interfaces *)
23232339
if
23242340
not
23252341
(QualIdent.(qual_mod_ident = qual_int_ident)
@@ -2331,6 +2347,18 @@ module ProcessModule = struct
23312347
!"%s %{QualIdent} does not implement interface %{QualIdent}"
23322348
(Symbol.kind (Rewriter.Symbol.orig_symbol mod_symbol) |> String.capitalize)
23332349
mod_ident int_ident)
2350+
else if
2351+
(* Make sure that I is the type of M itself rather than the expected type
2352+
of the module obtained by instantiating *)
2353+
int_is_instance && not mod_is_instance
2354+
then
2355+
Error.type_error
2356+
(QualIdent.to_loc mod_ident)
2357+
(Printf.sprintf
2358+
!"%s %{QualIdent} first needs to be instantiated to obtain a module with interface %{QualIdent}"
2359+
(Symbol.kind (Rewriter.Symbol.orig_symbol mod_symbol) |> String.capitalize)
2360+
mod_ident int_ident)
2361+
23342362

23352363
let rec process_module (m : Module.t) : Module.t Rewriter.t =
23362364
let open Rewriter.Syntax in
@@ -2370,40 +2398,44 @@ module ProcessModule = struct
23702398
let+ mod_def = Rewriter.exit_module mod_def in
23712399
Module.ModDef mod_def
23722400
| ModInst mod_inst ->
2401+
(* A functor application `module M : I = F[args]` *)
2402+
(* Get symbol of I *)
23732403
let* mod_inst_type =
23742404
Rewriter.resolve mod_inst.mod_inst_type
23752405
in
23762406
let symbol = Module.ModInst { mod_inst with mod_inst_type } in
2407+
(* Pair up formal parameters of F with arguments `args` *)
23772408
let* to_check =
23782409
Rewriter.Option.map mod_inst.mod_inst_def
23792410
~f:(fun (mod_inst_func, mod_inst_args) ->
23802411
let* _ = Rewriter.declare_symbol symbol in
2412+
(* Get qualified name of F and its symbol *)
23812413
let+ qual_functor_ident, functor_symbol =
23822414
Rewriter.resolve_and_find
23832415
mod_inst_func
23842416
in
2417+
(* Get formal parameters of F *)
23852418
let formals =
23862419
Rewriter.Symbol.extract functor_symbol ~f:(fun is_instance subst ->
23872420
function
23882421
| Ast.Module.ModDef mod_def when not is_instance ->
2389-
Logs.info (fun m -> m !"%{QualIdent}" mod_inst_func);
23902422
List.map mod_def.mod_decl.mod_decl_formals
23912423
~f:(fun mod_inst ->
23922424
subst mod_inst.mod_inst_type)
23932425
| _ -> [])
23942426
in
2427+
(* Pair up `args` and formals *)
23952428
let args_and_formals =
23962429
match List.zip mod_inst_args formals with
23972430
| Ok res -> res
23982431
| Unequal_lengths ->
2399-
Error.type_error (*mod_inst.mod_inst_loc*) (QualIdent.to_loc mod_inst_func)
2400-
(Printf.sprintf
2401-
!"Module %{QualIdent} expects %d arguments"
2402-
mod_inst_func (List.length formals))
2432+
arg_mismatch_error "Module" (QualIdent.to_loc mod_inst_func) (Type.Var mod_inst_func)
2433+
(List.length formals)
24032434
in
24042435
(qual_functor_ident, mod_inst.mod_inst_type) :: args_and_formals)
24052436
in
24062437
let to_check = Option.value to_check ~default:[] in
2438+
(* Check that `args` satisfy module types of formals *)
24072439
let+ _ =
24082440
Rewriter.List.iter to_check ~f:(fun (m, i) ->
24092441
check_module_type m i)
@@ -2496,11 +2528,11 @@ module ProcessModule = struct
24962528
&& (Set.is_empty seen || List.is_empty mod_def)
24972529
then
24982530
(* case: parent_symbol should be inherited now *)
2499-
let _ = Logs.info (fun m -> m !"Inheriting symbol %{Ident}" parent_symbol_ident) in
2531+
let _ = Logs.debug (fun m -> m !"Inheriting symbol %{Ident}" parent_symbol_ident) in
25002532
let parent_symbol_def =
25012533
match parent_symbol.symbol_def with
25022534
| CallDef call when not @@ Callable.is_abstract call ->
2503-
Logs.info (fun m -> m !"Making %{Ident} free." (Callable.to_ident call));
2535+
Logs.debug (fun m -> m !"Making %{Ident} free." (Callable.to_ident call));
25042536
Module.CallDef (Callable.make_free call)
25052537
| CallDef
25062538
({ call_decl = { call_decl_kind = Lemma; _ }; _ } as call)
@@ -2580,6 +2612,7 @@ module ProcessModule = struct
25802612
let* ( mod_decl_returns,
25812613
mod_decl_interfaces,
25822614
interface_ident,
2615+
interface_formals,
25832616
(merged_symbols, symbols_to_check) ) =
25842617
let+ interface_opt =
25852618
Rewriter.Option.map m.mod_decl.mod_decl_returns ~f:(fun mid ->
@@ -2611,23 +2644,24 @@ module ProcessModule = struct
26112644
in
26122645
match interface_opt with
26132646
| Some (qual_interface_ident, interface_ident, ModDef interface) ->
2614-
( Some qual_interface_ident,
2615-
Set.add interface.mod_decl.mod_decl_interfaces qual_interface_ident,
2616-
interface_ident,
2617-
merge_defs qual_interface_ident interface.mod_def m.mod_def )
2647+
( Some qual_interface_ident,
2648+
Set.add interface.mod_decl.mod_decl_interfaces qual_interface_ident,
2649+
interface_ident,
2650+
Some interface.mod_decl.mod_decl_formals,
2651+
merge_defs qual_interface_ident interface.mod_def m.mod_def )
26182652
| _ ->
26192653
let mod_ident = QualIdent.from_ident m.mod_decl.mod_decl_name in
26202654
let interfaces =
26212655
if is_root then m.mod_decl.mod_decl_interfaces
26222656
else Set.add m.mod_decl.mod_decl_interfaces mod_qual_ident
26232657
in
2624-
(None, interfaces, mod_ident, (m.mod_def, Map.empty (module Ident)))
2658+
(None, interfaces, mod_ident, None, (m.mod_def, Map.empty (module Ident)))
26252659
in
26262660

26272661
(*let inherited_symbols = List.rev inherited_symbols in*)
26282662
let mod_def = mod_def_formals @ merged_symbols in
2629-
let _ = Logs.info (fun mm -> mm !"Merged in %{Ident}" (Symbol.to_name (ModDef m))) in
2630-
let _ = List.iter ~f:(function SymbolDef symbol -> Logs.info (fun m -> m !"%{Ident}" (Symbol.to_name symbol.symbol_def)) | _ -> ()) mod_def in
2663+
let _ = Logs.debug (fun mm -> mm !"Merged in %{Ident}" (Symbol.to_name (ModDef m))) in
2664+
let _ = List.iter ~f:(function SymbolDef symbol -> Logs.debug (fun m -> m !"%{Ident}" (Symbol.to_name symbol.symbol_def)) | _ -> ()) mod_def in
26312665
(* Find rep type and add it to module declaration *)
26322666
let mod_decl_rep =
26332667
List.fold_left mod_def ~init:None ~f:(fun rep_type -> function
@@ -2686,6 +2720,33 @@ module ProcessModule = struct
26862720
}
26872721
in
26882722

2723+
(* Make sure that the module preserves the parameters of its interface *)
2724+
let _ =
2725+
let interface_formals = Option.value interface_formals ~default:[] in
2726+
match interface_formals with
2727+
| [] -> ()
2728+
| _ ->
2729+
let res =
2730+
List.iter2 mod_decl.mod_decl_formals interface_formals
2731+
~f:(fun param oparam ->
2732+
if
2733+
Ident.(param.mod_inst_name <> oparam.mod_inst_name)
2734+
|| QualIdent.(param.mod_inst_type <> oparam.mod_inst_type)
2735+
then
2736+
Error.type_error param.mod_inst_loc
2737+
(Printf.sprintf
2738+
!"Parameter %{Ident} of %s %{Ident} does not match declaration of \
2739+
parameter %{Ident} of interface %{QualIdent}"
2740+
param.mod_inst_name (Symbol.kind (ModDef m)) mod_decl.mod_decl_name
2741+
oparam.mod_inst_name interface_ident))
2742+
in
2743+
match res with
2744+
| Ok list -> list
2745+
| Unequal_lengths ->
2746+
param_mismatch_error "Interface" (Ident.to_loc mod_decl.mod_decl_name)
2747+
(QualIdent.to_string interface_ident) (List.length interface_formals)
2748+
in
2749+
26892750
let* _ =
26902751
Rewriter.List.iter mod_def ~f:(function
26912752
| Module.SymbolDef {symbol_def = ModInst { mod_inst_def = Some _; _ }; _} | Module.Import _ -> Rewriter.return ()
@@ -2695,7 +2756,7 @@ module ProcessModule = struct
26952756
(* Check and rewrite all symbols *)
26962757
let* mod_def = Rewriter.List.map merged_symbols ~f:process_instr in
26972758

2698-
(* Check symbols against interface *)
2759+
(* Check symbols against what is specified in the interface *)
26992760
let* _ =
27002761
Rewriter.List.iter mod_def ~f:(function
27012762
| SymbolDef symbol ->

test/arrays/array_utils.rav

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// This file is just the concatenation of array.rav and ordered_array.rav, for benchmarking purposes
22

3-
interface Array[E: Library.Type] {
3+
interface Array {
4+
module E: Library.Type
45
rep type T
56

67
func loc(a: T, i: Int) returns (r: Ref)
@@ -227,7 +228,9 @@ interface OrderedArray[E: Library.OrderedType] : Array {
227228
i <= j && sorted_map_seg(m, i, j) ==>
228229
k !in set_of_map(m, i, map_find(m, i, j, k)) &&
229230
k !in set_of_map(m, map_find(m, i, j, k) + 1, j)
230-
{ }
231+
{
232+
invertable2();
233+
}
231234

232235
lemma map_insert_content_set(m: Map[Int, E], m1: Map[Int, E], idx: Int, k: E, len: Int, new_len: Int)
233236
requires sorted_map_seg(m, 0, len)

0 commit comments

Comments
 (0)