Skip to content

Commit a0e5ab9

Browse files
committed
Monomorphize existing definitions
1 parent e18b705 commit a0e5ab9

File tree

1 file changed

+128
-28
lines changed
  • effekt/shared/src/main/scala/effekt/core

1 file changed

+128
-28
lines changed

effekt/shared/src/main/scala/effekt/core/Mono.scala

Lines changed: 128 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,29 @@ object Mono extends Phase[CoreTransformed, CoreTransformed] {
1515
core match {
1616
case ModuleDecl(path, includes, declarations, externs, definitions, exports) => {
1717
// Find constraints in the definitions
18-
val constraints = findConstraints(definitions)(using new MonoContext)
19-
println("Constraints")
20-
constraints.foreach(c => println(c))
21-
println()
22-
23-
val solved = solveConstraints(constraints)
24-
println("Solved")
25-
solved.foreach(println)
26-
println()
27-
28-
18+
val constraints = findConstraints(definitions)(using new MonoFindContext)
19+
// println("Constraints")
20+
// constraints.foreach(c => println(c))
21+
// println()
22+
23+
// Solve collected constraints
24+
val solution = solveConstraints(constraints)
25+
// println("Solved")
26+
// solution.foreach(println)
27+
// println()
28+
29+
// Monomorphize existing definitions
30+
var monoNames: MonoNames = Map.empty
31+
solution.foreach((funId, targs) =>
32+
targs.foreach(vb =>
33+
monoNames += ((funId, vb) -> freshMonoName(funId, vb))
34+
)
35+
)
36+
37+
val monoDefs = monomorphize(definitions)(using MonoContext(solution, monoNames))
38+
// monoDefs.foreach(defn => println(util.show(defn)))
39+
val newModuleDecl = ModuleDecl(path, includes, declarations, externs, monoDefs, exports)
40+
return Some(CoreTransformed(source, tree, mod, newModuleDecl))
2941
}
3042
}
3143
}
@@ -39,6 +51,7 @@ case class Constraint(lower: Vector[TypeArg], upper: FunctionId)
3951
type Constraints = List[Constraint]
4052

4153
type Solution = Map[FunctionId, Set[Vector[TypeArg.Base]]]
54+
type MonoNames = Map[(FunctionId, Vector[TypeArg.Base]), FunctionId]
4255

4356
enum TypeArg {
4457
case Base(val tpe: Id)
@@ -48,44 +61,48 @@ enum TypeArg {
4861
// Type Id -> Var
4962
type TypeParams = Map[Id, TypeArg.Var]
5063

51-
class MonoContext {
64+
class MonoFindContext {
5265
var typingContext: TypeParams = Map()
5366

5467
def extendTypingContext(tparam: Id, index: Int, functionId: FunctionId) =
5568
typingContext += (tparam -> TypeArg.Var(functionId, index))
5669
}
5770

58-
def findConstraints(definitions: List[Toplevel])(using MonoContext): Constraints =
71+
case class MonoContext(solution: Solution, names: MonoNames) {
72+
var replacementTparams: Map[Id, TypeArg.Base] = Map.empty
73+
}
74+
75+
def findConstraints(definitions: List[Toplevel])(using MonoFindContext): Constraints =
5976
definitions flatMap findConstraints
6077

61-
def findConstraints(definition: Toplevel)(using ctx: MonoContext): Constraints = definition match
78+
def findConstraints(definition: Toplevel)(using ctx: MonoFindContext): Constraints = definition match
6279
case Toplevel.Def(id, BlockLit(tparams, cparams, vparams, bparams, body)) =>
6380
tparams.zipWithIndex.foreach(ctx.extendTypingContext(_, _, id))
6481
findConstraints(body)
6582
case Toplevel.Def(id, block) => ???
6683
case Toplevel.Val(id, tpe, binding) => ???
6784

68-
def findConstraints(block: Block)(using ctx: MonoContext): Constraints = block match
85+
def findConstraints(block: Block)(using ctx: MonoFindContext): Constraints = block match
6986
case BlockVar(id, annotatedTpe, annotatedCapt) => ???
7087
case BlockLit(tparams, cparams, vparams, bparams, body) => ???
7188
case Unbox(pure) => ???
7289
case New(impl) => ???
7390

74-
def findConstraints(stmt: Stmt)(using ctx: MonoContext): Constraints = stmt match
91+
def findConstraints(stmt: Stmt)(using ctx: MonoFindContext): Constraints = stmt match
7592
case Let(id, annotatedTpe, binding, body) => findConstraints(binding) ++ findConstraints(body)
7693
case Return(expr) => findConstraints(expr)
7794
case Val(id, annotatedTpe, binding, body) => findConstraints(binding) ++ findConstraints(body)
7895
case App(callee: BlockVar, targs, vargs, bargs) => List(Constraint(targs.map(findId).toVector, callee.id))
7996
case If(cond, thn, els) => findConstraints(cond) ++ findConstraints(thn) ++ findConstraints(els)
8097
case o => println(o); ???
8198

82-
def findConstraints(expr: Expr)(using ctx: MonoContext): Constraints = expr match
99+
def findConstraints(expr: Expr)(using ctx: MonoFindContext): Constraints = expr match
83100
case DirectApp(b, List(), vargs, bargs) => List.empty
84101
case ValueVar(id, annotatedType) => List.empty
85102
case Literal(value, annotatedType) => List.empty
86103
case o => println(o); ???
87104

88-
def findId(vt: ValueType)(using ctx: MonoContext): TypeArg = vt match
105+
def findId(vt: ValueType)(using ctx: MonoFindContext): TypeArg = vt match
89106
case ValueType.Boxed(tpe, capt) => ???
90107
case ValueType.Data(name, targs) => TypeArg.Base(name)
91108
case ValueType.Var(name) => ctx.typingContext(name)
@@ -113,7 +130,6 @@ def solveConstraints(constraints: Constraints): Solution =
113130
val funSolved = solved.getOrElse(funId, solveConstraints(funId))
114131
val posArgs = funSolved.map(v => v(pos))
115132
l = posArgs.zipWithIndex.map((base, ind) => listFromIndex(ind) :+ base).toList
116-
println(l)
117133
})
118134
nbs ++= l
119135
)
@@ -124,6 +140,99 @@ def solveConstraints(constraints: Constraints): Solution =
124140
def productAppend[A](ls: List[List[A]], rs: List[A]): List[List[A]] =
125141
rs.flatMap(r => ls.map(l => l :+ r))
126142

