Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type literal =
| FloatLit of float
| StringLit of string

type typ = Int | Float | Bool | String | Dyn | Arr | Object | FuncType | Null
type typ = Int | Float | Bool | String | Dyn | Arr | Object | FuncType | Null | TVar of int

type bind = Bind of string * typ

Expand Down Expand Up @@ -91,6 +91,7 @@ let rec string_of_typ = function
| FuncType -> "func"
| Object -> "object"
| Null -> "null"
| TVar n -> "'" ^ string_of_int n

let rec string_of_bind = function
| Bind(s, t) -> s ^ ": " ^ string_of_typ t
Expand Down
5 changes: 3 additions & 2 deletions src/coral.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
open Ast
open Sast
open Utilities
open Infer

(* coral.ml: the main compiler file for the Coral Programming Language. coral.ml handles command line
parsing, generating and interpretering the compiler and interpreter, handling tab-based indentation,
Expand Down Expand Up @@ -344,7 +345,7 @@ let rec from_console map past run =
let imported_program = parse_imports program in
let after_program = strip_after [] imported_program in

let (sast, map') = (Semant.check [] [] { forloop = false; inclass = false; cond = false; noeval = false; stack = TypeMap.empty; func = false; locals = map; globals = map; } after_program) in (* temporarily here to check validity of SAST *)
let (sast, map') = (Semant.check [] [] { forloop = false; inclass = false; cond = false; noeval = false; stack = TypeMap.empty; func = false; locals = map; globals = map; subst = Infer.empty_subst; } after_program) in (* temporarily here to check validity of SAST *)
let (sast, globals) = sast in
let sast = (strip_return [] sast, globals) in
let _ = if !debug then print_endline ("Parser: \n\n" ^ (string_of_sprogram sast)) in (* print debug messages *)
Expand Down Expand Up @@ -375,7 +376,7 @@ let rec from_file map fname run = (* todo combine with loop *)
let imported_program = parse_imports program in
let after_program = strip_after [] imported_program in

let (sast, map') = (Semant.check [] [] { forloop = false; inclass = false; cond = false; noeval = false; stack = TypeMap.empty; func = false; globals = map; locals = map; } after_program) in (* temporarily here to check validity of SAST *)
let (sast, map') = (Semant.check [] [] { forloop = false; inclass = false; cond = false; noeval = false; stack = TypeMap.empty; func = false; globals = map; locals = map; subst = Infer.empty_subst; } after_program) in (* temporarily here to check validity of SAST *)
let (sast, globals) = sast in
let sast = (strip_return [] sast, globals) in
let () = if !debug then print_endline ("Parser: \n\n" ^ (string_of_sprogram sast)); flush stdout; in (* print debug messages *)
Expand Down
166 changes: 166 additions & 0 deletions src/infer.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
open Ast

(* Hindley-Milner type inference support for Coral.
This module provides type variables, substitutions, and unification. *)

(* Type variable map: maps type variable IDs to their resolved types *)
module TVarMap = Map.Make(struct type t = int let compare = compare end)

type substitution = typ TVarMap.t

(* Fresh type variable generation *)
let tvar_counter = ref 0

let fresh_var () =
incr tvar_counter;
TVar(!tvar_counter)

let reset_counter () =
tvar_counter := 0

let empty_subst : substitution = TVarMap.empty

(* Check if a type variable occurs in a type (prevents infinite types) *)
let rec occurs (var : int) (t : typ) : bool =
match t with
| TVar v -> v = var
| Int | Float | Bool | String | Dyn | Arr | Object | FuncType | Null -> false

(* Apply a substitution to a type, recursively resolving type variables *)
let rec apply_subst (subst : substitution) (t : typ) : typ =
match t with
| TVar id ->
(match TVarMap.find_opt id subst with
| None -> TVar id (* Unresolved, keep as-is *)
| Some t' -> apply_subst subst t') (* Recursively apply in case t' has vars *)
| Int | Float | Bool | String | Dyn | Arr | Object | FuncType | Null -> t

(* Extend a substitution by mapping var to t *)
let extend_subst (subst : substitution) (var : int) (t : typ) : substitution =
TVarMap.add var t subst

(* Compose two substitutions: apply s1 first, then s2 *)
let compose_subst (s1 : substitution) (s2 : substitution) : substitution =
(* Apply s1 to all types in s2 *)
let s2_after_s1 = TVarMap.map (apply_subst s1) s2 in
(* Merge: s1 takes precedence for common keys *)
TVarMap.merge (fun _ v1 v2 ->
match v1, v2 with
| Some x, _ -> Some x
| None, Some x -> Some x
| None, None -> None
) s1 s2_after_s1

exception UnificationError of string

(* Unify two types, returning an updated substitution *)
let rec unify (t1 : typ) (t2 : typ) (subst : substitution) : substitution =
(* Apply current substitution first *)
let t1' = apply_subst subst t1 in
let t2' = apply_subst subst t2 in

match t1', t2' with
(* Same types: no change needed *)
| t1, t2 when t1 = t2 -> subst

(* Type variable on one side *)
| TVar v, t | t, TVar v ->
(* Occurs check: prevent infinite types like 'a = list['a] *)
if occurs v t then
raise (UnificationError
(Printf.sprintf "Cannot unify: occurs check failed for TVar(%d)" v))
else
extend_subst subst v t

(* Dyn (dynamic type) unifies with anything - gradual typing *)
| Dyn, _ | _, Dyn -> subst

(* All other cases: types must match exactly *)
| t1, t2 ->
raise (UnificationError
(Printf.sprintf "Cannot unify %s with %s"
(string_of_typ t1) (string_of_typ t2)))

(* Replace all unresolved type variables with Dyn *)
let finalize_type (subst : substitution) (t : typ) : typ =
let t' = apply_subst subst t in
match t' with
| TVar _ -> Dyn (* Unresolved type variable becomes Dyn *)
| _ -> t'

(* Helper to check if a type is a type variable *)
let is_tvar = function
| TVar _ -> true
| _ -> false

(* Get the ID from a type variable, or None *)
let get_tvar_id = function
| TVar id -> Some id
| _ -> None

(* Apply substitution to a bind *)
let apply_subst_bind (subst : substitution) (Bind(name, t) : Ast.bind) : Ast.bind =
Bind(name, finalize_type subst t)

(* SAST transformation functions - apply substitution to resolve TVars *)
open Sast

(* Check if an sstmt is an SStage (dynamic call) *)
let is_sstage = function
| SStage(_, _, _) -> true
| _ -> false

let rec apply_subst_sexpr (subst : substitution) ((sexp, t) : sexpr) : sexpr =
(* For SCall with SStage (dynamic calling), always use Dyn as the result type *)
let resolved_t = match sexp with
| SCall(_, _, s) when is_sstage s -> Dyn (* Dynamic calls return boxed Dyn *)
| _ -> finalize_type subst t
in
(apply_subst_sexp subst sexp, resolved_t)

and apply_subst_sexp (subst : substitution) (sexp : sexp) : sexp =
match sexp with
| SBinop(e1, op, e2) -> SBinop(apply_subst_sexpr subst e1, op, apply_subst_sexpr subst e2)
| SLit(l) -> SLit(l)
| SVar(s) -> SVar(s)
| SUnop(op, e) -> SUnop(op, apply_subst_sexpr subst e)
| SCall(e, args, s) -> SCall(apply_subst_sexpr subst e, List.map (apply_subst_sexpr subst) args, apply_subst_sstmt subst s)
| SMethod(e, name, args) -> SMethod(apply_subst_sexpr subst e, name, List.map (apply_subst_sexpr subst) args)
| SField(e, s) -> SField(apply_subst_sexpr subst e, s)
| SList(exprs, t) -> SList(List.map (apply_subst_sexpr subst) exprs, finalize_type subst t)
| SNoexpr -> SNoexpr
| SListAccess(e1, e2) -> SListAccess(apply_subst_sexpr subst e1, apply_subst_sexpr subst e2)
| SListSlice(e1, e2, e3) -> SListSlice(apply_subst_sexpr subst e1, apply_subst_sexpr subst e2, apply_subst_sexpr subst e3)
| SCast(t1, t2, e) -> SCast(finalize_type subst t1, finalize_type subst t2, apply_subst_sexpr subst e)

and apply_subst_sstmt (subst : substitution) (stmt : sstmt) : sstmt =
match stmt with
| SFunc(fdecl) -> SFunc({
styp = finalize_type subst fdecl.styp;
sfname = fdecl.sfname;
sformals = List.map (apply_subst_bind subst) fdecl.sformals;
slocals = List.map (apply_subst_bind subst) fdecl.slocals;
sbody = apply_subst_sstmt subst fdecl.sbody
})
| SBlock(stmts) -> SBlock(List.map (apply_subst_sstmt subst) stmts)
| SExpr(e) -> SExpr(apply_subst_sexpr subst e)
| SIf(e, s1, s2) -> SIf(apply_subst_sexpr subst e, apply_subst_sstmt subst s1, apply_subst_sstmt subst s2)
| SFor(b, e, s) -> SFor(apply_subst_bind subst b, apply_subst_sexpr subst e, apply_subst_sstmt subst s)
| SWhile(e, s) -> SWhile(apply_subst_sexpr subst e, apply_subst_sstmt subst s)
| SRange(b, e, s) -> SRange(apply_subst_bind subst b, apply_subst_sexpr subst e, apply_subst_sstmt subst s)
| SReturn(e) -> SReturn(apply_subst_sexpr subst e)
| SClass(name, s) -> SClass(name, apply_subst_sstmt subst s)
| SAsn(lvalues, e) -> SAsn(List.map (apply_subst_lvalue subst) lvalues, apply_subst_sexpr subst e)
| STransform(s, t1, t2) -> STransform(s, finalize_type subst t1, finalize_type subst t2)
| SStage(s1, s2, s3) -> SStage(apply_subst_sstmt subst s1, apply_subst_sstmt subst s2, apply_subst_sstmt subst s3)
| SPrint(e) -> SPrint(apply_subst_sexpr subst e)
| SType(e) -> SType(apply_subst_sexpr subst e)
| SContinue -> SContinue
| SBreak -> SBreak
| SNop -> SNop

and apply_subst_lvalue (subst : substitution) (lv : lvalue) : lvalue =
match lv with
| SLVar(b) -> SLVar(apply_subst_bind subst b)
| SLListAccess(e1, e2) -> SLListAccess(apply_subst_sexpr subst e1, apply_subst_sexpr subst e2)
| SLListSlice(e1, e2, e3) -> SLListSlice(apply_subst_sexpr subst e1, apply_subst_sexpr subst e2, apply_subst_sexpr subst e3)
121 changes: 94 additions & 27 deletions src/semant.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
open Ast
open Sast
open Utilities
open Infer

(* Semant takes an Abstract Syntax Tree and returns a Syntactically Checked AST with partial type inferrence,
syntax checking, and other features. expr objects are converted to sexpr, and stmt objects are converted
Expand All @@ -24,10 +25,34 @@ let needs_cast t1 t2 =
This currently is quite restrictive and does not permit automatic type casting like in Python.
This may be changed in the future. The commented-out line would allow that feature *)

let binop t1 t2 op =
(* Helper to check if a type is a type variable *)
let is_tvar = function
| TVar _ -> true
| _ -> false

(* Check if an operation is a comparison that returns Bool *)
let is_comparison = function
| Eq | Neq | Less | Leq | Greater | Geq -> true
| _ -> false

(* Check if an operation is a logical operation that returns Bool *)
let is_logical = function
| And | Or -> true
| _ -> false

let binop t1 t2 op =
let except = (Failure ("STypeError: unsupported operand type(s) for binary " ^ binop_to_string op ^ ": '" ^ type_to_string t1 ^ "' and '" ^ type_to_string t2 ^ "'")) in
match (t1, t2) with
| (Dyn, Dyn) | (Dyn, _) | (_, Dyn) -> Dyn
(* If either operand is Dyn, result is Dyn (or Bool for comparisons) *)
| (Dyn, Dyn) | (Dyn, _) | (_, Dyn) ->
if is_comparison op || is_logical op then Bool else Dyn
(* Handle type variables - infer from the known operand *)
| (TVar a, TVar b) when a = b -> (* Same TVar *)
if is_comparison op || is_logical op then Bool else TVar a
| (TVar _, TVar _) -> Dyn (* Different TVars -> Dyn *)
| (TVar _, t) | (t, TVar _) ->
if is_comparison op || is_logical op then Bool
else t (* Infer result type from the known operand *)
| _ -> let same = t1 = t2 in (match op with
| Add | Sub | Mul | Exp when same && t1 = Int -> Int
| Add | Sub | Mul | Div | Exp when same && t1 = Float -> Float
Expand Down Expand Up @@ -165,27 +190,50 @@ and exp the_state = function
let (map'', _, _, _) = assign map' (Dyn, (SCall (e, (List.rev exprout), transforms), Dyn), data) name in (* add the function itself to the namespace *)

let (_, types) = split_sbind bindout in

if the_state.func && TypeMap.mem (x, types) the_state.stack then let () = debug "recursive callstack return" in (Dyn, SCall(e, (List.rev exprout), transforms), None)
else let stack' = TypeMap.add (x, types) true the_state.stack in (* check recursive stack *)

let (map2, block, data, locals) = (stmt {the_state with stack = stack'; func = true; locals = map''; } body) in

(match data with (* match return type with *)
| Some (typ2, e', d) -> (* it did return something *)
let Bind(n1, btype) = name in
if btype <> Dyn && btype <> typ2 then if typ2 <> Dyn
then raise (Failure (Printf.sprintf "STypeError: invalid return type (expected %s but found %s)" (string_of_typ btype) (string_of_typ typ2)))
else let func = { styp = btype; sfname = n1; sformals = (List.rev bindout); slocals = locals; sbody = block } in
(btype, (SCall(e, (List.rev exprout), SFunc(func))), d)
else let func = { styp = typ2; sfname = n1; sformals = (List.rev bindout); slocals = locals; sbody = block } in (* case where definite return type and Dynamic inferrence still has bind*)
(typ2, (SCall(e, (List.rev exprout), SFunc(func))), d)

| None -> (* function didn't return anything, null function *)
let Bind(n1, btype) = name in if btype <> Dyn then
raise (Failure (Printf.sprintf "STypeError: invalid return type (expected %s but found None)" (string_of_typ btype))) else
let func = { styp = Null; sfname = n1; sformals = (List.rev bindout); slocals = locals; sbody = block } in
(Null, (SCall(e, (List.rev exprout), SFunc(func))), None))
let Bind(n1, btype) = name in

(* For HM inference: create a return type variable if no explicit type *)
let ret_tvar = if btype = Dyn then fresh_var () else btype in

(* Check for recursive call - if so, return the type variable for inference *)
if the_state.func && TypeMap.mem (x, types) the_state.stack then
let ret_tvar = TypeMap.find (x, types) the_state.stack in
let () = debug "recursive callstack return with type var" in
(* Use transforms (SStage) for recursive calls - dynamic calling convention *)
(ret_tvar, SCall(e, (List.rev exprout), transforms), None)
else
(* Add function with its return type variable to the stack *)
let stack' = TypeMap.add (x, types) ret_tvar the_state.stack in

let (map2, block, data, locals) = (stmt {the_state with stack = stack'; func = true; locals = map''; } body) in

(match data with (* match return type with *)
| Some (typ2, e', d) -> (* it did return something *)
(* Unify the inferred return type with our type variable *)
let (resolved_type, final_subst) =
if btype <> Dyn then
(* Explicit return type - check it matches *)
if btype <> typ2 && typ2 <> Dyn then
raise (Failure (Printf.sprintf "STypeError: invalid return type (expected %s but found %s)" (string_of_typ btype) (string_of_typ typ2)))
else (btype, the_state.subst)
else
(* No explicit type - use unification to resolve TVar *)
try
let subst' = unify ret_tvar typ2 the_state.subst in
(finalize_type subst' ret_tvar, subst')
with UnificationError _ -> (typ2, the_state.subst)
in
(* Apply substitution to resolve all TVars in the function body *)
let resolved_block = apply_subst_sstmt final_subst block in
let func = { styp = resolved_type; sfname = n1; sformals = (List.rev bindout); slocals = locals; sbody = resolved_block } in
(resolved_type, (SCall(e, (List.rev exprout), SFunc(func))), d)

| None -> (* function didn't return anything, null function *)
if btype <> Dyn then
raise (Failure (Printf.sprintf "STypeError: invalid return type (expected %s but found None)" (string_of_typ btype)))
else
let func = { styp = Null; sfname = n1; sformals = (List.rev bindout); slocals = locals; sbody = block } in
(Null, (SCall(e, (List.rev exprout), SFunc(func))), None))

| _ -> raise (Failure ("SCriticalFailure: unexpected type encountered internally in Call evaluation"))) (* can be expanded to allow classes in the future *)

Expand Down Expand Up @@ -260,7 +308,7 @@ stack is a TypeMap containing the function call stack.
TODO distinguish between outer and inner scope return statements to stop evaluating when definitely
returned. *)

and check_func out data local_vars the_state = (function
and check_func out data local_vars the_state = (function
| [] -> ((List.rev out), data, the_state.locals , List.sort_uniq compare (List.rev local_vars))
| a :: t -> let (m', value, d, loc) = stmt the_state a in
let the_state = (change_state the_state (S_setmaps (m', the_state.globals))) in
Expand All @@ -269,6 +317,17 @@ and check_func out data local_vars the_state = (function
| (None, _) -> check_func (value :: out) d (loc @ local_vars) the_state t
| (_, None) -> check_func (value :: out) data (loc @ local_vars) the_state t
| (_, _) when d = data -> check_func (value :: out) data (loc @ local_vars) the_state t
| (Some x, Some y) ->
(* Use unification to reconcile return types *)
let (t1, _, _) = x and (t2, _, _) = y in
let unified_type =
if t1 = t2 then t1
else try
let subst = unify t1 t2 empty_subst in
finalize_type subst t1
with UnificationError _ -> Dyn
in
check_func (value :: out) (Some (unified_type, (SNoexpr, Dyn), None)) (loc @ local_vars) the_state t
| _ -> check_func (value :: out) (Some (Dyn, (SNoexpr, Dyn), None)) (loc @ local_vars) the_state t))

(* match_data: when reconciling branches in a conditional branch, this function
Expand All @@ -281,10 +340,18 @@ and check_func out data local_vars the_state = (function
and match_data d1 d2 = match d1, d2 with
| (None, None) -> None
| (None, _) | (_, None) -> (Some (Dyn, (SNoexpr, Dyn), None))
| (Some x, Some y) ->
| (Some x, Some y) ->
if x = y then d1
else let (t1, _, _) = x and (t2, _, _) = y in
(Some ((if t1 = t2 then t1 else Dyn), (SNoexpr, Dyn), None))
else let (t1, _, _) = x and (t2, _, _) = y in
(* Use unification to reconcile types - e.g., TVar(1) and Int can unify to Int *)
let unified_type =
if t1 = t2 then t1
else try
let subst = unify t1 t2 empty_subst in
finalize_type subst t1
with UnificationError _ -> Dyn
in
(Some (unified_type, (SNoexpr, Dyn), None))

(* func_stmt: syntactically checkts statements inside functions. Exists mostly to handle
function calls which recurse and to redirect calls to expr to expr. We may be able
Expand Down
Loading