Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
9707a3f
disjunctive subtyping
auht Feb 25, 2025
cc87059
Changes from meeting
LPTK Feb 26, 2025
0298463
disjoint upperbound
auht Mar 5, 2025
bf4c763
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht Mar 5, 2025
ec5c55e
no disjoint upperbound
auht Mar 5, 2025
aa956a7
multiple disjointness
auht Mar 5, 2025
8bad1a5
Changes from meeting
LPTK Mar 6, 2025
86d300d
Add test case and move tests to logicsub folder
LPTK Mar 6, 2025
3b34199
constraints solving nested disjsub
auht Mar 7, 2025
ae28b42
ues linkedhashset
auht Mar 10, 2025
2f1f5f3
traverse disjsub
auht Mar 12, 2025
3d383a2
Changes from meeting
LPTK Mar 13, 2025
00e397b
fix pretty printer type traverse and tv subst
auht Mar 13, 2025
bdd4b7a
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht Mar 13, 2025
0bbfeca
wip rcdtype implementation
auht Mar 20, 2025
3a67c1f
wip rcdtype implementation and fun args disjointness
auht Mar 22, 2025
bd53f6f
intersections wf check
auht Mar 24, 2025
92dc44a
fix nested record wf check
auht Mar 26, 2025
69bcc5a
wip refined if
auht Mar 28, 2025
081de12
Update hkmc2/shared/src/test/mlscript/logicsub/Records.mls
auht Mar 28, 2025
52c6941
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht Mar 28, 2025
55b9066
wip fix
auht Apr 2, 2025
d99f3af
wip else branch disjointness
auht Apr 4, 2025
93d7ff6
fix if disjsub
auht Apr 6, 2025
3cd0f96
wip if
auht Apr 8, 2025
4996431
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht Apr 8, 2025
3928da0
test
auht Apr 8, 2025
88e3518
Pretty printer changes from meeting
LPTK Apr 11, 2025
f77ddbb
dnf disjointness
auht Apr 14, 2025
f009711
fix if
auht Apr 16, 2025
fd21398
fix if
auht Apr 16, 2025
4dc8cd7
fix missing constraints
auht Apr 18, 2025
0ed746e
test
auht Apr 18, 2025
836f3bb
elim branches
auht Apr 20, 2025
a9c4221
negtype and wip test
auht Apr 22, 2025
3bfe1c4
wip test explanation
auht Apr 22, 2025
91ac0d4
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht Apr 22, 2025
99b465b
rcd union neg disjointess
auht Apr 28, 2025
9e13a34
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht Apr 28, 2025
afb08c5
wf check and typing function intersection
auht May 4, 2025
2cfa234
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
auht May 4, 2025
69feb99
test
auht May 4, 2025
4a335ea
modify disjointness signature and impl
auht May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 71 additions & 24 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package hkmc2
package bbml


import scala.collection.mutable.{HashSet, HashMap, ListBuffer, LinkedHashSet}
import scala.collection.mutable.{HashSet, HashMap, ListBuffer, LinkedHashSet, LinkedHashMap}
import scala.annotation.tailrec

import mlscript.utils.*, shorthands.*
Expand Down Expand Up @@ -226,6 +226,15 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
def constrain(lhs: Type, rhs: Type)(using ctx: BbCtx, cctx: CCtx): Unit
def commit(ds: DisjSub): Unit

private def constraintCollector =
val cs = ListBuffer.empty[Type -> Type]
val dss = ListBuffer.empty[DisjSub]
val nc = new ConstraintHandler:
def constrain(lhs: Type, rhs: Type)(using ctx: BbCtx, cctx: CCtx) =
cs += lhs -> rhs
def commit(ds: DisjSub) = dss += ds
(nc, dss, cs)

private def constrain(lhs: Type, rhs: Type)(using ctx: BbCtx, cctx: CCtx, c: ConstraintHandler): Unit =
c.constrain(lhs, rhs)

Expand Down Expand Up @@ -314,8 +323,32 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
pctx += sym -> PolyType.generalize(funTy, S(outer), 1)
case _ => error(msg"Function definition shape not yet supported for ${sym.nme}" -> lam.toLoc :: Nil)

