Skip to content

Commit 0b8091e

Browse files
Adds Hindley-Milner type resolution
1 parent 5c83d14 commit 0b8091e

19 files changed

+361
-34
lines changed

src/ast.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ type literal =
1919
| FloatLit of float
2020
| StringLit of string
2121

22-
type typ = Int | Float | Bool | String | Dyn | Arr | Object | FuncType | Null
22+
type typ = Int | Float | Bool | String | Dyn | Arr | Object | FuncType | Null | TVar of int
2323

2424
type bind = Bind of string * typ
2525

@@ -91,6 +91,7 @@ let rec string_of_typ = function
9191
| FuncType -> "func"
9292
| Object -> "object"
9393
| Null -> "null"
94+
| TVar n -> "'" ^ string_of_int n
9495

9596
let rec string_of_bind = function
9697
| Bind(s, t) -> s ^ ": " ^ string_of_typ t

src/coral.ml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
open Ast
22
open Sast
33
open Utilities
4+
open Infer
45

56
(* coral.ml: the main compiler file for the Coral Programming Language. coral.ml handles command line
67
parsing, generating and interpretering the compiler and interpreter, handling tab-based indentation,
@@ -344,7 +345,7 @@ let rec from_console map past run =
344345
let imported_program = parse_imports program in
345346
let after_program = strip_after [] imported_program in
346347

347-
let (sast, map') = (Semant.check [] [] { forloop = false; inclass = false; cond = false; noeval = false; stack = TypeMap.empty; func = false; locals = map; globals = map; } after_program) in (* temporarily here to check validity of SAST *)
348+
let (sast, map') = (Semant.check [] [] { forloop = false; inclass = false; cond = false; noeval = false; stack = TypeMap.empty; func = false; locals = map; globals = map; subst = Infer.empty_subst; } after_program) in (* temporarily here to check validity of SAST *)
348349
let (sast, globals) = sast in
349350
let sast = (strip_return [] sast, globals) in
350351
let _ = if !debug then print_endline ("Parser: \n\n" ^ (string_of_sprogram sast)) in (* print debug messages *)
@@ -375,7 +376,7 @@ let rec from_file map fname run = (* todo combine with loop *)
375376
let imported_program = parse_imports program in
376377
let after_program = strip_after [] imported_program in
377378

378-
let (sast, map') = (Semant.check [] [] { forloop = false; inclass = false; cond = false; noeval = false; stack = TypeMap.empty; func = false; globals = map; locals = map; } after_program) in (* temporarily here to check validity of SAST *)
379+
let (sast, map') = (Semant.check [] [] { forloop = false; inclass = false; cond = false; noeval = false; stack = TypeMap.empty; func = false; globals = map; locals = map; subst = Infer.empty_subst; } after_program) in (* temporarily here to check validity of SAST *)
379380
let (sast, globals) = sast in
380381
let sast = (strip_return [] sast, globals) in
381382
let () = if !debug then print_endline ("Parser: \n\n" ^ (string_of_sprogram sast)); flush stdout; in (* print debug messages *)

src/infer.ml

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
open Ast
2+
3+
(* Hindley-Milner type inference support for Coral.
4+
This module provides type variables, substitutions, and unification. *)
5+
6+
(* Type variable map: maps type variable IDs to their resolved types *)
7+
module TVarMap = Map.Make(struct type t = int let compare = compare end)
8+
9+
type substitution = typ TVarMap.t
10+
11+
(* Fresh type variable generation *)
12+
let tvar_counter = ref 0
13+
14+
let fresh_var () =
15+
incr tvar_counter;
16+
TVar(!tvar_counter)
17+
18+
let reset_counter () =
19+
tvar_counter := 0
20+
21+
let empty_subst : substitution = TVarMap.empty
22+
23+
(* Check if a type variable occurs in a type (prevents infinite types) *)
24+
let rec occurs (var : int) (t : typ) : bool =
25+
match t with
26+
| TVar v -> v = var
27+
| Int | Float | Bool | String | Dyn | Arr | Object | FuncType | Null -> false
28+
29+
(* Apply a substitution to a type, recursively resolving type variables *)
30+
let rec apply_subst (subst : substitution) (t : typ) : typ =
31+
match t with
32+
| TVar id ->
33+
(match TVarMap.find_opt id subst with
34+
| None -> TVar id (* Unresolved, keep as-is *)
35+
| Some t' -> apply_subst subst t') (* Recursively apply in case t' has vars *)
36+
| Int | Float | Bool | String | Dyn | Arr | Object | FuncType | Null -> t
37+
38+
(* Extend a substitution by mapping var to t *)
39+
let extend_subst (subst : substitution) (var : int) (t : typ) : substitution =
40+
TVarMap.add var t subst
41+
42+
(* Compose two substitutions: apply s1 first, then s2 *)
43+
let compose_subst (s1 : substitution) (s2 : substitution) : substitution =
44+
(* Apply s1 to all types in s2 *)
45+
let s2_after_s1 = TVarMap.map (apply_subst s1) s2 in
46+
(* Merge: s1 takes precedence for common keys *)
47+
TVarMap.merge (fun _ v1 v2 ->
48+
match v1, v2 with
49+
| Some x, _ -> Some x
50+
| None, Some x -> Some x
51+
| None, None -> None
52+
) s1 s2_after_s1
53+
54+
exception UnificationError of string
55+
56+
(* Unify two types, returning an updated substitution *)
57+
let rec unify (t1 : typ) (t2 : typ) (subst : substitution) : substitution =
58+
(* Apply current substitution first *)
59+
let t1' = apply_subst subst t1 in
60+
let t2' = apply_subst subst t2 in
61+
62+
match t1', t2' with
63+
(* Same types: no change needed *)
64+
| t1, t2 when t1 = t2 -> subst
65+
66+
(* Type variable on one side *)
67+
| TVar v, t | t, TVar v ->
68+
(* Occurs check: prevent infinite types like 'a = list['a] *)
69+
if occurs v t then
70+
raise (UnificationError
71+
(Printf.sprintf "Cannot unify: occurs check failed for TVar(%d)" v))
72+
else
73+
extend_subst subst v t
74+
75+
(* Dyn (dynamic type) unifies with anything - gradual typing *)
76+
| Dyn, _ | _, Dyn -> subst
77+
78+
(* All other cases: types must match exactly *)
79+
| t1, t2 ->
80+
raise (UnificationError
81+
(Printf.sprintf "Cannot unify %s with %s"
82+
(string_of_typ t1) (string_of_typ t2)))
83+
84+
(* Replace all unresolved type variables with Dyn *)
85+
let finalize_type (subst : substitution) (t : typ) : typ =
86+
let t' = apply_subst subst t in
87+
match t' with
88+
| TVar _ -> Dyn (* Unresolved type variable becomes Dyn *)
89+
| _ -> t'
90+
91+
(* Helper to check if a type is a type variable *)
92+
let is_tvar = function
93+
| TVar _ -> true
94+
| _ -> false
95+
96+
(* Get the ID from a type variable, or None *)
97+
let get_tvar_id = function
98+
| TVar id -> Some id
99+
| _ -> None
100+
101+
(* Apply substitution to a bind *)
102+
let apply_subst_bind (subst : substitution) (Bind(name, t) : Ast.bind) : Ast.bind =
103+
Bind(name, finalize_type subst t)
104+
105+
(* SAST transformation functions - apply substitution to resolve TVars *)
106+
open Sast
107+
108+
(* Check if an sstmt is an SStage (dynamic call) *)
109+
let is_sstage = function
110+
| SStage(_, _, _) -> true
111+
| _ -> false
112+
113+
let rec apply_subst_sexpr (subst : substitution) ((sexp, t) : sexpr) : sexpr =
114+
(* For SCall with SStage (dynamic calling), always use Dyn as the result type *)
115+
let resolved_t = match sexp with
116+
| SCall(_, _, s) when is_sstage s -> Dyn (* Dynamic calls return boxed Dyn *)
117+
| _ -> finalize_type subst t
118+
in
119+
(apply_subst_sexp subst sexp, resolved_t)
120+
121+
and apply_subst_sexp (subst : substitution) (sexp : sexp) : sexp =
122+
match sexp with
123+
| SBinop(e1, op, e2) -> SBinop(apply_subst_sexpr subst e1, op, apply_subst_sexpr subst e2)
124+
| SLit(l) -> SLit(l)
125+
| SVar(s) -> SVar(s)
126+
| SUnop(op, e) -> SUnop(op, apply_subst_sexpr subst e)
127+
| SCall(e, args, s) -> SCall(apply_subst_sexpr subst e, List.map (apply_subst_sexpr subst) args, apply_subst_sstmt subst s)
128+
| SMethod(e, name, args) -> SMethod(apply_subst_sexpr subst e, name, List.map (apply_subst_sexpr subst) args)
129+
| SField(e, s) -> SField(apply_subst_sexpr subst e, s)
130+
| SList(exprs, t) -> SList(List.map (apply_subst_sexpr subst) exprs, finalize_type subst t)
131+
| SNoexpr -> SNoexpr
132+
| SListAccess(e1, e2) -> SListAccess(apply_subst_sexpr subst e1, apply_subst_sexpr subst e2)
133+
| SListSlice(e1, e2, e3) -> SListSlice(apply_subst_sexpr subst e1, apply_subst_sexpr subst e2, apply_subst_sexpr subst e3)
134+
| SCast(t1, t2, e) -> SCast(finalize_type subst t1, finalize_type subst t2, apply_subst_sexpr subst e)
135+
136+
and apply_subst_sstmt (subst : substitution) (stmt : sstmt) : sstmt =
137+
match stmt with
138+
| SFunc(fdecl) -> SFunc({
139+
styp = finalize_type subst fdecl.styp;
140+
sfname = fdecl.sfname;
141+
sformals = List.map (apply_subst_bind subst) fdecl.sformals;
142+
slocals = List.map (apply_subst_bind subst) fdecl.slocals;
143+
sbody = apply_subst_sstmt subst fdecl.sbody
144+
})
145+
| SBlock(stmts) -> SBlock(List.map (apply_subst_sstmt subst) stmts)
146+
| SExpr(e) -> SExpr(apply_subst_sexpr subst e)
147+
| SIf(e, s1, s2) -> SIf(apply_subst_sexpr subst e, apply_subst_sstmt subst s1, apply_subst_sstmt subst s2)
148+
| SFor(b, e, s) -> SFor(apply_subst_bind subst b, apply_subst_sexpr subst e, apply_subst_sstmt subst s)
149+
| SWhile(e, s) -> SWhile(apply_subst_sexpr subst e, apply_subst_sstmt subst s)
150+
| SRange(b, e, s) -> SRange(apply_subst_bind subst b, apply_subst_sexpr subst e, apply_subst_sstmt subst s)
151+
| SReturn(e) -> SReturn(apply_subst_sexpr subst e)
152+
| SClass(name, s) -> SClass(name, apply_subst_sstmt subst s)
153+
| SAsn(lvalues, e) -> SAsn(List.map (apply_subst_lvalue subst) lvalues, apply_subst_sexpr subst e)
154+
| STransform(s, t1, t2) -> STransform(s, finalize_type subst t1, finalize_type subst t2)
155+
| SStage(s1, s2, s3) -> SStage(apply_subst_sstmt subst s1, apply_subst_sstmt subst s2, apply_subst_sstmt subst s3)
156+
| SPrint(e) -> SPrint(apply_subst_sexpr subst e)
157+
| SType(e) -> SType(apply_subst_sexpr subst e)
158+
| SContinue -> SContinue
159+
| SBreak -> SBreak
160+
| SNop -> SNop
161+
162+
and apply_subst_lvalue (subst : substitution) (lv : lvalue) : lvalue =
163+
match lv with
164+
| SLVar(b) -> SLVar(apply_subst_bind subst b)
165+
| SLListAccess(e1, e2) -> SLListAccess(apply_subst_sexpr subst e1, apply_subst_sexpr subst e2)
166+
| SLListSlice(e1, e2, e3) -> SLListSlice(apply_subst_sexpr subst e1, apply_subst_sexpr subst e2, apply_subst_sexpr subst e3)

src/semant.ml

Lines changed: 94 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
open Ast
22
open Sast
33
open Utilities
4+
open Infer
45

56
(* Semant takes an Abstract Syntax Tree and returns a Syntactically Checked AST with partial type inferrence,
67
syntax checking, and other features. expr objects are converted to sexpr, and stmt objects are converted
@@ -24,10 +25,34 @@ let needs_cast t1 t2 =
2425
This currently is quite restrictive and does not permit automatic type casting like in Python.
2526
This may be changed in the future. The commented-out line would allow that feature *)
2627

27-
let binop t1 t2 op =
28+
(* Helper to check if a type is a type variable *)
29+
let is_tvar = function
30+
| TVar _ -> true
31+
| _ -> false
32+
33+
(* Check if an operation is a comparison that returns Bool *)
34+
let is_comparison = function
35+
| Eq | Neq | Less | Leq | Greater | Geq -> true
36+
| _ -> false
37+
38+
(* Check if an operation is a logical operation that returns Bool *)
39+
let is_logical = function
40+
| And | Or -> true
41+
| _ -> false
42+
43+
let binop t1 t2 op =
2844
let except = (Failure ("STypeError: unsupported operand type(s) for binary " ^ binop_to_string op ^ ": '" ^ type_to_string t1 ^ "' and '" ^ type_to_string t2 ^ "'")) in
2945
match (t1, t2) with
30-
| (Dyn, Dyn) | (Dyn, _) | (_, Dyn) -> Dyn
46+
(* If either operand is Dyn, result is Dyn (or Bool for comparisons) *)
47+
| (Dyn, Dyn) | (Dyn, _) | (_, Dyn) ->
48+
if is_comparison op || is_logical op then Bool else Dyn
49+
(* Handle type variables - infer from the known operand *)
50+
| (TVar a, TVar b) when a = b -> (* Same TVar *)
51+
if is_comparison op || is_logical op then Bool else TVar a
52+
| (TVar _, TVar _) -> Dyn (* Different TVars -> Dyn *)
53+
| (TVar _, t) | (t, TVar _) ->
54+
if is_comparison op || is_logical op then Bool
55+
else t (* Infer result type from the known operand *)
3156
| _ -> let same = t1 = t2 in (match op with
3257
| Add | Sub | Mul | Exp when same && t1 = Int -> Int
3358
| Add | Sub | Mul | Div | Exp when same && t1 = Float -> Float
@@ -165,27 +190,50 @@ and exp the_state = function
165190
let (map'', _, _, _) = assign map' (Dyn, (SCall (e, (List.rev exprout), transforms), Dyn), data) name in (* add the function itself to the namespace *)
166191

167192
let (_, types) = split_sbind bindout in
168-
169-
if the_state.func && TypeMap.mem (x, types) the_state.stack then let () = debug "recursive callstack return" in (Dyn, SCall(e, (List.rev exprout), transforms), None)
170-
else let stack' = TypeMap.add (x, types) true the_state.stack in (* check recursive stack *)
171-
172-
let (map2, block, data, locals) = (stmt {the_state with stack = stack'; func = true; locals = map''; } body) in
173-
174-
(match data with (* match return type with *)
175-
| Some (typ2, e', d) -> (* it did return something *)
176-
let Bind(n1, btype) = name in
177-
if btype <> Dyn && btype <> typ2 then if typ2 <> Dyn
178-
then raise (Failure (Printf.sprintf "STypeError: invalid return type (expected %s but found %s)" (string_of_typ btype) (string_of_typ typ2)))
179-
else let func = { styp = btype; sfname = n1; sformals = (List.rev bindout); slocals = locals; sbody = block } in
180-
(btype, (SCall(e, (List.rev exprout), SFunc(func))), d)
181-
else let func = { styp = typ2; sfname = n1; sformals = (List.rev bindout); slocals = locals; sbody = block } in (* case where definite return type and Dynamic inferrence still has bind*)
182-
(typ2, (SCall(e, (List.rev exprout), SFunc(func))), d)
183-
184-
| None -> (* function didn't return anything, null function *)
185-
let Bind(n1, btype) = name in if btype <> Dyn then
186-
raise (Failure (Printf.sprintf "STypeError: invalid return type (expected %s but found None)" (string_of_typ btype))) else
187-
let func = { styp = Null; sfname = n1; sformals = (List.rev bindout); slocals = locals; sbody = block } in
188-
(Null, (SCall(e, (List.rev exprout), SFunc(func))), None))
193+
let Bind(n1, btype) = name in
194+
195+
(* For HM inference: create a return type variable if no explicit type *)
196+
let ret_tvar = if btype = Dyn then fresh_var () else btype in
197+
198+
(* Check for recursive call - if so, return the type variable for inference *)
199+
if the_state.func && TypeMap.mem (x, types) the_state.stack then
200+
let ret_tvar = TypeMap.find (x, types) the_state.stack in
201+
let () = debug "recursive callstack return with type var" in
202+
(* Use transforms (SStage) for recursive calls - dynamic calling convention *)
203+
(ret_tvar, SCall(e, (List.rev exprout), transforms), None)
204+
else
205+
(* Add function with its return type variable to the stack *)
206+
let stack' = TypeMap.add (x, types) ret_tvar the_state.stack in
207+
208+
let (map2, block, data, locals) = (stmt {the_state with stack = stack'; func = true; locals = map''; } body) in
209+
210+
(match data with (* match return type with *)
211+
| Some (typ2, e', d) -> (* it did return something *)
212+
(* Unify the inferred return type with our type variable *)
213+
let (resolved_type, final_subst) =
214+
if btype <> Dyn then
215+
(* Explicit return type - check it matches *)
216+
if btype <> typ2 && typ2 <> Dyn then
217+
raise (Failure (Printf.sprintf "STypeError: invalid return type (expected %s but found %s)" (string_of_typ btype) (string_of_typ typ2)))
218+
else (btype, the_state.subst)
219+
else
220+
(* No explicit type - use unification to resolve TVar *)
221+
try
222+
let subst' = unify ret_tvar typ2 the_state.subst in
223+
(finalize_type subst' ret_tvar, subst')
224+
with UnificationError _ -> (typ2, the_state.subst)
225+
in
226+
(* Apply substitution to resolve all TVars in the function body *)
227+
let resolved_block = apply_subst_sstmt final_subst block in
228+
let func = { styp = resolved_type; sfname = n1; sformals = (List.rev bindout); slocals = locals; sbody = resolved_block } in
229+
(resolved_type, (SCall(e, (List.rev exprout), SFunc(func))), d)
230+
231+
| None -> (* function didn't return anything, null function *)
232+
if btype <> Dyn then
233+
raise (Failure (Printf.sprintf "STypeError: invalid return type (expected %s but found None)" (string_of_typ btype)))
234+
else
235+
let func = { styp = Null; sfname = n1; sformals = (List.rev bindout); slocals = locals; sbody = block } in
236+
(Null, (SCall(e, (List.rev exprout), SFunc(func))), None))
189237

190238
| _ -> raise (Failure ("SCriticalFailure: unexpected type encountered internally in Call evaluation"))) (* can be expanded to allow classes in the future *)
191239

@@ -260,7 +308,7 @@ stack is a TypeMap containing the function call stack.
260308
TODO distinguish between outer and inner scope return statements to stop evaluating when definitely
261309
returned. *)
262310

263-
and check_func out data local_vars the_state = (function
311+
and check_func out data local_vars the_state = (function
264312
| [] -> ((List.rev out), data, the_state.locals , List.sort_uniq compare (List.rev local_vars))
265313
| a :: t -> let (m', value, d, loc) = stmt the_state a in
266314
let the_state = (change_state the_state (S_setmaps (m', the_state.globals))) in
@@ -269,6 +317,17 @@ and check_func out data local_vars the_state = (function
269317
| (None, _) -> check_func (value :: out) d (loc @ local_vars) the_state t
270318
| (_, None) -> check_func (value :: out) data (loc @ local_vars) the_state t
271319
| (_, _) when d = data -> check_func (value :: out) data (loc @ local_vars) the_state t
320+
| (Some x, Some y) ->
321+
(* Use unification to reconcile return types *)
322+
let (t1, _, _) = x and (t2, _, _) = y in
323+
let unified_type =
324+
if t1 = t2 then t1
325+
else try
326+
let subst = unify t1 t2 empty_subst in
327+
finalize_type subst t1
328+
with UnificationError _ -> Dyn
329+
in
330+
check_func (value :: out) (Some (unified_type, (SNoexpr, Dyn), None)) (loc @ local_vars) the_state t
272331
| _ -> check_func (value :: out) (Some (Dyn, (SNoexpr, Dyn), None)) (loc @ local_vars) the_state t))
273332

274333
(* match_data: when reconciling branches in a conditional branch, this function
@@ -281,10 +340,18 @@ and check_func out data local_vars the_state = (function
281340
and match_data d1 d2 = match d1, d2 with
282341
| (None, None) -> None
283342
| (None, _) | (_, None) -> (Some (Dyn, (SNoexpr, Dyn), None))
284-
| (Some x, Some y) ->
343+
| (Some x, Some y) ->
285344
if x = y then d1
286-
else let (t1, _, _) = x and (t2, _, _) = y in
287-
(Some ((if t1 = t2 then t1 else Dyn), (SNoexpr, Dyn), None))
345+
else let (t1, _, _) = x and (t2, _, _) = y in
346+
(* Use unification to reconcile types - e.g., TVar(1) and Int can unify to Int *)
347+
let unified_type =
348+
if t1 = t2 then t1
349+
else try
350+
let subst = unify t1 t2 empty_subst in
351+
finalize_type subst t1
352+
with UnificationError _ -> Dyn
353+
in
354+
(Some (unified_type, (SNoexpr, Dyn), None))
288355

289356
(* func_stmt: syntactically checkts statements inside functions. Exists mostly to handle
290357
function calls which recurse and to redirect calls to expr to expr. We may be able

0 commit comments

Comments
 (0)