@@ -15,17 +15,29 @@ object Mono extends Phase[CoreTransformed, CoreTransformed] {
15
15
core match {
16
16
case ModuleDecl (path, includes, declarations, externs, definitions, exports) => {
17
17
// Find constraints in the definitions
18
- val constraints = findConstraints(definitions)(using new MonoContext )
19
- println(" Constraints" )
20
- constraints.foreach(c => println(c))
21
- println()
22
-
23
- val solved = solveConstraints(constraints)
24
- println(" Solved" )
25
- solved.foreach(println)
26
- println()
27
-
28
-
18
+ val constraints = findConstraints(definitions)(using new MonoFindContext )
19
+ // println("Constraints")
20
+ // constraints.foreach(c => println(c))
21
+ // println()
22
+
23
+ // Solve collected constraints
24
+ val solution = solveConstraints(constraints)
25
+ // println("Solved")
26
+ // solution.foreach(println)
27
+ // println()
28
+
29
+ // Monomorphize existing definitions
30
+ var monoNames : MonoNames = Map .empty
31
+ solution.foreach((funId, targs) =>
32
+ targs.foreach(vb =>
33
+ monoNames += ((funId, vb) -> freshMonoName(funId, vb))
34
+ )
35
+ )
36
+
37
+ val monoDefs = monomorphize(definitions)(using MonoContext (solution, monoNames))
38
+ // monoDefs.foreach(defn => println(util.show(defn)))
39
+ val newModuleDecl = ModuleDecl (path, includes, declarations, externs, monoDefs, exports)
40
+ return Some (CoreTransformed (source, tree, mod, newModuleDecl))
29
41
}
30
42
}
31
43
}
@@ -39,6 +51,7 @@ case class Constraint(lower: Vector[TypeArg], upper: FunctionId)
39
51
type Constraints = List [Constraint ]
40
52
41
53
type Solution = Map [FunctionId , Set [Vector [TypeArg .Base ]]]
54
+ type MonoNames = Map [(FunctionId , Vector [TypeArg .Base ]), FunctionId ]
42
55
43
56
enum TypeArg {
44
57
case Base (val tpe : Id )
@@ -48,44 +61,48 @@ enum TypeArg {
48
61
// Type Id -> Var
49
62
type TypeParams = Map [Id , TypeArg .Var ]
50
63
51
- class MonoContext {
64
+ class MonoFindContext {
52
65
var typingContext : TypeParams = Map ()
53
66
54
67
def extendTypingContext (tparam : Id , index : Int , functionId : FunctionId ) =
55
68
typingContext += (tparam -> TypeArg .Var (functionId, index))
56
69
}
57
70
58
- def findConstraints (definitions : List [Toplevel ])(using MonoContext ): Constraints =
71
+ case class MonoContext (solution : Solution , names : MonoNames ) {
72
+ var replacementTparams : Map [Id , TypeArg .Base ] = Map .empty
73
+ }
74
+
75
+ def findConstraints (definitions : List [Toplevel ])(using MonoFindContext ): Constraints =
59
76
definitions flatMap findConstraints
60
77
61
- def findConstraints (definition : Toplevel )(using ctx : MonoContext ): Constraints = definition match
78
+ def findConstraints (definition : Toplevel )(using ctx : MonoFindContext ): Constraints = definition match
62
79
case Toplevel .Def (id, BlockLit (tparams, cparams, vparams, bparams, body)) =>
63
80
tparams.zipWithIndex.foreach(ctx.extendTypingContext(_, _, id))
64
81
findConstraints(body)
65
82
case Toplevel .Def (id, block) => ???
66
83
case Toplevel .Val (id, tpe, binding) => ???
67
84
68
- def findConstraints (block : Block )(using ctx : MonoContext ): Constraints = block match
85
+ def findConstraints (block : Block )(using ctx : MonoFindContext ): Constraints = block match
69
86
case BlockVar (id, annotatedTpe, annotatedCapt) => ???
70
87
case BlockLit (tparams, cparams, vparams, bparams, body) => ???
71
88
case Unbox (pure) => ???
72
89
case New (impl) => ???
73
90
74
- def findConstraints (stmt : Stmt )(using ctx : MonoContext ): Constraints = stmt match
91
+ def findConstraints (stmt : Stmt )(using ctx : MonoFindContext ): Constraints = stmt match
75
92
case Let (id, annotatedTpe, binding, body) => findConstraints(binding) ++ findConstraints(body)
76
93
case Return (expr) => findConstraints(expr)
77
94
case Val (id, annotatedTpe, binding, body) => findConstraints(binding) ++ findConstraints(body)
78
95
case App (callee : BlockVar , targs, vargs, bargs) => List (Constraint (targs.map(findId).toVector, callee.id))
79
96
case If (cond, thn, els) => findConstraints(cond) ++ findConstraints(thn) ++ findConstraints(els)
80
97
case o => println(o); ???
81
98
82
- def findConstraints (expr : Expr )(using ctx : MonoContext ): Constraints = expr match
99
+ def findConstraints (expr : Expr )(using ctx : MonoFindContext ): Constraints = expr match
83
100
case DirectApp (b, List (), vargs, bargs) => List .empty
84
101
case ValueVar (id, annotatedType) => List .empty
85
102
case Literal (value, annotatedType) => List .empty
86
103
case o => println(o); ???
87
104
88
- def findId (vt : ValueType )(using ctx : MonoContext ): TypeArg = vt match
105
+ def findId (vt : ValueType )(using ctx : MonoFindContext ): TypeArg = vt match
89
106
case ValueType .Boxed (tpe, capt) => ???
90
107
case ValueType .Data (name, targs) => TypeArg .Base (name)
91
108
case ValueType .Var (name) => ctx.typingContext(name)
@@ -113,7 +130,6 @@ def solveConstraints(constraints: Constraints): Solution =
113
130
val funSolved = solved.getOrElse(funId, solveConstraints(funId))
114
131
val posArgs = funSolved.map(v => v(pos))
115
132
l = posArgs.zipWithIndex.map((base, ind) => listFromIndex(ind) :+ base).toList
116
- println(l)
117
133
})
118
134
nbs ++= l
119
135
)
@@ -124,6 +140,99 @@ def solveConstraints(constraints: Constraints): Solution =
124
140
def productAppend [A ](ls : List [List [A ]], rs : List [A ]): List [List [A ]] =
125
141
rs.flatMap(r => ls.map(l => l :+ r))
126
142
143
+ def monomorphize (definitions : List [Toplevel ])(using ctx : MonoContext ): List [Toplevel ] =
144
+ var newDefinitions : List [Toplevel ] = List .empty
145
+ definitions.foreach(definition => newDefinitions ++= monomorphize(definition))
146
+ newDefinitions
147
+
148
+ def monomorphize (toplevel : Toplevel )(using ctx : MonoContext ): List [Toplevel ] = toplevel match
149
+ case Toplevel .Def (id, BlockLit (List (), cparams, vparams, bparams, body)) =>
150
+ List (Toplevel .Def (id, BlockLit (List .empty, cparams, vparams, bparams, monomorphize(body))))
151
+ case Toplevel .Def (id, BlockLit (tparams, cparams, vparams, bparams, body)) =>
152
+ val monoTypes = ctx.solution(id).toList
153
+ monoTypes.map(baseTypes =>
154
+ val replacementTparams = tparams.zip(baseTypes).toMap
155
+ ctx.replacementTparams ++= replacementTparams
156
+ Toplevel .Def (ctx.names(id, baseTypes), BlockLit (List .empty, cparams, vparams map monomorphize, bparams map monomorphize, monomorphize(body)))
157
+ )
158
+ case Toplevel .Def (id, block) => ???
159
+ case Toplevel .Val (id, tpe, binding) => ???
160
+
161
+ def monomorphize (block : Block )(using ctx : MonoContext ): Block = block match
162
+ case b : BlockVar => monomorphize(b)
163
+ case o => println(o); ???
164
+
165
+ def monomorphize (blockVar : BlockVar , replacementId : FunctionId )(using ctx : MonoContext ): BlockVar = blockVar match
166
+ case BlockVar (id, BlockType .Function (List (), cparams, vparams, bparams, result), annotatedCapt) => blockVar
167
+ // TODO: What is in annotated captures. Does it need to be handled?
168
+ case BlockVar (id, BlockType .Function (tparams, cparams, vparams, bparams, result), annotatedCapt) =>
169
+ val monoAnnotatedTpe = BlockType .Function (List .empty, cparams, vparams map monomorphize, bparams map monomorphize, monomorphize(result))
170
+ BlockVar (replacementId, monoAnnotatedTpe, annotatedCapt)
171
+ case o => ???
172
+
173
+ def monomorphize (stmt : Stmt )(using ctx : MonoContext ): Stmt = stmt match
174
+ case Return (expr) => Return (monomorphize(expr))
175
+ case Val (id, annotatedTpe, binding, body) => Val (id, monomorphize(annotatedTpe), monomorphize(binding), monomorphize(body))
176
+ case App (callee : BlockVar , targs, vargs, bargs) =>
177
+ val replacementId = replacementIdFromTargs(callee.id, targs)
178
+ App (monomorphize(callee, replacementId), List .empty, vargs map monomorphize, bargs map monomorphize)
179
+ case Let (id, annotatedTpe, binding, body) => Let (id, monomorphize(annotatedTpe), monomorphize(binding), monomorphize(body))
180
+ case If (cond, thn, els) => If (monomorphize(cond), monomorphize(thn), monomorphize(els))
181
+ case o => println(o); ???
182
+
183
+ def monomorphize (expr : Expr )(using ctx : MonoContext ): Expr = expr match
184
+ case DirectApp (b, targs, vargs, bargs) =>
185
+ val replacementId = replacementIdFromTargs(b.id, targs)
186
+ DirectApp (monomorphize(b, replacementId), List .empty, vargs map monomorphize, bargs map monomorphize)
187
+ case o => println(o); ???
188
+
189
+ def monomorphize (pure : Pure )(using ctx : MonoContext ): Pure = pure match
190
+ case ValueVar (id, annotatedType) => ValueVar (id, monomorphize(annotatedType))
191
+ case PureApp (b, targs, vargs) =>
192
+ val replacementId = replacementIdFromTargs(b.id, targs)
193
+ PureApp (monomorphize(b, replacementId), List .empty, vargs map monomorphize)
194
+ case Literal (value, annotatedType) => Literal (value, monomorphize(annotatedType))
195
+ case o => println(o); ???
196
+
197
+ def monomorphize (valueParam : ValueParam )(using ctx : MonoContext ): ValueParam = valueParam match
198
+ case ValueParam (id, tpe) => ValueParam (id, monomorphize(tpe))
199
+
200
+ def monomorphize (blockParam : BlockParam )(using ctx : MonoContext ): BlockParam = blockParam match
201
+ // TODO: Same question as in block
202
+ case BlockParam (id, tpe, capt) => BlockParam (id, monomorphize(tpe), capt)
203
+
204
+ def monomorphize (blockType : BlockType )(using ctx : MonoContext ): BlockType = blockType match
205
+ case BlockType .Function (tparams, cparams, vparams, bparams, result) =>
206
+ BlockType .Function (List .empty, cparams, vparams map monomorphize, bparams map monomorphize, monomorphize(result))
207
+ case o => println(o); ???
208
+
209
+ def monomorphize (valueType : ValueType )(using ctx : MonoContext ): ValueType = valueType match
210
+ case ValueType .Var (name) => ValueType .Var (ctx.replacementTparams(name).tpe)
211
+ case ValueType .Data (name, targs) => ValueType .Data (name, targs)
212
+ case o => println(o); ???
213
+
214
+ var monoCounter = 0
215
+ def freshMonoName (baseId : Id , tpe : TypeArg .Base ): Id =
216
+ monoCounter += 1
217
+ Id (baseId.name.name + tpe.tpe.name.name + monoCounter)
218
+
219
+ def freshMonoName (baseId : Id , tpes : Vector [TypeArg .Base ]): Id =
220
+ if (tpes.length == 0 ) return baseId
221
+
222
+ monoCounter += 1
223
+ val tpesString = tpes.map(tpe => tpe.tpe.name.name).mkString
224
+ Id (baseId.name.name + tpesString + monoCounter)
225
+
226
+ def replacementIdFromTargs (id : FunctionId , targs : List [ValueType ])(using ctx : MonoContext ): FunctionId =
227
+ if (targs.isEmpty) return id
228
+ var baseTypes : List [TypeArg .Base ] = List .empty
229
+ targs.foreach({
230
+ case ValueType .Data (name, targs) => baseTypes :+= TypeArg .Base (name)
231
+ case ValueType .Var (name) => baseTypes :+= ctx.replacementTparams(name)
232
+ case ValueType .Boxed (tpe, capt) =>
233
+ })
234
+ ctx.names((id, baseTypes.toVector))
235
+
127
236
// Old stuff
128
237
129
238
// type PolyConstraints = Map[Id, Set[PolyType]]
@@ -132,16 +241,7 @@ def productAppend[A](ls: List[List[A]], rs: List[A]): List[List[A]] =
132
241
133
242
// class MonoContext(val solvedConstraints: PolyConstraintsSolved, var monoDefs: Map[Id, Map[List[PolyType.Base], (Id, Block)]] = Map.empty)
134
243
135
- // var monoCounter = 0
136
- // def freshMonoName(baseId: Id, tpe: PolyType.Base): Id =
137
- // monoCounter += 1
138
- // Id(baseId.name.name + tpe.tpe.name.name + monoCounter)
139
244
140
- // def freshMonoName(baseId: Id, tpes: List[PolyType.Base]): Id =
141
- // monoCounter += 1
142
- // var tpesString = ""
143
- // tpes.foreach(tpe => tpesString += tpe.tpe.name.name)
144
- // Id(baseId.name.name + tpesString + monoCounter)
145
245
146
246
// // TODO: The following two are awful and surely doing redundant work.
147
247
// def generator(xs: List[Set[PolyConstraintSingle]]): List[Set[PolyConstraintSingle]] = xs.foldRight(List(Set.empty)) { (next, combinations) =>
0 commit comments