private def typeSplitBr(split: Split, sign: GeneralType, isElse: Bool)(using ctx: BbCtx, c: ConstraintHandler)(using CCtx, Scope): (Type, Ls[Ls[Type -> ClassLikeType]]) = split match
case Split.Cons(Branch(Ref(sym), c: Pattern.ClassLike, cons), alts) =>
val sty = tryMkMono(ctx.get(sym).get, sym)
val cls = ClassLikeType(c.sym, Nil)
val ctx1 = ctx.nest
ctx1 += sym -> (cls & sty)
val (ce, p0) = typeSplitBr(cons, sign, false)(using ctx1)
val (ae, p1) = typeSplitBr(alts, sign, true)
val p = if p1.isEmpty then
if isElse then Nil else Ls(sty -> cls) :: p0
else
(Ls(sty -> cls) :: p0).flatMap(u => p1.map(u ++ _))
(ce | ae, p)
case Split.Let(name, term, tail) =>
val nestCtx = ctx.nest
given BbCtx = nestCtx
val (termTy, termEff) = typeCheck(term)
val sk = freshSkolem(name)
nestCtx += name -> termTy
val (tailEff, p) = typeSplitBr(tail, sign, isElse)(using nestCtx)
(termEff | tailEff, p)
case Split.Else(alts) => (ascribe(alts, sign)._2, Nil)
case Split.End => (Bot, Nil)

private def typeSplit
(split: Split, sign: Opt[GeneralType])(using ctx: BbCtx, c: ConstraintHandler)(using CCtx, Scope)
(split: Split, sign: Opt[GeneralType], path: Ls[Ls[Type -> ClassLikeType]] = Ls(Nil))(using ctx: BbCtx, c: ConstraintHandler)(using CCtx, Scope)
: (GeneralType, Type) =
split match
case Split.Cons(Branch(scrutinee, pattern, cons), alts) =>
Expand All @@ -329,34 +362,23 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
val clsTy = ClassLikeType(sym, cls.tparams.map(_ => Wildcard.empty))
val sty = tryMkMono(scrutineeTy, scrutinee)
val res = sign.orElse(S(freshVar(new TempSymbol(N, "res"))))
// val ctv = freshVar(new TempSymbol(S(scrutinee), "scrut"))
// val atv = freshVar(new TempSymbol(S(scrutinee), "scrut"))
scrutinee match // * refine
case Ref(sym: LocalSymbol) =>
nestCtx1 += sym -> (clsTy & sty)
nestCtx2 += sym -> sty
// nestCtx1 += sym -> ctv
// nestCtx2 += sym -> atv
case _ => () // TODO: refine all variables holding this value?
// constrain(sty, (clsTy & ctv) | (clsTy.! & atv))
val (consTy, consEff) = Type.disjoint(clsTy, sty) match
case N => typeSplit(cons, res)(using nestCtx1)
val (consEff, p) = Type.disjoint(clsTy, sty) match
case N => typeSplitBr(cons, res.get, false)(using nestCtx1)
case S(k) =>
if k.isEmpty then (Bot, Bot)
if k.isEmpty then (Bot, Ls(Nil))
else
val cs = ListBuffer.empty[Type -> Type]
val dss = ListBuffer.empty[DisjSub]
val nc = new ConstraintHandler:
def constrain(lhs: Type, rhs: Type)(using ctx: BbCtx, cctx: CCtx) =
cs += lhs -> rhs
def commit(ds: DisjSub) = dss += ds
val (nc, dss, cs) = constraintCollector
val eff = freshVar(new TempSymbol(N, "eff"))
val (t, e) = typeSplit(cons, res)(using nestCtx1, nc)
val (e, p) = typeSplitBr(cons, res.get, false)(using nestCtx1, nc)
k.foreach(k => c.commit(DisjSub(LinkedHashSet.from(k), dss.toList, (e, eff) :: cs.toList)))
(t, eff)
val (altsTy, altsEff) = typeSplit(alts, res)(using nestCtx2)
(eff, p)
val (altsTy, altsEff) = typeSplit(alts, res, (Ls(sty -> clsTy) :: p).flatMap(u => path.map(u ++ _)))(using nestCtx2)
val allEff = scrutineeEff | (consEff | altsEff)
// (sign.getOrElse(tryMkMono(consTy, cons) | tryMkMono(altsTy, alts)), allEff)
(res.get, allEff)
case _ =>
error(msg"Cannot match ${scrutinee.toString} as ${sym.toString}" -> split.toLoc :: Nil)
Expand All @@ -380,10 +402,35 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
nestCtx += name -> termTy
val (tailTy, tailEff) = typeSplit(tail, sign)(using nestCtx)
(tailTy, termEff | tailEff)
case Split.Else(alts) => sign match
case S(sign) => ascribe(alts, sign)
case _ => typeCheck(alts)
case Split.End => (Bot, Bot)
case Split.Else(alts) =>
val p = path.map: u =>
val m = LinkedHashMap.empty[Type, Type]
u.foreach { case (t, c) => m.updateWith(t)(_.map(_ | c).orElse(S(c))) }
m.flatMap { case (t, c) => Type.disjoint(NegType(c), t) }.toList
if p.exists(_.isEmpty) then
sign match
case S(sign) => ascribe(alts, sign)
case _ => typeCheck(alts)
else
val (nc, dss, cs) = constraintCollector
val res = sign.orElse(S(freshVar(new TempSymbol(N, "res")))).get
val eff = freshVar(new TempSymbol(N, "eff"))
val (_, e) = ascribe(alts, res)(using c = nc)
p.foreach(_.reduce((x, y) => y.flatMap(y => x.map(_ ++ y))).foreach: k =>
c.commit(DisjSub(LinkedHashSet.from(k), dss.toList, (e, eff) :: cs.toList)))
(res, e)
case Split.End =>
val ss = path.flatMap(_.map(_._1 -> Bot))
val p = path.map: u =>
val m = LinkedHashMap.empty[Type, Type]
u.foreach { case (t, c) => m.updateWith(t)(_.map(_ | c).orElse(S(c))) }
m.flatMap { case (t, c) => Type.disjoint(NegType(c), t) }.toList
if p.exists(_.isEmpty) then
ss.foreach { case (x, y) => constrain(x, y) }
else
p.foreach(_.reduce((x, y) => y.flatMap(y => x.map(_ ++ y))).foreach: k =>
c.commit(DisjSub(LinkedHashSet.from(k), Nil, ss)))
(Bot, Bot)

