Skip to content

Commit 9707a3f

Browse files
committed
disjunctive subtyping
1 parent 63d1c8d commit 9707a3f

File tree

4 files changed

+94
-13
lines changed

4 files changed

+94
-13
lines changed

hkmc2/shared/src/main/scala/hkmc2/bbml/ConstraintSolver.scala

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
9797
cctx.nest(bd -> v) givenIn:
9898
v.state.lowerBounds ::= bd
9999
v.state.upperBounds.foreach(ub => constrainImpl(bd, ub))
100+
v.state.disjsub.foreach:(d)=>
101+
Type.disjoint(d.disjoint(v),bd.toBasic.simp)match
102+
case N=>
103+
d.remove(v)
104+
if d.disjoint.isEmpty then
105+
d.dss.foreach(_.commit())
106+
d.cs.foreach((a,b)=>constrainImpl(a,b))
107+
case S(k)=>d.disjoint++=k
100108
case Conj(i, u, Nil) => (conj.i, conj.u) match
101109
case (_, Union(N, Nil)) =>
102110
// raise(ErrorReport(msg"Cannot solve ${conj.i.toString()} ∧ ¬⊥" -> N :: Nil))
@@ -107,16 +115,37 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
107115
constrainArgs(ta1, ta2)
108116
else constrainConj(Conj(conj.i, Union(f, rest), Nil))
109117
case (int: Inter, Union(f, _ :: rest)) => constrainConj(Conj(int, Union(f, rest), Nil))
110-
case (Inter(S(FunType(args1, ret1, eff1))), Union(S(FunType(args2, ret2, eff2)), Nil)) =>
118+
case (Inter(S(FunType(args1, ret1, eff1)::Nil)), Union(S(FunType(args2, ret2, eff2)), Nil)) =>
111119
if args1.length =/= args2.length then
112120
// raise(ErrorReport(msg"Cannot constrain ${conj.i.toString()} <: ${conj.u.toString()}" -> N :: Nil))
113121
cctx.err
114122
else
115-
args1.zip(args2).foreach {
116-
case (a1, a2) => constrainImpl(a2, a1)
117-
}
118-
constrainImpl(ret1, ret2)
119-
constrainImpl(eff1, eff2)
123+
val k=args2.flatMap(x=>Type.disjoint(x,x))
124+
if k.isEmpty then
125+
args1.zip(args2).foreach {
126+
case (a1, a2) => constrainImpl(a2, a1)
127+
}
128+
constrainImpl(ret1, ret2)
129+
constrainImpl(eff1, eff2)
130+
else if !k.contains(Nil)then
131+
DisjSub(mutable.Map.from(k.flatten),Nil,(ret1,ret2)::(eff1,eff2)::args2.zip(args1)).commit()
132+
case (Inter(S(fs:Ls[FunType])), Union(S(FunType(args2, ret2, eff2)), Nil)) =>
133+
val f=fs.filter(_.args.length===args2.length)
134+
val args=f.map(_.args).transpose
135+
val k=args2.flatMap(x=>Type.disjoint(x,x))
136+
if!k.contains(Nil)then
137+
// assume distinguished by the first arg
138+
constrainImpl(args2.head,args.head.foldLeft(Bot:Type)(_|_))
139+
args.head.iterator.zip(f).foreach:(a,b)=>
140+
val s=args2.zip(b.args).tail
141+
Type.disjoint(args2.head.toBasic.simp,a.toBasic.simp)match
142+
case N =>
143+
s.foreach((x,y)=>constrainImpl(x,y))
144+
constrainImpl(b.ret,ret2)
145+
constrainImpl(b.eff,eff2)
146+
case S(k) =>
147+
val ds=DisjSub(mutable.Map.from(k),Nil,(b.ret,ret2)::(b.eff,eff2)::s)
148+
ds.commit()
120149
case _ =>
121150
// raise(ErrorReport(msg"Cannot solve ${conj.i.toString()} <: ${conj.u.toString()}" -> N :: Nil))
122151
cctx.err

hkmc2/shared/src/main/scala/hkmc2/bbml/NormalForm.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ object Conj:
7676
}){}
7777
lazy val empty: Conj = Conj(Inter.empty, Union.empty, Nil)
7878
def mkVar(v: InfVar, pol: Bool) = Conj(Inter.empty, Union.empty, (v, pol) :: Nil)
79-
def mkInter(inter: ClassLikeType | FunType) =
79+
def mkInter(inter: ClassLikeType | Ls[FunType]) =
8080
Conj(Inter(S(inter)), Union.empty, Nil)
8181
def mkUnion(union: ClassLikeType | FunType) =
8282
Conj(Inter.empty, union match {
@@ -85,18 +85,23 @@ object Conj:
8585
}, Nil)
8686

