diff --git a/effekt/jvm/src/test/scala/effekt/core/MonoTests.scala b/effekt/jvm/src/test/scala/effekt/core/MonoTests.scala new file mode 100644 index 000000000..10383090a --- /dev/null +++ b/effekt/jvm/src/test/scala/effekt/core/MonoTests.scala @@ -0,0 +1,112 @@ +package effekt +package core + + +abstract class AbstractMonoTests extends CorePhaseTests(Mono) { + import TypeArg.* + + implicit def stringBaseT(name: String): Base = Base(Id(name), List()) + + val BaseTInt: Base = "Int" + val BaseTString: Base = "String" + val BaseTChar: Base = "Char" + val BaseTBool: Base = "Bool" + val BaseTDouble: Base = "Double" + + val fnId: Map[String, FunctionId] = Map( + "a" -> Id("a"), + "b" -> Id("b"), + "c" -> Id("c"), + "d" -> Id("d"), + "e" -> Id("e"), + "f" -> Id("f"), + ) +} + +class MonoTests extends AbstractMonoTests { + + import TypeArg.* + + test("simple polymorphic function") { + val constraints = List( + Constraint(Vector(BaseTInt), fnId("a")), + Constraint(Vector(BaseTString), fnId("a")) + ) + val expectedSolved: Solution = Map( + fnId("a") -> Set(Vector(BaseTInt), Vector(BaseTString)) + ) + + assertEquals(solveConstraints(constraints), expectedSolved) + } + + test("calling other polymorphic function") { + val constraints = List( + Constraint(Vector(Var(fnId("b"), 0)), fnId("a")), + Constraint(Vector(BaseTInt), fnId("a")), + Constraint(Vector(BaseTString), fnId("b")), + ) + val expectedSolved: Solution = Map( + fnId("a") -> Set(Vector(BaseTInt), Vector(BaseTString)), + fnId("b") -> Set(Vector(BaseTString)), + ) + + assertEquals(solveConstraints(constraints), expectedSolved) + } + + test("polymorphic function with multiple type args") { + val constraints = List( + Constraint(Vector(BaseTInt, BaseTString), fnId("a")), + Constraint(Vector(BaseTBool, BaseTChar), fnId("a")), + Constraint(Vector(BaseTBool, BaseTString), fnId("a")), + ) + val expectedSolved: Solution = Map( + fnId("a") -> Set( + Vector(BaseTInt, BaseTString), + Vector(BaseTBool, BaseTChar), + Vector(BaseTBool, BaseTString), + ) + ) + + assertEquals(solveConstraints(constraints), expectedSolved) + } + + test("calling other polymorphic function with type args swapped") { + val constraints = List( + Constraint(Vector(Var(fnId("b"), 1), Var(fnId("b"), 0)), fnId("a")), + Constraint(Vector(BaseTString, BaseTBool), fnId("b")), + ) + val expectedSolved: Solution = Map( + fnId("a") -> Set(Vector(BaseTBool, BaseTString)), + fnId("b") -> Set(Vector(BaseTString, BaseTBool)), + ) + + assertEquals(solveConstraints(constraints), expectedSolved) + } + + test("recursive polymorphic function") { + val constraints = List( + Constraint(Vector(Var(fnId("a"), 0)), fnId("a")), + Constraint(Vector(BaseTInt), fnId("a")), + ) + val expectedSolved: Solution = Map( + fnId("a") -> Set(Vector(BaseTInt)), + ) + + assertEquals(solveConstraints(constraints), expectedSolved) + } + + test("mutually recursive polymorphic functions") { + val constraints = List( + Constraint(Vector(Var(fnId("b"), 0)), fnId("a")), + Constraint(Vector(Var(fnId("a"), 0)), fnId("b")), + Constraint(Vector(BaseTInt), fnId("a")), + Constraint(Vector(BaseTString), fnId("b")), + ) + val expectedSolved: Solution = Map( + fnId("a") -> Set(Vector(BaseTInt), Vector(BaseTString)), + fnId("b") -> Set(Vector(BaseTInt), Vector(BaseTString)), + ) + + assertEquals(solveConstraints(constraints), expectedSolved) + } +} diff --git a/effekt/shared/src/main/scala/effekt/core/DeadCodeElimination.scala b/effekt/shared/src/main/scala/effekt/core/DeadCodeElimination.scala new file mode 100644 index 000000000..466e683c7 --- /dev/null +++ b/effekt/shared/src/main/scala/effekt/core/DeadCodeElimination.scala @@ -0,0 +1,20 @@ +package effekt.core + +import effekt.PhaseResult.CoreTransformed +import effekt.context.Context +import effekt.core.optimizer.Deadcode +import effekt.Phase + +object DeadCodeElimination extends Phase[CoreTransformed, CoreTransformed] { + val phaseName: String = "deadcode-elimination" + + def run(input: CoreTransformed)(using Context): Option[CoreTransformed] = + input match { + case CoreTransformed(source, tree, mod, core) => + val term = Context.ensureMainExists(mod) + val dce = Context.timed("deadcode-elimination", source.name) { + Deadcode.remove(term, core) + } + Some(CoreTransformed(source, tree, mod, dce)) + } +} diff --git a/effekt/shared/src/main/scala/effekt/core/Mono.scala b/effekt/shared/src/main/scala/effekt/core/Mono.scala new file mode 100644 index 000000000..e4684aefd --- /dev/null +++ b/effekt/shared/src/main/scala/effekt/core/Mono.scala @@ -0,0 +1,409 @@ +package effekt +package core + +import effekt.context.Context +import effekt.lexer.TokenKind +import effekt.context.assertions.asDataType + +object Mono extends Phase[CoreTransformed, CoreTransformed] { + + override val phaseName: String = "mono" + + override def run(input: CoreTransformed)(using Context): Option[CoreTransformed] = { + input match { + case CoreTransformed(source, tree, mod, core) => { + core match { + case ModuleDecl(path, includes, declarations, externs, definitions, exports) => { + // Find constraints in the definitions + val monoFindContext = MonoFindContext() + var constraints = findConstraints(definitions)(using monoFindContext) + constraints = constraints ++ declarations.flatMap(findConstraints(_)(using monoFindContext)) + // println("Constraints") + // constraints.foreach(c => println(c)) + // println() + + // Solve collected constraints + val solution = solveConstraints(constraints) + // println("Solved") + // solution.foreach(println) + // println() + + // Monomorphize existing definitions + var monoNames: MonoNames = Map.empty + solution.foreach((funId, targs) => + targs.foreach(vb => + monoNames += ((funId, vb) -> freshMonoName(funId, vb)) + ) + ) + + var monoContext = MonoContext(solution, monoNames) + val monoDecls = declarations flatMap (monomorphize(_)(using monoContext)) + val monoDefs = monomorphize(definitions)(using monoContext) + // monoDecls.foreach(decl => println(util.show(decl))) + // println() + // monoDefs.foreach(defn => println(util.show(defn))) + val newModuleDecl = ModuleDecl(path, includes, monoDecls, externs, monoDefs, exports) + return Some(CoreTransformed(source, tree, mod, newModuleDecl)) + } + } + } + } + Some(input) + } +} + +type FunctionId = Id +case class Constraint(lower: Vector[TypeArg], upper: FunctionId) +type Constraints = List[Constraint] + +type Solution = Map[FunctionId, Set[Vector[TypeArg.Base]]] +type MonoNames = Map[(FunctionId, Vector[TypeArg.Base]), FunctionId] + +enum TypeArg { + case Base(val tpe: Id, targs: List[TypeArg]) + case Var(funId: FunctionId, pos: Int) +} + +// Type Id -> Var +type TypeParams = Map[Id, TypeArg.Var] + +class MonoFindContext { + var typingContext: TypeParams = Map() + + def extendTypingContext(tparam: Id, index: Int, functionId: FunctionId) = + typingContext += (tparam -> TypeArg.Var(functionId, index)) +} + +case class MonoContext(solution: Solution, names: MonoNames) { + var replacementTparams: Map[Id, TypeArg.Base] = Map.empty +} + +def findConstraints(definitions: List[Toplevel])(using MonoFindContext): Constraints = + definitions flatMap findConstraints + +def findConstraints(definition: Toplevel)(using ctx: MonoFindContext): Constraints = definition match + case Toplevel.Def(id, BlockLit(tparams, cparams, vparams, bparams, body)) => + tparams.zipWithIndex.foreach(ctx.extendTypingContext(_, _, id)) + findConstraints(body) + case Toplevel.Def(id, block) => ??? + case Toplevel.Val(id, tpe, binding) => ??? + +def findConstraints(declaration: Declaration)(using ctx: MonoFindContext): Constraints = declaration match + // Maybe[T] { Just[](x: T) } + case Data(id, tparams, constructors) => + tparams.zipWithIndex.foreach(ctx.extendTypingContext(_, _, id)) + constructors.map{ constr => + val arity = tparams.size // + constr.tparams.size + val constructorArgs = (0 until arity).map(index => + TypeArg.Var(constr.id, index) // Just.0 + ).toVector // < Just.0 > + Constraint(constructorArgs, id) // < Just.0 > <: Maybe + } + case Interface(id, tparams, properties) => + tparams.zipWithIndex.foreach(ctx.extendTypingContext(_, _, id)) + List.empty + +def findConstraints(block: Block)(using ctx: MonoFindContext): Constraints = block match + case BlockVar(id, annotatedTpe: BlockType.Interface, annotatedCapt) => findConstraints(annotatedTpe) + case BlockVar(id, annotatedTpe: BlockType.Function, annotatedCapt) => findConstraints(annotatedTpe, id) + case BlockLit(tparams, cparams, vparams, bparams, body) => findConstraints(body) + case Unbox(pure) => findConstraints(pure) + case New(impl) => findConstraints(impl) + +def findConstraints(blockType: BlockType.Interface)(using ctx: MonoFindContext): Constraints = blockType match + case BlockType.Interface(name, targs) => + List(Constraint(targs.map(findId).toVector, name)) + +def findConstraints(blockType: BlockType.Function, fnId: Id)(using ctx: MonoFindContext): Constraints = blockType match + case BlockType.Function(tparams, cparams, vparams, bparams, result) => + tparams.zipWithIndex.foreach(ctx.extendTypingContext(_, _, fnId)) + List() + +def findConstraints(impl: Implementation)(using ctx: MonoFindContext): Constraints = impl match + case Implementation(interface, operations) => + findConstraints(interface) ++ + (operations flatMap findConstraints) + +def findConstraints(operation: Operation)(using ctx: MonoFindContext): Constraints = operation match + case Operation(name, tparams, cparams, vparams, bparams, body) => + tparams.zipWithIndex.foreach(ctx.extendTypingContext(_, _, name)) + findConstraints(body) + +def findConstraints(constructor: Constructor)(using ctx: MonoFindContext): Constraints = constructor match + case Constructor(id, tparams, List()) => List.empty + case Constructor(id, tparams, fields) => + List(Constraint(((fields map (_.tpe)) map findId).toVector, id)) + +def findConstraints(stmt: Stmt)(using ctx: MonoFindContext): Constraints = stmt match + case Let(id, annotatedTpe, binding, body) => findConstraints(binding) ++ findConstraints(body) + case Return(expr) => findConstraints(expr) + case Val(id, annotatedTpe, binding, body) => findConstraints(binding) ++ findConstraints(body) + case Var(ref, init, capture, body) => findConstraints(body) + case App(callee: BlockVar, targs, vargs, bargs) => + List(Constraint(targs.map(findId).toVector, callee.id)) ++ vargs.flatMap(findConstraints) ++ bargs.flatMap(findConstraints) + // TODO: Very specialized, but otherwise passing an id that matches in monomorphize is hard + // although I'm not certain any other case can even happen + // TODO: part 2, also update the implementation in monomorphize if changing this + case App(Unbox(ValueVar(id, annotatedType)), targs, vargs, bargs) => + List(Constraint(targs.map(findId).toVector, id)) ++ vargs.flatMap(findConstraints) ++ bargs.flatMap(findConstraints) + case Invoke(callee: BlockVar, method, methodTpe, targs, vargs, bargs) => + List(Constraint(targs.map(findId).toVector, callee.id)) ++ vargs.flatMap(findConstraints) ++ bargs.flatMap(findConstraints) + case Reset(body) => findConstraints(body) + case If(cond, thn, els) => findConstraints(cond) ++ findConstraints(thn) ++ findConstraints(els) + case Def(id, block, body) => findConstraints(block) ++ findConstraints(body) + case Shift(prompt, body) => findConstraints(prompt) ++ findConstraints(body) + case Match(scrutinee, clauses, default) => clauses.map(_._2).flatMap(findConstraints) ++ findConstraints(default) + case Resume(k, body) => findConstraints(k) ++ findConstraints(body) + case Get(id, annotatedTpe, ref, annotatedCapt, body) => findConstraints(body) + case Put(ref, annotatedCapt, value, body) => findConstraints(value) ++ findConstraints(body) + case Alloc(id, init, region, body) => findConstraints(init) ++ findConstraints(body) + case Region(body) => findConstraints(body) + case Hole(span) => List.empty + case o => println(o); ??? + +def findConstraints(opt: Option[Stmt])(using ctx: MonoFindContext): Constraints = opt match + case None => List.empty + case Some(stmt) => findConstraints(stmt) + +def findConstraints(expr: Expr)(using ctx: MonoFindContext): Constraints = expr match + // TODO: + // Technically targs should still flow + // Just don't monomorphize + case DirectApp(b, targs, vargs, bargs) => List.empty + case PureApp(b, targs, vargs) => List.empty + case ValueVar(id, annotatedType) => List.empty + case Literal(value, annotatedType) => List.empty + case Make(data, tag, targs, vargs) => + List(Constraint(data.targs.map(findId).toVector, tag)) // <: Just + case Box(b, annotatedCapture) => List.empty + +def findId(vt: ValueType)(using ctx: MonoFindContext): TypeArg = vt match + // TODO: What is the correct TypeArg for Boxed + case ValueType.Boxed(tpe, capt) => ??? + case ValueType.Data(name, targs) => TypeArg.Base(name, targs map findId) + case ValueType.Var(name) => ctx.typingContext(name) + +def solveConstraints(constraints: Constraints): Solution = + var solved: Solution = Map() + + val groupedConstraints = constraints.groupBy(c => c.upper) + val vecConstraints = groupedConstraints.map((sym, constraints) => (sym -> constraints.map(c => c.lower))) + + while (true) { + val previousSolved = solved + vecConstraints.foreach((sym, tas) => + val sol = solveConstraints(sym).map(bs => bs.toVector) + solved += (sym -> sol) + ) + if (previousSolved == solved) return solved + } + + def solveConstraints(funId: FunctionId): Set[List[TypeArg.Base]] = + val filteredConstraints = vecConstraints(funId) + var nbs: Set[List[TypeArg.Base]] = Set.empty + filteredConstraints.foreach(b => + var l: List[List[TypeArg.Base]] = List(List.empty) + def listFromIndex(ind: Int) = if (ind >= l.length) List.empty else l(ind) + b.foreach({ + case TypeArg.Base(tpe, targs) => l = productAppend(l, List(TypeArg.Base(tpe, targs))) + case TypeArg.Var(funId, pos) => + val funSolved = solved.getOrElse(funId, Set.empty) + val posArgs = funSolved.map(v => v(pos)) + l = posArgs.zipWithIndex.map((base, ind) => listFromIndex(ind) :+ base).toList + }) + nbs ++= l + ) + nbs + + solved + +def productAppend[A](ls: List[List[A]], rs: List[A]): List[List[A]] = + rs.flatMap(r => ls.map(l => l :+ r)) + +def monomorphize(definitions: List[Toplevel])(using ctx: MonoContext): List[Toplevel] = + var newDefinitions: List[Toplevel] = List.empty + definitions.foreach(definition => newDefinitions ++= monomorphize(definition)) + newDefinitions + +def monomorphize(toplevel: Toplevel)(using ctx: MonoContext): List[Toplevel] = toplevel match + case Toplevel.Def(id, BlockLit(List(), cparams, vparams, bparams, body)) => + List(Toplevel.Def(id, BlockLit(List.empty, cparams, vparams, bparams, monomorphize(body)))) + case Toplevel.Def(id, BlockLit(tparams, cparams, vparams, bparams, body)) => + val monoTypes = ctx.solution(id).toList + monoTypes.map(baseTypes => + val replacementTparams = tparams.zip(baseTypes).toMap + ctx.replacementTparams ++= replacementTparams + Toplevel.Def(ctx.names(id, baseTypes), BlockLit(List.empty, cparams, vparams map monomorphize, bparams map monomorphize, monomorphize(body))) + ) + case Toplevel.Def(id, block) => ??? + case Toplevel.Val(id, tpe, binding) => ??? + +def monomorphize(decl: Declaration)(using ctx: MonoContext): List[Declaration] = decl match + case Data(id, List(), constructors) => List(decl) + case Data(id, tparams, constructors) => + val monoTypes = ctx.solution.getOrElse(id, Set.empty).toList + monoTypes.map(baseTypes => + val replacementTparams = tparams.zip(baseTypes).toMap + ctx.replacementTparams ++= replacementTparams + val newConstructors = constructors map { + case Constructor(id, tparams, fields) => Constructor(id, tparams, fields map monomorphize) + } + Declaration.Data(ctx.names(id, baseTypes), List.empty, newConstructors) + ) + case Interface(id, List(), properties) => List(decl) + case Interface(id, tparams, properties) => + val monoTypes = ctx.solution.getOrElse(id, Set.empty).toList + monoTypes.map(baseTypes => + val replacementTparams = tparams.zip(baseTypes).toMap + ctx.replacementTparams ++= replacementTparams + Declaration.Interface(ctx.names(id, baseTypes), List.empty, properties) + ) + +def monomorphize(block: Block)(using ctx: MonoContext): Block = block match + case b: BlockLit => monomorphize(b) + case b: BlockVar => monomorphize(b) + case New(impl) => New(monomorphize(impl)) + case Unbox(pure) => Unbox(monomorphize(pure)) + +def monomorphize(impl: Implementation)(using ctx: MonoContext): Implementation = impl match + case Implementation(interface, operations) => Implementation(monomorphize(interface), operations.map(monomorphize)) + +def monomorphize(interface: BlockType.Interface)(using ctx: MonoContext): BlockType.Interface = interface match + case BlockType.Interface(name, targs) => + val replacementData = replacementDataFromTargs(name, targs) + BlockType.Interface(replacementData.name, replacementData.targs) + +def monomorphize(operation: Operation)(using ctx: MonoContext): Operation = operation match + case Operation(name, tparams, cparams, vparams, bparams, body) => + Operation(name, List.empty, cparams, vparams map monomorphize, bparams map monomorphize, monomorphize(body)) + +def monomorphize(block: BlockLit)(using ctx: MonoContext): BlockLit = block match + case BlockLit(tparams, cparams, vparams, bparams, body) => + BlockLit(List.empty, cparams, vparams map monomorphize, bparams map monomorphize, monomorphize(body)) + +def monomorphize(block: BlockVar)(using ctx: MonoContext): BlockVar = block match + case BlockVar(id, annotatedTpe, annotatedCapt) => BlockVar(id, monomorphize(annotatedTpe), annotatedCapt) + +def monomorphize(field: Field)(using ctx: MonoContext): Field = field match + case Field(id, tpe) => Field(id, monomorphize(tpe)) + +def monomorphize(blockVar: BlockVar, replacementId: FunctionId)(using ctx: MonoContext): BlockVar = blockVar match + case BlockVar(id, BlockType.Function(List(), cparams, vparams, bparams, result), annotatedCapt) => blockVar + // TODO: What is in annotated captures. Does it need to be handled? + case BlockVar(id, BlockType.Function(tparams, cparams, vparams, bparams, result), annotatedCapt) => + val monoAnnotatedTpe = BlockType.Function(List.empty, cparams, vparams map monomorphize, bparams map monomorphize, monomorphize(result)) + BlockVar(replacementId, monoAnnotatedTpe, annotatedCapt) + case o => println(o); ??? + +def monomorphize(stmt: Stmt)(using ctx: MonoContext): Stmt = stmt match + case Return(expr) => Return(monomorphize(expr)) + case Val(id, annotatedTpe, binding, body) => + Val(id, monomorphize(annotatedTpe), monomorphize(binding), monomorphize(body)) + case Var(ref, init, capture, body) => + Var(ref, monomorphize(init), capture, monomorphize(body)) + case App(callee: BlockVar, targs, vargs, bargs) => + val replacementData = replacementDataFromTargs(callee.id, targs) + App(monomorphize(callee, replacementData.name), List.empty, vargs map monomorphize, bargs map monomorphize) + // TODO: Highly specialized, see todo in findConstraints for info + // change at the same time as findConstraints + case App(Unbox(ValueVar(id, annotatedTpe)), targs, vargs, bargs) => + val replacementData = replacementDataFromTargs(id, targs) + App(Unbox(ValueVar(id, monomorphize(annotatedTpe))), List.empty, vargs map monomorphize, bargs map monomorphize) + case Let(id, annotatedTpe, binding, body) => Let(id, monomorphize(annotatedTpe), monomorphize(binding), monomorphize(body)) + case If(cond, thn, els) => If(monomorphize(cond), monomorphize(thn), monomorphize(els)) + case Invoke(Unbox(pure), method, methodTpe, targs, vargs, bargs) => + Invoke(Unbox(monomorphize(pure)), method, methodTpe, List.empty, vargs map monomorphize, bargs map monomorphize) + case Invoke(BlockVar(id, annotatedTpe, annotatedCapt), method, methodTpe, targs, vargs, bargs) => + Invoke(BlockVar(id, monomorphize(annotatedTpe), annotatedCapt), method, methodTpe, List.empty, vargs map monomorphize, bargs map monomorphize) + // TODO: Monomorphizing here throws an error complaining about a missing implementation + // Not sure what is missing, altough it does works like this + case Reset(body) => Reset(body) + case Def(id, block, body) => Def(id, monomorphize(block), monomorphize(body)) + case Shift(prompt, body) => Shift(monomorphize(prompt), monomorphize(body)) + case Match(scrutinee, clauses, default) => + val monoClauses = clauses.map((id, blockLit) => (id, monomorphize(blockLit))) + Match(monomorphize(scrutinee), monoClauses, monomorphize(default)) + case Get(id, annotatedTpe, ref, annotatedCapt, body) => + Get(id, monomorphize(annotatedTpe), ref, annotatedCapt, monomorphize(body)) + case Put(ref, annotatedCapt, value, body) => + Put(ref, annotatedCapt, monomorphize(value), monomorphize(body)) + case Alloc(id, init, region, body) => + Alloc(id, monomorphize(init), region, monomorphize(body)) + case Region(body) => Region(monomorphize(body)) + case Hole(span) => Hole(span) + case o => println(o); ??? + +def monomorphize(opt: Option[Stmt])(using ctx: MonoContext): Option[Stmt] = opt match + case None => None + case Some(stmt) => Some(monomorphize(stmt)) + +def monomorphize(expr: Expr)(using ctx: MonoContext): Expr = expr match + case DirectApp(b, targs, vargs, bargs) => + val replacementData = replacementDataFromTargs(b.id, targs) + DirectApp(monomorphize(b, replacementData.name), List.empty, vargs map monomorphize, bargs map monomorphize) + case Literal(value, annotatedType) => + Literal(value, monomorphize(annotatedType)) + case PureApp(b, targs, vargs) => + val replacementData = replacementDataFromTargs(b.id, targs) + PureApp(monomorphize(b, replacementData.name), List.empty, vargs map monomorphize) + case Make(data, tag, targs, vargs) => + Make(replacementDataFromTargs(data.name, data.targs), tag, List.empty, vargs map monomorphize) + case Box(b, annotatedCapture) => + // TODO: Does this need other handling? + Box(b, annotatedCapture) + case o => println(o); ??? + +def monomorphize(pure: Pure)(using ctx: MonoContext): Pure = pure match + case ValueVar(id, annotatedType) => ValueVar(id, monomorphize(annotatedType)) + case PureApp(b, targs, vargs) => + val replacementData = replacementDataFromTargs(b.id, targs) + PureApp(monomorphize(b, replacementData.name), List.empty, vargs map monomorphize) + case Literal(value, annotatedType) => Literal(value, monomorphize(annotatedType)) + case Make(data, tag, targs, vargs) => + val replacementData = replacementDataFromTargs(data.name, data.targs) + Make(replacementData, tag, List.empty, vargs map monomorphize) + case o => println(o); ??? + +def monomorphize(valueParam: ValueParam)(using ctx: MonoContext): ValueParam = valueParam match + case ValueParam(id, tpe) => ValueParam(id, monomorphize(tpe)) + +def monomorphize(blockParam: BlockParam)(using ctx: MonoContext): BlockParam = blockParam match + // TODO: Same question as in block + case BlockParam(id, tpe, capt) => BlockParam(id, monomorphize(tpe), capt) + +def monomorphize(blockType: BlockType)(using ctx: MonoContext): BlockType = blockType match + case BlockType.Function(tparams, cparams, vparams, bparams, result) => + BlockType.Function(List.empty, cparams, vparams map monomorphize, bparams map monomorphize, monomorphize(result)) + case b: BlockType.Interface => monomorphize(b) + +def monomorphize(valueType: ValueType)(using ctx: MonoContext): ValueType = valueType match + case ValueType.Var(name) => ValueType.Var(ctx.replacementTparams(name).tpe) + case ValueType.Data(name, targs) => replacementDataFromTargs(name, targs) + case o => println(o); ??? + +var monoCounter = 0 +def freshMonoName(baseId: Id, tpe: TypeArg.Base): Id = + monoCounter += 1 + Id(baseId.name.name + tpe.tpe.name.name + monoCounter) + +def freshMonoName(baseId: Id, tpes: Vector[TypeArg.Base]): Id = + if (tpes.length == 0) return baseId + + monoCounter += 1 + val tpesString = tpes.map(tpe => tpe.tpe.name.name).mkString + Id(baseId.name.name + tpesString + monoCounter) + +def replacementDataFromTargs(id: FunctionId, targs: List[ValueType])(using ctx: MonoContext): ValueType.Data = + if (targs.isEmpty) return ValueType.Data(id, targs) + // TODO: Incredibly hacky, resume did not seem to appear when finding constraints + // it does show up while monomorphizing which caused an error + // this seems to work for now + if (id.name.name == "Resume") return ValueType.Data(id, targs) + + def toTypeArg(vt: ValueType): TypeArg.Base = vt match + case ValueType.Data(name, targs) => TypeArg.Base(name, targs map toTypeArg) + case ValueType.Var(name) => ctx.replacementTparams(name) + case ValueType.Boxed(tpe, capt) => ??? + + val baseTypes: List[TypeArg.Base] = targs map toTypeArg + ValueType.Data(ctx.names((id, baseTypes.toVector)), List.empty) diff --git a/effekt/shared/src/main/scala/effekt/core/Transformer.scala b/effekt/shared/src/main/scala/effekt/core/Transformer.scala index 6a4d55c02..f116f32ba 100644 --- a/effekt/shared/src/main/scala/effekt/core/Transformer.scala +++ b/effekt/shared/src/main/scala/effekt/core/Transformer.scala @@ -143,6 +143,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { } }.toList ++ exports.namespaces.values.flatMap(transform) + // Add tparams separately def transform(c: symbols.Constructor)(using Context): core.Constructor = core.Constructor(c, c.tparams, c.fields.map(f => core.Field(f, transform(f.returnType)))) diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/Optimizer.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/Optimizer.scala index f494eb96c..b6d8dcc70 100644 --- a/effekt/shared/src/main/scala/effekt/core/optimizer/Optimizer.scala +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/Optimizer.scala @@ -23,10 +23,10 @@ object Optimizer extends Phase[CoreTransformed, CoreTransformed] { var tree = core - // (1) first thing we do is simply remove unused definitions (this speeds up all following analysis and rewrites) - tree = Context.timed("deadcode-elimination", source.name) { - Deadcode.remove(mainSymbol, tree) - } + // (1) first thing we do is simply remove unused definitions (this speeds up all following analysis and rewrites) + // tree = Context.timed("deadcode-elimination", source.name) { + // Deadcode.remove(mainSymbol, tree) + // } if !Context.config.optimize() then return tree; diff --git a/effekt/shared/src/main/scala/effekt/generator/llvm/LLVM.scala b/effekt/shared/src/main/scala/effekt/generator/llvm/LLVM.scala index 14ecb643b..40da90d5b 100644 --- a/effekt/shared/src/main/scala/effekt/generator/llvm/LLVM.scala +++ b/effekt/shared/src/main/scala/effekt/generator/llvm/LLVM.scala @@ -7,6 +7,8 @@ import effekt.core.optimizer import effekt.machine import kiama.output.PrettyPrinterTypes.{ Document, emptyLinks } import kiama.util.Source +import effekt.core.DeadCodeElimination +import effekt.generator.chez.DeadCodeElimination class LLVM extends Compiler[String] { @@ -52,7 +54,7 @@ class LLVM extends Compiler[String] { // ----------------------------------- object steps { // intermediate steps for VSCode - val afterCore = allToCore(Core) andThen Aggregate andThen optimizer.Optimizer andThen core.PolymorphismBoxing + val afterCore = allToCore(Core) andThen Aggregate andThen core.DeadCodeElimination andThen core.Mono andThen optimizer.Optimizer val afterMachine = afterCore andThen Machine map { case (mod, main, prog) => prog } val afterLLVM = afterMachine map { case machine.Program(decls, defns, entry) =>