// * Note: currently, the returned type is not used or useful, but it could be in the future
private def ascribe(lhs: Term, rhs: GeneralType)(using ctx: BbCtx, scope: Scope, c: ConstraintHandler): (GeneralType, Type) =
Expand Down
134 changes: 78 additions & 56 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -320,68 +320,90 @@ object Type:
val (q, p) = discriminant(r.toBasic)
((u & q).toBasic, (w & p).toBasic)
case a => (Top, a)
def in(a: BasicType, b: BasicType)(prev: Set[BasicType -> BasicType])
(using c: MutMap[BasicType -> BasicType, Opt[Set[Set[InfVar->BasicType]]]])
: Opt[Set[Set[InfVar->BasicType]]] = (a.simp.toBasic, b.simp.toBasic) match
case (a, NegType(b)) => disjointImpl(a, b.toBasic)(prev)
case (v: InfVar, b) =>
val p = prev + (v -> NegType(b))
val k = v.state.lowerBounds.map(lb => in(lb.toBasic, b)(p))
if k.exists(_.isEmpty) then N
else S(k.flatten.flatten.toSet + Set(v -> NegType(b)))
case (ComposedType(p, q, true), b) =>
val u = in(p.toBasic, b)(prev)
val w = in(q.toBasic, b)(prev)
u.flatMap(u => w.map(u ++ _))
case (ComposedType(p, q, false), b) =>
(in(p.toBasic, b)(prev), in(q.toBasic, b)(prev)) match
case (N, w) => w
case (u, N) => u
case (S(u), S(w)) => S(u.flatMap(u => w.map(u ++ _)))
case (a: ClassLikeType, b: ClassLikeType) if a.name.uid === b.name.uid => S(Set.empty)
case (a: ClassLikeType, ComposedType(p, q, true)) =>
(in(a, p.toBasic)(prev), in(a, q.toBasic)(prev)) match
case (N, w) => w
case (u, N) => u
case (S(u), S(w)) => S(u.flatMap(u => w.map(u ++ _)))
case _ => N
def disjointImpl(a: BasicType, b: BasicType)(prev: Set[BasicType -> BasicType])
(using c: MutMap[BasicType -> BasicType, Opt[Set[Set[InfVar->BasicType]]]])
: Opt[Set[Set[InfVar->BasicType]]] =
if !prev.contains(a -> b) then c.getOrElseUpdate(a -> b, {
(a.simp.toBasic, b.simp.toBasic) match
case (Bot, _) | (_, Bot) => S(Set.empty)
case (ClassLikeType(a, _), ClassLikeType(b, _)) if a.uid =/= b.uid => S(Set.empty)
case (RcdType(u), RcdType(w)) if u.nonEmpty && w.nonEmpty =>
val um = u.toMap
val wm = w.toMap
val k = um.keySet & wm.keySet
val ur = RcdType((um -- k).toList)
val wr = RcdType((wm -- k).toList)
val ud = disjointImpl(ur, ur)(prev)
val wd = disjointImpl(wr, wr)(prev)
if ud.exists(_.isEmpty) || wd.exists(_.isEmpty) then
S(Set.empty)
else
val d = k.flatMap(k => disjointImpl(um(k).toBasic, wm(k).toBasic)(prev)) ++ Ls(ud, wd).flatten
if d.isEmpty then N
else if d.exists(_.isEmpty) then S(Set.empty)
else S(d.reduce((x, y) => y.flatMap(y => x.map(_ ++ y))))
case (_: RcdType, _: FunType) | (_: FunType, _: RcdType) => S(Set.empty)
case (_: ClassLikeType, _: FunType) | (_: FunType, _: ClassLikeType) => S(Set.empty)
case (ComposedType(p, q, true), _) =>
val u = disjointImpl(p.toBasic, b)(prev)
val w = disjointImpl(q.toBasic, b)(prev)
u.flatMap(u => w.map(u ++ _))
case (_, ComposedType(p, q, true)) =>
val u = disjointImpl(a, p.toBasic)(prev)
val w = disjointImpl(a, q.toBasic)(prev)
u.flatMap(u => w.map(u ++ _))
case (ComposedType(p, q, false), _) =>
(disjointImpl(p.toBasic, b)(prev), disjointImpl(q.toBasic, b)(prev)) match
case (N, w) => w
case (u, N) => u
case (S(u), S(w)) => S(u.flatMap(u => w.map(u ++ _)))
case (_, ComposedType(p, q, false)) =>
(disjointImpl(a, p.toBasic)(prev), disjointImpl(a, q.toBasic)(prev)) match
case (N, w) => w
case (u, N) => u
case (S(u), S(w)) => S(u.flatMap(u => w.map(u ++ _)))
case (v: InfVar, _) =>
val p = prev + (v -> b)
val k = v.state.lowerBounds.map(lb => disjointImpl(lb.toBasic, b)(p))
if k.exists(_.isEmpty) then N
else S(k.flatten.flatten.toSet + Set(v -> b))
case (_, v: InfVar) =>
val p = prev + (a -> v)
val k = v.state.lowerBounds.map(lb => disjointImpl(a, lb.toBasic)(p))
if k.exists(_.isEmpty) then N
else S(k.flatten.flatten.toSet + Set(v -> a))
case (NegType(t),_) => t.!.simp.toBasic match
case NegType(_) => N
case a => disjointImpl(a, b)(prev)
case (_, NegType(t)) => t.!.simp.toBasic match
case NegType(_) => N
case a => disjointImpl(a, b)(prev)
case _ => N
(a, b) match
case (NegType(a), b) => in(b, a.toBasic)(prev)
case (a, NegType(b)) => in(a, b.toBasic)(prev)
case _ => (a.simp.toBasic, b.simp.toBasic) match
case (Bot, _) | (_, Bot) => S(Set.empty)
case (ClassLikeType(a, _), ClassLikeType(b, _)) if a.uid =/= b.uid => S(Set.empty)
case (RcdType(u), RcdType(w)) if u.nonEmpty && w.nonEmpty =>
val um = u.toMap
val wm = w.toMap
val k = um.keySet & wm.keySet
val ur = RcdType((um -- k).toList)
val wr = RcdType((wm -- k).toList)
val ud = disjointImpl(ur, ur)(prev)
val wd = disjointImpl(wr, wr)(prev)
if ud.exists(_.isEmpty) || wd.exists(_.isEmpty) then
S(Set.empty)
else
val d = k.flatMap(k => disjointImpl(um(k).toBasic, wm(k).toBasic)(prev)) ++ Ls(ud, wd).flatten
if d.isEmpty then N
else if d.exists(_.isEmpty) then S(Set.empty)
else S(d.reduce((x, y) => y.flatMap(y => x.map(_ ++ y))))
case (_: RcdType, _: FunType) | (_: FunType, _: RcdType) => S(Set.empty)
case (_: ClassLikeType, _: FunType) | (_: FunType, _: ClassLikeType) => S(Set.empty)
case (ComposedType(p, q, true), b) =>
val u = disjointImpl(p.toBasic, b)(prev)
val w = disjointImpl(q.toBasic, b)(prev)
u.flatMap(u => w.map(u ++ _))
case (a, ComposedType(p, q, true)) =>
val u = disjointImpl(a, p.toBasic)(prev)
val w = disjointImpl(a, q.toBasic)(prev)
u.flatMap(u => w.map(u ++ _))
case (ComposedType(p, q, false), b) =>
(disjointImpl(p.toBasic, b)(prev), disjointImpl(q.toBasic, b)(prev)) match
case (N, w) => w
case (u, N) => u
case (S(u), S(w)) => S(u.flatMap(u => w.map(u ++ _)))
case (a, ComposedType(p, q, false)) =>
(disjointImpl(a, p.toBasic)(prev), disjointImpl(a, q.toBasic)(prev)) match
case (N, w) => w
case (u, N) => u
case (S(u), S(w)) => S(u.flatMap(u => w.map(u ++ _)))
case (v: InfVar, b) =>
val p = prev + (v -> b)
val k = v.state.lowerBounds.map(lb => disjointImpl(lb.toBasic, b)(p))
if k.exists(_.isEmpty) then N
else S(k.flatten.flatten.toSet + Set(v -> b))
case (a, v: InfVar) =>
val p = prev + (a -> v)
val k = v.state.lowerBounds.map(lb => disjointImpl(a, lb.toBasic)(p))
if k.exists(_.isEmpty) then N
else S(k.flatten.flatten.toSet + Set(v -> a))
case _ => N
}) else S(Set.empty)
def disjoint(a: Type, b: Type): Opt[Set[Set[InfVar->BasicType]]] =
disjointImpl(a.simp.toBasic, b.simp.toBasic)(Set.empty)(using c = MutMap.empty)
disjointImpl(a.toBasic, b.toBasic)(Set.empty)(using c = MutMap.empty)