8787
// * Some(ClassType) -> C[in D_i out D_i], Some(FunType) -> D_1 ->{D_2} D_3, None -> Top
88-
final case class Inter(v: Opt[ClassLikeType | FunType]) extends NormalForm:
88+
final case class Inter(v: Opt[ClassLikeType | Ls[FunType]]) extends NormalForm:
8989
def isTop: Bool = v.isEmpty
9090
def merge(other: Inter): Option[Inter] = (v, other.v) match
9191
case (S(ClassLikeType(cls1, targs1)), S(ClassLikeType(cls2, targs2))) if cls1.uid === cls2.uid =>
9292
S(Inter(S(ClassLikeType(cls1, targs1.lazyZip(targs2).map(_ & _)))))
9393
case (S(_: ClassLikeType), S(_: ClassLikeType)) => N
94-
case (S(FunType(a1, r1, e1)), S(FunType(a2, r2, e2))) =>
95-
S(Inter(S(FunType(a1.lazyZip(a2).map(_ | _), r1 & r2, e1 & e2))))
94+
// case (S(FunType(a1, r1, e1)), S(FunType(a2, r2, e2))) =>
95+
// S(Inter(S(FunType(a1.lazyZip(a2).map(_ | _), r1 & r2, e1 & e2))))
96+
case (S(a:Ls[FunType]),S(b:Ls[FunType]))=>S(Inter(S(a++b)))
9697
case (S(v), N) => S(Inter(S(v)))
9798
case (N, v) => S(Inter(v))
9899
case _ => N
99-
def toBasic: BasicType = v.getOrElse(Top)
100+
def toBasic: BasicType = v match
101+
case N=>Top
102+
case S(x:ClassLikeType)=>x
103+
case S(Nil)=>Top
104+
case S(x:Ls[FunType])=>x.reduce[Type](_&_).toBasic
100105
def toDnf(using TL): Disj = Disj(Conj(this, Union(N, Nil), Nil) :: Nil)
101106
override def show(using Scope): Str =
102107
toBasic.show
@@ -182,7 +187,7 @@ object NormalForm:
182187
case Bot => Disj.bot
183188
case v: InfVar => Disj(Conj.mkVar(v, true) :: Nil)
184189
case ct: ClassLikeType => Disj(Conj.mkInter(ct.toNorm) :: Nil)
185-
case ft: FunType => Disj(Conj.mkInter(ft.toNorm) :: Nil)
190+
case ft: FunType => Disj(Conj.mkInter(Ls(ft.toNorm)) :: Nil)
186191
case ComposedType(lhs, rhs, pol) =>
187192
if pol then union(dnf(lhs), dnf(rhs)) else inter(dnf(lhs), dnf(rhs))
188193
case NegType(ty) => neg(ty)

hkmc2/shared/src/main/scala/hkmc2/bbml/types.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import mlscript.utils.*, shorthands.*
55
import syntax.*
66
import semantics.*, semantics.Term.*
77
import utils.*
8-
import scala.collection.mutable.{Set => MutSet}
8+
import scala.collection.mutable.{Set => MutSet, Map => MutMap}
99
import utils.Scope
1010
import Elaborator.State
1111

@@ -281,6 +281,13 @@ object Type:
281281
then lhs | rhs
282282
else lhs & rhs
283283
def mkNegType(ty: Type): Type = ty.!
284+
def disjoint(a:Type,b:Type):Opt[Ls[InfVar->BasicType]]=(a,b)match
285+
case (Bot,_)|(_,Bot)=>S(Nil)
286+
case (ClassLikeType(a,_),ClassLikeType(b,_))if a.uid=/=b.uid=>S(Nil)
287+
case (a:ClassLikeType,v:InfVar)=>S(Ls(v->a))
288+
case (v:InfVar,a:ClassLikeType)=>S(Ls(v->a))
289+
case _=>N
290+
284291

285292
// * Poly types can not be used as type arguments
286293
case class PolyType(tvs: Ls[InfVar], outer: Opt[InfVar], body: GeneralType) extends GeneralType:
@@ -388,3 +395,10 @@ case class PolyFunType(args: Ls[GeneralType], ret: GeneralType, eff: Type) exten
388395
class VarState:
389396
var lowerBounds: Ls[Type] = Nil
390397
var upperBounds: Ls[Type] = Nil
398+
val disjsub: MutSet[DisjSub] = MutSet.empty
399+
400+
case class DisjSub(disjoint:MutMap[InfVar,BasicType],dss:Ls[DisjSub],cs:Ls[Type->Type]):
401+
def commit()=disjoint.keys.foreach(_.state.disjsub+=this)
402+
def remove(v:InfVar)=
403+
v.state.disjsub-=this
404+
disjoint-=v
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
:bbml
2+
//│ Type: ⊤
3+
4+
//│ Type: ⊤
5+
6+
fun andt(x)=x&&true
7+
fun k(f:Nothing->Bool)=1
8+
fun ap(f)=x=>f(x)
9+
//│ Type: ⊤
10+
11+
k(andt)
12+
//│ Type: Int
13+
14+
k(ap(andt))
15+
//│ Type: Int
16+
17+
fun id:(Int->Int)&(Bool->Bool)
18+
//│ Type: ⊤
19+
20+
id(1)
21+
//│ Type: Int
22+
23+
fun ap1(f)=f(1)
24+
ap1(id)
25+
//│ Type: Int
26+
27+
ap(id)(1)
28+
//│ Type: Int
29+
30+
:todo
31+
x=>id(x)
32+
//│ Type: (Int ∨ Bool) ->{⊥} ⊥
33+

0 commit comments

Comments
 (0)