@@ -20,9 +20,9 @@ object Mono extends Phase[CoreTransformed, CoreTransformed] {
20
20
constraints.foreach(c => println(c))
21
21
println()
22
22
23
- // val solved = solveConstraint (constraints)
24
- // println("Solved")
25
- // solved.foreach(println)
23
+ val solved = solveConstraints (constraints)
24
+ println(" Solved" )
25
+ solved.foreach(println)
26
26
println()
27
27
28
28
@@ -38,8 +38,7 @@ type FunctionId = Id
38
38
case class Constraint (lower : Vector [TypeArg ], upper : FunctionId )
39
39
type Constraints = List [Constraint ]
40
40
41
- // case class SolvedConstraint(lower: Vector[TypeArg.Base], upper: FunctionId | TypeArg.Var)
42
- // type SolvedConstraints = List[SolvedConstraint]
41
+ type Solution = Map [FunctionId , Set [Vector [TypeArg .Base ]]]
43
42
44
43
enum TypeArg {
45
44
case Base (val tpe : Id )
@@ -91,6 +90,39 @@ def findId(vt: ValueType)(using ctx: MonoContext): TypeArg = vt match
91
90
case ValueType .Data (name, targs) => TypeArg .Base (name)
92
91
case ValueType .Var (name) => ctx.typingContext(name)
93
92
93
+ def solveConstraints (constraints : Constraints ): Solution =
94
+ var solved : Solution = Map ()
95
+
96
+ val groupedConstraints = constraints.groupBy(c => c.upper)
97
+ val vecConstraints = groupedConstraints.map((sym, constraints) => (sym -> constraints.map(c => c.lower)))
98
+
99
+ vecConstraints.foreach((sym, tas) =>
100
+ val sol = solveConstraints(sym).map(bs => bs.toVector)
101
+ solved += (sym -> sol)
102
+ )
103
+
104
+ def solveConstraints (funId : FunctionId ): Set [List [TypeArg .Base ]] =
105
+ val filteredConstraints = vecConstraints(funId)
106
+ var nbs : Set [List [TypeArg .Base ]] = Set .empty
107
+ filteredConstraints.foreach(b =>
108
+ var l : List [List [TypeArg .Base ]] = List (List .empty)
109
+ def listFromIndex (ind : Int ) = if (ind >= l.length) List .empty else l(ind)
110
+ b.foreach({
111
+ case TypeArg .Base (tpe) => l = productAppend(l, List (TypeArg .Base (tpe)))
112
+ case TypeArg .Var (funId, pos) =>
113
+ val funSolved = solved.getOrElse(funId, solveConstraints(funId))
114
+ val posArgs = funSolved.map(v => v(pos))
115
+ l = posArgs.zipWithIndex.map((base, ind) => listFromIndex(ind) :+ base).toList
116
+ println(l)
117
+ })
118
+ nbs ++= l
119
+ )
120
+ nbs
121
+
122
+ solved
123
+
124
+ def productAppend [A ](ls : List [List [A ]], rs : List [A ]): List [List [A ]] =
125
+ rs.flatMap(r => ls.map(l => l :+ r))
94
126
95
127
// Old stuff
96
128
@@ -219,110 +251,3 @@ def findId(vt: ValueType)(using ctx: MonoContext): TypeArg = vt match
219
251
// TODO: After solving the constraints it would be helpful to know
220
252
// which functions have which tparams
221
253
// so we can generate the required monomorphic functions
222
-
223
- // enum PolyType {
224
- // case Base(val tpe: Id)
225
- // case Var(val sym: Id)
226
-
227
- // def toSymbol: Id = this match {
228
- // case Base(tpe) => tpe
229
- // case Var(sym) => sym
230
- // }
231
-
232
- // def toValueType: ValueType = this match {
233
- // case Base(tpe) => ValueType.Data(tpe, List.empty)
234
- // case Var(sym) => ValueType.Var(sym)
235
- // }
236
- // }
237
-
238
- // def solveConstraints(constraints: PolyConstraints): PolyConstraintsSolved =
239
- // var solved: PolyConstraintsSolved = Map()
240
-
241
- // def solveConstraint(sym: Id, types: Set[PolyType]): Set[PolyType.Base] =
242
- // var polyTypes: Set[PolyType.Base] = Set()
243
- // types.foreach {
244
- // case PolyType.Var(symbol) => polyTypes ++= solved.getOrElse(symbol, solveConstraint(symbol, constraints.getOrElse(symbol, Set())))
245
- // case PolyType.Base(tpe) => polyTypes += PolyType.Base(tpe)
246
- // }
247
- // solved += (sym -> polyTypes)
248
- // polyTypes
249
-
250
- // constraints.foreach(solveConstraint)
251
-
252
- // solved
253
-
254
- // def combineConstraints(a: PolyConstraints, b: PolyConstraints): PolyConstraints = {
255
- // a ++ b.map { case (k, v) => k -> (v ++ a.getOrElse(k, Iterable.empty)) }
256
- // }
257
-
258
- // def findConstraints(definitions: List[Toplevel]): PolyConstraints =
259
- // definitions.map(findConstraints).reduce(combineConstraints)
260
-
261
- // def findConstraints(toplevel: Toplevel): PolyConstraints = toplevel match {
262
- // case Toplevel.Def(id, block) => findConstraints(block, List.empty)
263
- // case Toplevel.Val(id, tpe, binding) => ???
264
- // }
265
-
266
- // def findConstraints(block: Block, targs: List[ValueType]): PolyConstraints = block match {
267
- // case BlockLit(tparam :: tparams, cparams, vparams, bparams, body) => findConstraints(body, tparam :: tparams)
268
- // case BlockLit(List(), cparams, vparams, bparams, body) => findConstraints(body, List.empty)
269
- // case BlockVar(id, annotatedTpe, annotatedCapt) => findConstraints(annotatedTpe, targs)
270
- // case New(impl) => ???
271
- // case Unbox(pure) => ???
272
- // case _ => Map.empty
273
- // }
274
-
275
- // def findConstraints(stmt: Stmt, tparams: List[Id]): PolyConstraints = stmt match {
276
- // case App(callee, targs, vargs, bargs) => findConstraints(callee, targs)
277
- // case Return(expr) if !tparams.isEmpty => Map(tparams.head -> Set(findPolyType(expr.tpe)))
278
- // case Return(expr) => Map.empty
279
- // case Val(id, annotatedTpe, binding, body) => combineConstraints(findConstraints(binding, tparams), findConstraints(body, tparams))
280
- // // TODO: Let & If case is wrong, but placeholders are required as they are used in print
281
- // case Let(id, annotatedTpe, binding, body) => Map.empty
282
- // case If(cond, thn, els) => Map.empty
283
- // case o => println(o); ???
284
- // }
285
-
286
- // def findConstraints(value: Val): PolyConstraints = value match {
287
- // // TODO: List.empty might be wrong
288
- // case Val(id, annotatedTpe, binding, body) => combineConstraints(findConstraints(binding, List.empty), findConstraints(body, List.empty))
289
- // }
290
-
291
- // def findConstraints(blockType: BlockType, targs: List[ValueType]): PolyConstraints = blockType match {
292
- // case BlockType.Function(tparams, cparams, vparams, bparams, result) => tparams.zip(targs).map((id, tpe) => (id -> Set(findPolyType(tpe)))).toMap
293
- // case BlockType.Interface(name, targs) => ???
294
- // }
295
-
296
- // def findPolyType(blockType: BlockType, targs: List[ValueType]): List[PolyType] = blockType match {
297
- // case BlockType.Function(tparams, cparams, vparams, bparams, result) => ???
298
- // case BlockType.Interface(name, targs) => ???
299
- // }
300
-
301
- // def findPolyType(valueType: ValueType): PolyType = valueType match {
302
- // case ValueType.Boxed(tpe, capt) => ???
303
- // case ValueType.Data(name, targs) => PolyType.Base(name)
304
- // case ValueType.Var(name) => PolyType.Var(name)
305
- // }
306
-
307
- // def hasCycle(constraints: PolyConstraints): Boolean =
308
- // var visited: Set[Id] = Set()
309
- // var recStack: Set[Id] = Set()
310
-
311
- // def hasCycleHelper(vertex: Id): Boolean =
312
- // if (recStack.contains(vertex)) return true
313
- // if (visited.contains(vertex)) return false
314
-
315
- // visited += vertex
316
- // recStack += vertex
317
-
318
- // var cycleFound = false
319
- // constraints.getOrElse(vertex, Set()).foreach(v => cycleFound |= hasCycleHelper(v.toSymbol))
320
-
321
- // recStack -= vertex
322
-
323
- // cycleFound
324
-
325
- // var cycleFound = false
326
- // constraints.keys.foreach(v => cycleFound |= !visited.contains(v) && hasCycleHelper(v))
327
-
328
- // cycleFound
0 commit comments