// * Poly types can not be used as type arguments
Expand Down
13 changes: 7 additions & 6 deletions hkmc2/shared/src/test/mlscript/bbml/bbBasics.mls
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ foofoo
//│ Type: ['x, 'res, 'eff] -> 'x ->{'eff} 'res
//│ Where:
//│ 'x#Int ∨ ∧ ⊥<:'eff ∧ Int ∧ 'x<:Int ∧ Int<:Int ∧ Str<:'res
//│ 'x#¬Int ∨ ∧ 'x<:⊥

fun foofoo(x) =
let t = x + 1 in "foo"
Expand All @@ -177,7 +178,7 @@ f().Printer#f(42)
:e
f().Printer#f("oops")
//│ ╔══[ERROR] Type error in string literal with expected type 'T
//│ ║ l.178: f().Printer#f("oops")
//│ ║ l.179: f().Printer#f("oops")
//│ ║ ^^^^^^
//│ ╟── because: cannot constrain Str <: 'T
//│ ╟── because: cannot constrain Str <: 'T
Expand All @@ -196,7 +197,7 @@ let ip = new Printer(foofoo) in ip.Printer#f(42)
:e
let ip = new Printer(foofoo) in ip.Printer#f("42")
//│ ╔══[ERROR] Type error in string literal with expected type 'T
//│ ║ l.197: let ip = new Printer(foofoo) in ip.Printer#f("42")
//│ ║ l.198: let ip = new Printer(foofoo) in ip.Printer#f("42")
//│ ║ ^^^^
//│ ╟── because: cannot constrain Str <: 'T
//│ ╟── because: cannot constrain Str <: 'T
Expand Down Expand Up @@ -227,7 +228,7 @@ let tf = new TFun(inc) in tf.TFun#f(1)
:e
let tf = new TFun(inc) in tf.TFun#f("1")
//│ ╔══[ERROR] Type error in string literal with expected type 'T
//│ ║ l.228: let tf = new TFun(inc) in tf.TFun#f("1")
//│ ║ l.229: let tf = new TFun(inc) in tf.TFun#f("1")
//│ ║ ^^^
//│ ╟── because: cannot constrain Str <: 'T
//│ ╟── because: cannot constrain Str <: 'T
Expand Down Expand Up @@ -283,10 +284,10 @@ if 1 is Foo then 0 else 1
fun test(x) =
if x is Int then x + 1 else 0
test
//│ Type: ['x, 'res, 'eff] -> 'x ->{'eff} 'res
//│ Type: ['x, 'res, 'eff, 'eff1] -> 'x ->{'eff} 'res
//│ Where:
//│ 'x#Int ∨ ∧ ⊥<:'eff ∧ Int ∧ 'x<:Int ∧ Int<:Int ∧ Int<:'res
//│ Int <: 'res
//│ 'x#¬Int ∨ ∧ ⊥<:'eff1 ∧ Int<:'res

test(1)
//│ Type: Int
Expand Down Expand Up @@ -357,7 +358,7 @@ throw new Error("oops")
:e
throw 42
//│ ╔══[ERROR] Type error in throw
//│ ║ l.358: throw 42
//│ ║ l.359: throw 42
//│ ║ ^^
//│ ╙── because: cannot constrain Int <: Error
//│ Type: ⊥
Expand Down
Loading
Loading