Skip to content

Commit e18b705

Browse files
committed
Implement solveConstraints
1 parent c80695f commit e18b705

File tree

1 file changed

+37
-112
lines changed
  • effekt/shared/src/main/scala/effekt/core

1 file changed

+37
-112
lines changed

effekt/shared/src/main/scala/effekt/core/Mono.scala

Lines changed: 37 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ object Mono extends Phase[CoreTransformed, CoreTransformed] {
2020
constraints.foreach(c => println(c))
2121
println()
2222

23-
// val solved = solveConstraint(constraints)
24-
// println("Solved")
25-
// solved.foreach(println)
23+
val solved = solveConstraints(constraints)
24+
println("Solved")
25+
solved.foreach(println)
2626
println()
2727

2828

@@ -38,8 +38,7 @@ type FunctionId = Id
3838
case class Constraint(lower: Vector[TypeArg], upper: FunctionId)
3939
type Constraints = List[Constraint]
4040

41-
// case class SolvedConstraint(lower: Vector[TypeArg.Base], upper: FunctionId | TypeArg.Var)
42-
// type SolvedConstraints = List[SolvedConstraint]
41+
type Solution = Map[FunctionId, Set[Vector[TypeArg.Base]]]
4342

4443
enum TypeArg {
4544
case Base(val tpe: Id)
@@ -91,6 +90,39 @@ def findId(vt: ValueType)(using ctx: MonoContext): TypeArg = vt match
9190
case ValueType.Data(name, targs) => TypeArg.Base(name)
9291
case ValueType.Var(name) => ctx.typingContext(name)
9392

93+
def solveConstraints(constraints: Constraints): Solution =
94+
var solved: Solution = Map()
95+
96+
val groupedConstraints = constraints.groupBy(c => c.upper)
97+
val vecConstraints = groupedConstraints.map((sym, constraints) => (sym -> constraints.map(c => c.lower)))
98+
99+
vecConstraints.foreach((sym, tas) =>
100+
val sol = solveConstraints(sym).map(bs => bs.toVector)
101+
solved += (sym -> sol)
102+
)
103+
104+
def solveConstraints(funId: FunctionId): Set[List[TypeArg.Base]] =
105+
val filteredConstraints = vecConstraints(funId)
106+
var nbs: Set[List[TypeArg.Base]] = Set.empty
107+
filteredConstraints.foreach(b =>
108+
var l: List[List[TypeArg.Base]] = List(List.empty)
109+
def listFromIndex(ind: Int) = if (ind >= l.length) List.empty else l(ind)
110+
b.foreach({
111+
case TypeArg.Base(tpe) => l = productAppend(l, List(TypeArg.Base(tpe)))
112+
case TypeArg.Var(funId, pos) =>
113+
val funSolved = solved.getOrElse(funId, solveConstraints(funId))
114+
val posArgs = funSolved.map(v => v(pos))
115+
l = posArgs.zipWithIndex.map((base, ind) => listFromIndex(ind) :+ base).toList
116+
println(l)
117+
})
118+
nbs ++= l
119+
)
120+
nbs
121+
122+
solved
123+
124+
def productAppend[A](ls: List[List[A]], rs: List[A]): List[List[A]] =
125+
rs.flatMap(r => ls.map(l => l :+ r))
94126

95127
// Old stuff
96128

@@ -219,110 +251,3 @@ def findId(vt: ValueType)(using ctx: MonoContext): TypeArg = vt match
219251
// TODO: After solving the constraints it would be helpful to know
220252
// which functions have which tparams
221253
// so we can generate the required monomorphic functions
222-
223-
// enum PolyType {
224-
// case Base(val tpe: Id)
225-
// case Var(val sym: Id)
226-
227-
// def toSymbol: Id = this match {
228-
// case Base(tpe) => tpe
229-
// case Var(sym) => sym
230-
// }
231-
232-
// def toValueType: ValueType = this match {
233-
// case Base(tpe) => ValueType.Data(tpe, List.empty)
234-
// case Var(sym) => ValueType.Var(sym)
235-
// }
236-
// }
237-
238-
// def solveConstraints(constraints: PolyConstraints): PolyConstraintsSolved =
239-
// var solved: PolyConstraintsSolved = Map()
240-
241-
// def solveConstraint(sym: Id, types: Set[PolyType]): Set[PolyType.Base] =
242-
// var polyTypes: Set[PolyType.Base] = Set()
243-
// types.foreach {
244-
// case PolyType.Var(symbol) => polyTypes ++= solved.getOrElse(symbol, solveConstraint(symbol, constraints.getOrElse(symbol, Set())))
245-
// case PolyType.Base(tpe) => polyTypes += PolyType.Base(tpe)
246-
// }
247-
// solved += (sym -> polyTypes)
248-
// polyTypes
249-
250-
// constraints.foreach(solveConstraint)
251-
252-
// solved
253-
254-
// def combineConstraints(a: PolyConstraints, b: PolyConstraints): PolyConstraints = {
255-
// a ++ b.map { case (k, v) => k -> (v ++ a.getOrElse(k, Iterable.empty)) }
256-
// }
257-
258-
// def findConstraints(definitions: List[Toplevel]): PolyConstraints =
259-
// definitions.map(findConstraints).reduce(combineConstraints)
260-
261-
// def findConstraints(toplevel: Toplevel): PolyConstraints = toplevel match {
262-
// case Toplevel.Def(id, block) => findConstraints(block, List.empty)
263-
// case Toplevel.Val(id, tpe, binding) => ???
264-
// }
265-
266-
// def findConstraints(block: Block, targs: List[ValueType]): PolyConstraints = block match {
267-
// case BlockLit(tparam :: tparams, cparams, vparams, bparams, body) => findConstraints(body, tparam :: tparams)
268-
// case BlockLit(List(), cparams, vparams, bparams, body) => findConstraints(body, List.empty)
269-
// case BlockVar(id, annotatedTpe, annotatedCapt) => findConstraints(annotatedTpe, targs)
270-
// case New(impl) => ???
271-
// case Unbox(pure) => ???
272-
// case _ => Map.empty
273-
// }
274-
275-
// def findConstraints(stmt: Stmt, tparams: List[Id]): PolyConstraints = stmt match {
276-
// case App(callee, targs, vargs, bargs) => findConstraints(callee, targs)
277-
// case Return(expr) if !tparams.isEmpty => Map(tparams.head -> Set(findPolyType(expr.tpe)))
278-
// case Return(expr) => Map.empty
279-
// case Val(id, annotatedTpe, binding, body) => combineConstraints(findConstraints(binding, tparams), findConstraints(body, tparams))
280-
// // TODO: Let & If case is wrong, but placeholders are required as they are used in print
281-
// case Let(id, annotatedTpe, binding, body) => Map.empty
282-
// case If(cond, thn, els) => Map.empty
283-
// case o => println(o); ???
284-
// }
285-
286-
// def findConstraints(value: Val): PolyConstraints = value match {
287-
// // TODO: List.empty might be wrong
288-
// case Val(id, annotatedTpe, binding, body) => combineConstraints(findConstraints(binding, List.empty), findConstraints(body, List.empty))
289-
// }
290-
291-
// def findConstraints(blockType: BlockType, targs: List[ValueType]): PolyConstraints = blockType match {
292-
// case BlockType.Function(tparams, cparams, vparams, bparams, result) => tparams.zip(targs).map((id, tpe) => (id -> Set(findPolyType(tpe)))).toMap
293-
// case BlockType.Interface(name, targs) => ???
294-
// }
295-
296-
// def findPolyType(blockType: BlockType, targs: List[ValueType]): List[PolyType] = blockType match {
297-
// case BlockType.Function(tparams, cparams, vparams, bparams, result) => ???
298-
// case BlockType.Interface(name, targs) => ???
299-
// }
300-
301-
// def findPolyType(valueType: ValueType): PolyType = valueType match {
302-
// case ValueType.Boxed(tpe, capt) => ???
303-
// case ValueType.Data(name, targs) => PolyType.Base(name)
304-
// case ValueType.Var(name) => PolyType.Var(name)
305-
// }
306-
307-
// def hasCycle(constraints: PolyConstraints): Boolean =
308-
// var visited: Set[Id] = Set()
309-
// var recStack: Set[Id] = Set()
310-
311-
// def hasCycleHelper(vertex: Id): Boolean =
312-
// if (recStack.contains(vertex)) return true
313-
// if (visited.contains(vertex)) return false
314-
315-
// visited += vertex
316-
// recStack += vertex
317-
318-
// var cycleFound = false
319-
// constraints.getOrElse(vertex, Set()).foreach(v => cycleFound |= hasCycleHelper(v.toSymbol))
320-
321-
// recStack -= vertex
322-
323-
// cycleFound
324-
325-
// var cycleFound = false
326-
// constraints.keys.foreach(v => cycleFound |= !visited.contains(v) && hasCycleHelper(v))
327-
328-
// cycleFound

0 commit comments

Comments
 (0)