143+
def monomorphize(definitions: List[Toplevel])(using ctx: MonoContext): List[Toplevel] =
144+
var newDefinitions: List[Toplevel] = List.empty
145+
definitions.foreach(definition => newDefinitions ++= monomorphize(definition))
146+
newDefinitions
147+
148+
def monomorphize(toplevel: Toplevel)(using ctx: MonoContext): List[Toplevel] = toplevel match
149+
case Toplevel.Def(id, BlockLit(List(), cparams, vparams, bparams, body)) =>
150+
List(Toplevel.Def(id, BlockLit(List.empty, cparams, vparams, bparams, monomorphize(body))))
151+
case Toplevel.Def(id, BlockLit(tparams, cparams, vparams, bparams, body)) =>
152+
val monoTypes = ctx.solution(id).toList
153+
monoTypes.map(baseTypes =>
154+
val replacementTparams = tparams.zip(baseTypes).toMap
155+
ctx.replacementTparams ++= replacementTparams
156+
Toplevel.Def(ctx.names(id, baseTypes), BlockLit(List.empty, cparams, vparams map monomorphize, bparams map monomorphize, monomorphize(body)))
157+
)
158+
case Toplevel.Def(id, block) => ???
159+
case Toplevel.Val(id, tpe, binding) => ???
160+
161+
def monomorphize(block: Block)(using ctx: MonoContext): Block = block match
162+
case b: BlockVar => monomorphize(b)
163+
case o => println(o); ???
164+
165+
def monomorphize(blockVar: BlockVar, replacementId: FunctionId)(using ctx: MonoContext): BlockVar = blockVar match
166+
case BlockVar(id, BlockType.Function(List(), cparams, vparams, bparams, result), annotatedCapt) => blockVar
167+
// TODO: What is in annotated captures. Does it need to be handled?
168+
case BlockVar(id, BlockType.Function(tparams, cparams, vparams, bparams, result), annotatedCapt) =>
169+
val monoAnnotatedTpe = BlockType.Function(List.empty, cparams, vparams map monomorphize, bparams map monomorphize, monomorphize(result))
170+
BlockVar(replacementId, monoAnnotatedTpe, annotatedCapt)
171+
case o => ???
172+
173+
def monomorphize(stmt: Stmt)(using ctx: MonoContext): Stmt = stmt match
174+
case Return(expr) => Return(monomorphize(expr))
175+
case Val(id, annotatedTpe, binding, body) => Val(id, monomorphize(annotatedTpe), monomorphize(binding), monomorphize(body))
176+
case App(callee: BlockVar, targs, vargs, bargs) =>
177+
val replacementId = replacementIdFromTargs(callee.id, targs)
178+
App(monomorphize(callee, replacementId), List.empty, vargs map monomorphize, bargs map monomorphize)
179+
case Let(id, annotatedTpe, binding, body) => Let(id, monomorphize(annotatedTpe), monomorphize(binding), monomorphize(body))
180+
case If(cond, thn, els) => If(monomorphize(cond), monomorphize(thn), monomorphize(els))
181+
case o => println(o); ???
182+
183+
def monomorphize(expr: Expr)(using ctx: MonoContext): Expr = expr match
184+
case DirectApp(b, targs, vargs, bargs) =>
185+
val replacementId = replacementIdFromTargs(b.id, targs)
186+
DirectApp(monomorphize(b, replacementId), List.empty, vargs map monomorphize, bargs map monomorphize)
187+
case o => println(o); ???
188+
189+
def monomorphize(pure: Pure)(using ctx: MonoContext): Pure = pure match
190+
case ValueVar(id, annotatedType) => ValueVar(id, monomorphize(annotatedType))
191+
case PureApp(b, targs, vargs) =>
192+
val replacementId = replacementIdFromTargs(b.id, targs)
193+
PureApp(monomorphize(b, replacementId), List.empty, vargs map monomorphize)
194+
case Literal(value, annotatedType) => Literal(value, monomorphize(annotatedType))
195+
case o => println(o); ???
196+
197+
def monomorphize(valueParam: ValueParam)(using ctx: MonoContext): ValueParam = valueParam match
198+
case ValueParam(id, tpe) => ValueParam(id, monomorphize(tpe))
199+
200+
def monomorphize(blockParam: BlockParam)(using ctx: MonoContext): BlockParam = blockParam match
201+
// TODO: Same question as in block
202+
case BlockParam(id, tpe, capt) => BlockParam(id, monomorphize(tpe), capt)
203+
204+
def monomorphize(blockType: BlockType)(using ctx: MonoContext): BlockType = blockType match
205+
case BlockType.Function(tparams, cparams, vparams, bparams, result) =>
206+
BlockType.Function(List.empty, cparams, vparams map monomorphize, bparams map monomorphize, monomorphize(result))
207+
case o => println(o); ???
208+
209+
def monomorphize(valueType: ValueType)(using ctx: MonoContext): ValueType = valueType match
210+
case ValueType.Var(name) => ValueType.Var(ctx.replacementTparams(name).tpe)
211+
case ValueType.Data(name, targs) => ValueType.Data(name, targs)
212+
case o => println(o); ???
213+
214+
var monoCounter = 0
215+
def freshMonoName(baseId: Id, tpe: TypeArg.Base): Id =
216+
monoCounter += 1
217+
Id(baseId.name.name + tpe.tpe.name.name + monoCounter)
218+
219+
def freshMonoName(baseId: Id, tpes: Vector[TypeArg.Base]): Id =
220+
if (tpes.length == 0) return baseId
221+
222+
monoCounter += 1
223+
val tpesString = tpes.map(tpe => tpe.tpe.name.name).mkString
224+
Id(baseId.name.name + tpesString + monoCounter)
225+
226+
def replacementIdFromTargs(id: FunctionId, targs: List[ValueType])(using ctx: MonoContext): FunctionId =
227+
if (targs.isEmpty) return id
228+
var baseTypes: List[TypeArg.Base] = List.empty
229+
targs.foreach({
230+
case ValueType.Data(name, targs) => baseTypes :+= TypeArg.Base(name)
231+
case ValueType.Var(name) => baseTypes :+= ctx.replacementTparams(name)
232+
case ValueType.Boxed(tpe, capt) =>
233+
})
234+
ctx.names((id, baseTypes.toVector))
235+
127236
// Old stuff
128237

129238
// type PolyConstraints = Map[Id, Set[PolyType]]
@@ -132,16 +241,7 @@ def productAppend[A](ls: List[List[A]], rs: List[A]): List[List[A]] =
132241

133242
// class MonoContext(val solvedConstraints: PolyConstraintsSolved, var monoDefs: Map[Id, Map[List[PolyType.Base], (Id, Block)]] = Map.empty)
134243

135-
// var monoCounter = 0
136-
// def freshMonoName(baseId: Id, tpe: PolyType.Base): Id =
137-
// monoCounter += 1
138-
// Id(baseId.name.name + tpe.tpe.name.name + monoCounter)
139244

140-
// def freshMonoName(baseId: Id, tpes: List[PolyType.Base]): Id =
141-
// monoCounter += 1
142-
// var tpesString = ""
143-
// tpes.foreach(tpe => tpesString += tpe.tpe.name.name)
144-
// Id(baseId.name.name + tpesString + monoCounter)
145245

146246
// // TODO: The following two are awful and surely doing redundant work.
147247
// def generator(xs: List[Set[PolyConstraintSingle]]): List[Set[PolyConstraintSingle]] = xs.foldRight(List(Set.empty)) { (next, combinations) =>

0 commit comments

Comments
 (0)