diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Instrumentation.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Instrumentation.scala index b65d7427d6..5b6c699d9b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Instrumentation.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Instrumentation.scala @@ -4,12 +4,300 @@ package codegen import utils.* import hkmc2.Message.MessageContext -class Instrumentation(using Raise) extends BlockTransformer(new SymbolSubst()): - def transform(prgm: Program) = Program(prgm.imports, applyBlock(prgm.main)) - - override def applyDefn(d: Defn)(k: Defn => Block): Block = d match - case defn: ClsLikeDefn => - if defn.isym.defn.exists(_.hasStagedModifier.isDefined) && defn.companion.isDefined - then raise(WarningReport(msg"`staged` keyword doesn't do anything currently." -> defn.sym.toLoc :: Nil)) - super.applyDefn(defn)(k) - case b => super.applyDefn(b)(k) +import scala.collection.mutable.HashMap +import scala.util.chaining._ + +import mlscript.utils.*, shorthands.* + +import semantics.* +import semantics.Elaborator.State + +import syntax.{Literal, Tree} + +// it should be possible to cache some common constructions (End, Option) into the context +// this avoids having to rebuild the same shapes everytime they are needed + +// transform Block to Block IR so that it can be instrumented in mlscript +class InstrumentationImpl(using State): + type ArgWrappable = Path | Symbol + type Context = HashMap[Path, Path] + + def asArg(x: ArgWrappable): Arg = + x match + case p: Path => p.asArg + case l: Symbol => l.asPath.asArg + + // null and undefined are missing + def toValue(lit: Str | Int | BigDecimal | Bool): Value = + val l = lit match + case i: Int => Tree.IntLit(i) + case b: Bool => Tree.BoolLit(b) + case s: Str => Tree.StrLit(s) + case n: BigDecimal => Tree.DecLit(n) + Value.Lit(l) + + extension [A, B](ls: Ls[(A => B) => B]) + def collectApply(f: Ls[A] => B): B = + // defer applying k while prepending new elements to the list + ls.foldRight((_: Ls[A] => B)(Nil))((headCont, tailCont) => + k => + headCont: head => + tailCont: tail => + k(head :: tail) + )(f) + + // helpers for constructing Block + + def assign(res: Result, symName: Str = "tmp")(k: Path => Block): Block = + // TODO: skip assignment if res: Path? + val sym = new TempSymbol(N, symName) + Assign(sym, res, k(sym.asPath)) + + def tuple(elems: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = + assign(Tuple(false, elems.map(asArg)), symName)(k) + + def ctor(cls: Path, args: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = + assign(Instantiate(false, cls, args.map(asArg)), symName)(k) + + // isMlsFun is probably always true? + def call(fun: Path, args: Ls[ArgWrappable], isMlsFun: Bool = true, symName: Str = "tmp")(k: Path => Block): Block = + assign(Call(fun, args.map(asArg))(isMlsFun, false, false), symName)(k) + + // helpers for instrumenting Block + + def blockMod(name: Str) = summon[State].blockSymbol.asPath.selSN(name) + def optionMod(name: Str) = summon[State].optionSymbol.asPath.selSN(name) + + def blockCtor(name: Str, args: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = + ctor(blockMod(name), args, symName)(k) + def optionSome(arg: ArgWrappable, symName: Str = "tmp")(k: Path => Block): Block = + ctor(optionMod("Some"), Ls(arg), symName)(k) + def optionNone(symName: Str = "tmp")(k: Path => Block): Block = + assign(optionMod("None"), symName)(k) + + def blockCall(name: Str, args: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = + call(blockMod(name), args, symName = symName)(k) + + // linking functions defined in MLscipt + + def fnPrintCode(p: Path)(k: Block): Block = + // discard result, we only care about side effect + blockCall("printCode", Ls(p))(_ => k) + + def fnConcat(p1: Path, p2: Path, symName: String = "concat")(k: Path => Block): Block = + blockCall("concat", Ls(p1, p2), symName)(k) + + // transformation helpers + + def transformSymbol(sym: Symbol, symName: Str = "sym")(k: Path => Block): Block = + sym match + case clsSym: ClassSymbol => + transformParamsOpt(clsSym.defn.get.paramsOpt): paramsOpt => + blockCtor("ClassSymbol", Ls(toValue(sym.nme), paramsOpt), symName)(k) + case t: TermSymbol if t.defn.exists(_.sym.asCls.isDefined) => + transformSymbol(t.defn.get.sym.asCls.get, symName)(k) + case _ => blockCtor("Symbol", Ls(toValue(sym.nme)), symName)(k) + + def transformOption[A](xOpt: Opt[A], f: A => (Path => Block) => Block)(k: Path => Block): Block = + xOpt match + case S(x) => f(x)(optionSome(_)(k)) + case N => optionNone()(k) + + // instrumentation rules + + def ruleEnd(symName: String = "end")(k: Path => Block): Block = + blockCtor("End", Ls(), symName)(k) + + def ruleBranches(x: Path, p: Path, arms: Ls[Case -> Block], dflt: Opt[Block], symName: String = "branches")(using Context)(k: (Path, Context) => Block): Block = + def applyRuleBranch(cse: Case, block: Block)(f: Path => Context => Block)(ctx: Context): Block = + transformCase(cse): cse => + transformBlock(block)(using ctx.clone() += p -> x): (y, ctx) => + // TODO: use Arm type instead of Tup + tuple(Ls(cse, y), "branch"): cde => + f(cde)(ctx.clone() -= p) + + (arms.map(applyRuleBranch).collectApply(_: Ls[Path] => Context => Block)(summon)): arms => + ctx => + tuple(arms): arms => + ruleEnd(): e => + // TODO: use transformOption here + def dfltStaged(k: (Path, Context) => Block) = + dflt match + case S(dflt) => + transformBlock(dflt)(using ctx.clone() += p -> x): (dflt, ctx) => + optionSome(dflt)(k(_, ctx.clone() -= p)) + case N => optionNone()(k(_, ctx)) + dfltStaged: (dflt, ctx) => + blockCtor("Match", Ls(x, arms, dflt, e), symName)(k(_, ctx)) + + // transformations of Block + + def transformPath(p: Path)(using ctx: Context)(k: Path => Block): Block = + // rulePath + ctx.get(p).map(k).getOrElse: + p match + case Value.Ref(l, disamb) => + transformSymbol(disamb.getOrElse(l)): sym => + blockCtor("ValueRef", Ls(sym), "var")(k) + case l: Value.Lit => + blockCtor("ValueLit", Ls(l), "lit")(k) + case s @ Select(p, i @ Tree.Ident(name)) => + transformPath(p): x => + val sym = + if s.symbol.isDefined + then transformSymbol(s.symbol.get) + else blockCtor("Symbol", Ls(toValue(name))) + sym: sym => + blockCtor("Select", Ls(x, sym), "sel")(k) + case DynSelect(qual, fld, arrayIdx) => + transformPath(qual): x => + transformPath(fld): y => + blockCtor("DynSelect", Ls(x, y, toValue(arrayIdx)), "dynsel")(k) + case _ => ??? // not supported + + def transformResult(r: Result)(using Context)(k: Path => Block): Block = + r match + case p: Path => transformPath(p)(k) + case Tuple(mut, elems) => + assert(!mut, "mutable tuple not supported") + transformArgs(elems): xs => + tuple(xs.map(_._1)): codes => + blockCtor("Tuple", Ls(codes), "tup")(k) + case Instantiate(mut, cls, args) => + assert(!mut, "mutable instantiation not supported") + transformArgs(args): xs => + transformPath(cls): cls => + tuple(xs.map(_._1)): codes => + blockCtor("Instantiate", Ls(cls, codes), "inst")(k) + case Call(fun, args) => + transformPath(fun): fun => + transformArgs(args): args => + tuple(args.map(_._1)): tup => + blockCtor("Call", Ls(fun, tup), "app")(k) + case _ => ??? // not supported + + def transformArg(a: Arg)(using Context)(k: ((Path, Bool)) => Block): Block = + val Arg(spread, value) = a + transformOption(spread, bool => assign(toValue(bool))): spreadStaged => + transformPath(value): value => + blockCtor("Arg", Ls(spreadStaged, value)): cde => + k(cde, spread.isDefined) + + def transformArgs(args: Ls[Arg])(using Context)(k: Ls[(Path, Bool)] => Block): Block = + args.map(transformArg).collectApply(k) + + def transformParamList(ps: ParamList)(k: Path => Block) = + ps.params.map(p => transformSymbol(p.sym)).collectApply(tuple(_)(k)) + + def transformParamsOpt(pOpt: Opt[ParamList])(k: Path => Block) = + transformOption(pOpt, transformParamList)(k) + + def transformCase(cse: Case)(using Context)(k: Path => Block): Block = + cse match + case Case.Lit(lit) => blockCtor("Lit", Ls(Value.Lit(lit)))(k) + case Case.Cls(cls, path) => + transformSymbol(cls): cls => + transformPath(path): path => + blockCtor("Cls", Ls(cls, path))(k) + case Case.Tup(len, inf) => blockCtor("Tup", Ls(len, inf).map(toValue))(k) + case Case.Field(name, safe) => ??? // not supported + + // ruleBlk? + def transformBlock(b: Block)(using Context)(k: Path => Block): Block = + transformBlock(b)((p, _) => k(p)) + + def transformBlock(b: Block)(using ctx: Context)(k: (Path, Context) => Block): Block = + b match + case Return(res, implct) => + transformResult(res): x => + blockCtor("Return", Ls(x, toValue(implct)), "return")(k(_, ctx)) + case Assign(x, r, b) => + transformResult(r): y => + transformSymbol(x): xSym => + blockCtor("ValueRef", Ls(xSym)): xStaged => + (Assign(x, xStaged, _)): + given Context = ctx.clone() += x.asPath -> xStaged + transformBlock(b): (z, ctx) => + blockCtor("Assign", Ls(xSym, y, z), "assign")(k(_, ctx)) + case Define(cls: ClsLikeDefn, rest) => + assert(cls.companion.isEmpty, "nested module not supported") + (Define(cls, _)): + transformBlock(rest): p => + transformSymbol(cls.isym): c => + optionNone(): none => // TODO: handle companion object + blockCtor("ClsLikeDefn", Ls(c, none)): cls => + blockCtor("Define", Ls(cls, p)): p => + ruleEnd(): end => + fnPrintCode(p)(k(end, ctx)) + case End(_) => ruleEnd()(k(_, ctx)) + case Match(p, ks, dflt, rest) => + transformPath(p): x => + ruleBranches(x, p, ks, dflt): (stagedMatch, ctx) => + transformBlock(rest)(using ctx): (z, ctx) => + fnConcat(stagedMatch, z, "match")(k(_, ctx)) + case Begin(sub, rest) => + // TODO: This is untested as there is no test case that generates the Begin block yet + transformBlock(sub): (sub, ctx) => + transformBlock(rest)(using ctx): (rest, ctx) => + fnConcat(sub, rest)(k(_, ctx)) + case Scoped(syms, body) => + syms.toList.map(transformSymbol(_)).collectApply: symsStaged => + tuple(symsStaged): tup => + transformBlock(body): (body, ctx) => + blockCtor("Scoped", Ls(tup, body))(k(_, ctx)) + case Label(labelSymbol, loop, body, rest) => + transformSymbol(labelSymbol): labelSymbol => + transformBlock(body): (body, ctx) => + transformBlock(rest)(using ctx): (rest, ctx) => + blockCtor("Label", Ls(labelSymbol, toValue(loop), body, rest))(k(_, ctx)) + case Break(labelSymbol) => + transformSymbol(labelSymbol): labelSymbol => + blockCtor("Break", Ls(labelSymbol))(k(_, ctx)) + case _ => ??? // not supported + + // f.owner returns an InnerSymbol, but we need BlockMemberSymbol of the module to call the function + // so we pass modSym instead + def transformFunDefn(modSym: BlockMemberSymbol, f: FunDefn): (FunDefn, Block) = + val genSym = BlockMemberSymbol(f.sym.nme + "_gen", Nil, true) + val sym = modSym.asPath.selSN(genSym.nme) + // NOTE: this debug printing only works for top-level modules, nested modules don't work + // TODO: remove it. only for test + val debug = blockCtor("ValueLit", Ls(Value.Lit(Tree.UnitLit(false)))): undef => + // TODO: put correct parameters instead of undefined + f.params.map(ps => List.fill(ps.params.length)(undef)) + .foldRight((p: Path) => fnPrintCode(p)(End()))((args, cont) => call(_, args)(cont))(sym) + + val dSym = TermSymbol(f.dSym.k, f.dSym.owner, Tree.Ident(f.sym.nme + "_gen")) + val args = f.params.flatMap(_.params).map(_.sym) + val newBody = + given Context = HashMap(args.map(s => Value.Ref(s, N) -> Value.Ref(s, N))*) + transformBlock(f.body)(Return(_, false)) + val newFun = f.copy(sym = genSym, dSym = dSym, body = newBody)(false) + (newFun, debug) + +// TODO: rename as InstrumentationTransformer? +class Instrumentation(using State) extends BlockTransformer(new SymbolSubst()): + val impl = new InstrumentationImpl + + def concat(b1: Block, b2: Block): Block = + b1.mapTail { + case _: End => b2 + case _ => ??? + } + + override def applyBlock(b: Block): Block = + super.applyBlock(b) match + // find modules with staged annotation + case Define(c: ClsLikeDefn, rest) if c.companion.exists(_.isym.defn.exists(_.hasStagedModifier.isDefined)) => + val sym = c.sym.subst + val companion = c.companion.get + val (stagedMethods, debugPrintCode) = companion.methods + .map(impl.transformFunDefn(sym, _)) + .unzip + val newCtor = impl.transformBlock(companion.ctor)(using new HashMap())(_ => End()) + val newCompanion = companion.copy(methods = companion.methods ++ stagedMethods, ctor = newCtor) + val newModule = c.copy(sym = sym, companion = S(newCompanion)) + // debug is printed after definition + val debugBlock = debugPrintCode.foldRight(rest)(concat) + Define(newModule, debugBlock) + case b => b diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index eec24fd8a6..b8849527e4 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -20,7 +20,7 @@ import semantics.Term.{Throw => _, *} import semantics.Elaborator.{State, Ctx, ctx} import syntax.{Literal, Tree} -import hkmc2.syntax.Fun +import hkmc2.syntax.{Fun, Keyword} abstract class TailOp extends (Result => Block) @@ -264,14 +264,20 @@ class Lowering()(using Config, TL, Raise, State, Ctx): mod.classCompanion match case S(comp) => comp.defn.getOrElse(wat("Module companion without definition", mod.companion)) case N => - ClassDef.Plain(mod.owner, syntax.Cls, new ClassSymbol(Tree.DummyTypeDef(syntax.Cls), mod.sym.id), + val clsSymb = new ClassSymbol(Tree.DummyTypeDef(syntax.Cls), mod.sym.id) + val stagedAnnots = mod.annotations.collect { + case Annot.Modifier(Keyword.`staged`) => Annot.Modifier(Keyword.`staged`) + } + val newDefn = ClassDef.Plain(mod.owner, syntax.Cls, clsSymb, mod.bsym, Nil, N, ObjBody(Blk(Nil, UnitVal())), S(mod.sym), - Nil, + stagedAnnots ) + clsSymb.defn = S(newDefn) + newDefn case _ => _defn reportAnnotations(defn, defn.extraAnnotations) val bufferableAnnots = defn.annotations.flatMap: @@ -1058,7 +1064,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): val merged = MergeMatchArmTransformer.applyBlock(bufferable) val staged = - if config.stageCode then Instrumentation(using summon).applyBlock(merged) + if config.stageCode then Instrumentation().applyBlock(merged) else merged val res = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index d36932b598..4cd069ff8c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -101,8 +101,9 @@ object Printer: case Select(qual, name) => val docQual = mkDocument(qual) doc"${docQual}.${name.name}" + case DynSelect(qual, fld, ai) => + doc"${mkDocument(qual)}.(${mkDocument(fld)})" case x: Value => mkDocument(x) - case _ => TODO(path) def mkDocument(result: Result)(using Raise, Scope): Document = result match case Call(fun, args) => doc"${mkDocument(fun)}(${args.map(mkDocument).mkDocument(", ")})" diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala index 8711617ac1..fd709741b1 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala @@ -257,7 +257,7 @@ object Elaborator: val prettyPrintSymbol = TempSymbol(N, "prettyPrint") val termSymbol = TempSymbol(N, "Term") val blockSymbol = TempSymbol(N, "Block") - val shapeSymbol = TempSymbol(N, "Shape") + val optionSymbol = TempSymbol(N, "option") val wasmSymbol = TempSymbol(N, "wasm") val effectSigSymbol = ClassSymbol(DummyTypeDef(syntax.Cls), Ident("EffectSig")) val nonLocalRetHandlerTrm = diff --git a/hkmc2/shared/src/test/mlscript-compile/Block.mls b/hkmc2/shared/src/test/mlscript-compile/Block.mls index e69de29bb2..360f4723c9 100644 --- a/hkmc2/shared/src/test/mlscript-compile/Block.mls +++ b/hkmc2/shared/src/test/mlscript-compile/Block.mls @@ -0,0 +1,189 @@ +import "./Predef.mls" +import "./Option.mls" +import "./StrOps.mls" + +open Predef +open StrOps +open Option + +module Block with... + +type Opt[A] = Option[A] + +// dependancies referenced in Block classes, referencing implementation in Term.mls + +type Literal = null | undefined | Str | Int | Num | Bool + +type ParamList = Array[Symbol] + +class Symbol(val name: Str) +type Ident = Symbol +// this is so that we're able to retrieve information about the class from the symbol +class ClassSymbol(val name: Str, val paramsOpt: Opt[ParamList]) extends Symbol(name) + +class Arg(val spread: Opt[Bool], val value: Path) + +class Case with + constructor + Lit(val lit: Literal) + Cls(val cls: Symbol, val path: Path) + Tup(val len: Int, val inf: Bool) + +class Result with + constructor + Call(val _fun: Path, val args: Array[Arg]) + Instantiate(val cls: Path, val args: Array[Arg]) // assume immutable + Tuple(val elems: Array[Arg]) // assume immutable + +class Path extends Result with + constructor + Select(val qual: Path, val name: Ident) + DynSelect(val qual: Path, val fld: Path, val arrayIdx: Bool) // is arrayIdx used? + ValueRef(val l: Symbol) + ValueLit(val lit: Literal) + +class Defn with + constructor + ValDefn(val sym: Symbol, val rhs: Path) + ClsLikeDefn(val sym: ClassSymbol, val companion: Opt[ClsLikeBody]) // companion unused + FunDefn(val sym: Symbol, val params: Array[ParamList], val body: Block, val stage: Bool) + +class ClsLikeBody(val isym: Symbol, val methods: Array[FunDefn], val publicFields: Array[[Symbol, Symbol]]) // unused + +class Block with + constructor + Match(val scrut: Path, val arms: Array[[Case, Block]], val dflt: Opt[Block], val rest: Block) + Return(val res: Result, val implct: Bool) + Assign(val lhs: Symbol, val rhs: Result, val rest: Block) + Define(val defn: Defn, val rest: Block) + // TODO: [fyp] handle Scoped, Label, Break nodes + Scoped(val symbols: Array[Symbol], val rest: Block) + Label(val labelSymbol: Symbol, val loop: Bool, val body: Block, val rest: Block) + Break(val labelSymbol: Symbol) + End() + +fun concat(b1: Block, b2: Block) = if b1 is + Match(scrut, arms, dflt, rest) then Match(scrut, arms, dflt, concat(rest, b2)) + Return(res, implct) then b2 // discard return? + Assign(lhs, rhs, rest) then Assign(lhs, rhs, concat(rest, b2)) + Define(defn, rest) then Define(defn, concat(rest, b2)) + Scoped(symbols, rest) then Scoped(symbols, concat(rest, b2)) + Label(labelSymbol, loop, body, rest) then Label(labelSymbol, loop, body, concat(rest, b2)) + Break then ??? // unreachable + End() then b2 + +fun showBool(b: Bool) = if b then "true" else "false" + +fun showLiteral(l: Literal) = + if l is + undefined then "undefined" + null then "null" + String then "\"" + l.toString() + "\"" + else l.toString() + +fun showSymbol(s: Symbol) = + // console.log("printing " + s) + if s is + ClassSymbol(name, args) then + "ClassSymbol(" + "\"" + name + "\"" + + if args + is Some(args) then ":[" + args.map(showSymbol).join(", ") + "]" + is None then "" + + ")" + _ then "Symbol(" + "\"" + s.name + "\"" + ")" + +fun showIdent(i: Ident) = showSymbol(i) + +fun showPath(p: Path): Str = + if p is + Select(qual, name) then + "Select(" + showPath(qual) + ", " + showIdent(name) + ")" + DynSelect(qual, fld, arrayIdx) then + "DynSelect(" + showPath(qual) + ", " + showPath(fld) + ", " + showBool(arrayIdx) + ")" + ValueRef(l) then + "Ref(" + showSymbol(l) + ")" + ValueLit(lit) then + "Lit(" + showLiteral(lit) + ")" + +fun showArg(arg: Arg) = + if arg.spread is + Some(true) then "..." + Some(false) then ".." + else "" + + showPath(arg.value) + +fun showArgs(args: Array[Arg]) = + "[" + args.map(showArg).join(", ") + "]" + +// Case (match arm patterns) +fun showCase(c: Case): Str = + if c is + Lit(lit) then "Lit(" + showLiteral(lit) + ")" + Cls(cls, path) then "Cls(" + showSymbol(cls) + ", " + showPath(path) + ")" + Tup(len, inf) then "Tup(" + len + ", " + inf + ")" + +fun showResult(r: Result): Str = + if r is + Path then showPath(r) + Call(f, args) then "Call(" + showPath(f) + ", " + showArgs(args) + ")" + Instantiate(cls, args) then "Instantiate(" + showPath(cls) + ", " + showArgs(args) + ")" + Tuple(elems) then "Tuple(" + showArgs(elems) + ")" + +fun showParamList(ps: ParamList) = + "[" + ps.map(s => showSymbol(s)).join(", ") + "]" + +fun showDefn(d: Defn): Str = + if d is + ValDefn(sym, rhs) then + "ValDefn(" + showSymbol(sym) + ", " + showPath(rhs) + ")" + FunDefn(sym, params, body, stage) then + "FunDefn(" + showSymbol(sym) + ", " + + "(" + params.map(showParamList) + "), " + + showBlock(body) + ", " + + stage + ")" + ClsLikeDefn(sym, companion) then + // TODO: print rest of the arguments + "ClsLikeDefn(" + showSymbol(sym) + ", " + "TODO" + ")" + +fun showOptBlock(ob: Opt[Block]) = + if ob is Some(b) then showBlock(b) else "None" + +fun showArm(pair: Case -> Block) = + if pair is [cse, body] then showCase(cse) + " -> " + showBlock(body) else "" + +fun showBlock(b: Block): Str = + if b is + Match(scrut, arms, dflt, rest) then + "Match(" + + showPath(scrut) + ", " + + "[" + arms.map(showArm).join(", ") + "], " + + showOptBlock(dflt) + ", " + + showBlock(rest) + ")" + Return(res, implct) then + "Return(" + showResult(res) + ", " + showBool(implct) + ")" + Assign(lhs, rhs, rest) then + "Assign(" + showSymbol(lhs) + ", " + showResult(rhs) + ", " + showBlock(rest) + ")" + Define(defn, rest) then + "Define(" + showDefn(defn) + ", " + showBlock(rest) + ")" + Scoped(symbols, rest) then + "Scoped([" + symbols.map(showSymbol).join(", ") + "], " + showBlock(rest) + ")" + Label(labelSymbol, loop, body, rest) then + "Label(" + showSymbol(labelSymbol) + ", " + showBool(loop) + ", " + showBlock(body) + ", " + showBlock(rest) + + ")" + Break(labelSymbol) then + "Break(" + showSymbol(labelSymbol) + ")" + End() then "End" + +fun show(x) = + if x is + Symbol then showSymbol(x) + Path then showPath(x) + Result then showResult(x) + Case then showCase(x) + Defn then showDefn(x) + Block then showBlock(x) + else + "" + +fun printCode(x) = print(show(x)) + +fun compile(p: Block) = ??? \ No newline at end of file diff --git a/hkmc2/shared/src/test/mlscript-compile/Shape.mls b/hkmc2/shared/src/test/mlscript-compile/Shape.mls index e69de29bb2..0b724cfcd7 100644 --- a/hkmc2/shared/src/test/mlscript-compile/Shape.mls +++ b/hkmc2/shared/src/test/mlscript-compile/Shape.mls @@ -0,0 +1,120 @@ +import "./Block.mls" +import "./Option.mls" + +open Block { Literal, Symbol, ClassSymbol, showSymbol } +open Option + +type Shape = Shape.Shape + +fun isPrimitiveType(sym: Symbol) = + if sym.name is + "Str" then true + "Int" then true + "Num" then true + "Bool" then true + else false + +fun isPrimitiveTypeOf(sym: Symbol, l: Literal) = + if [sym.name, l] is + ["Str", l] and l is Str then true + ["Int", i] and i is Int then true + ["Num", n] and n is Num then true + ["Bool", b] and b is Bool then true + else false + +module Shape with... + +class Shape with + constructor + Dyn() + Lit(val l: Literal) + Arr(val shapes: Array[Shape], val inf: Bool) + Class(val sym: ClassSymbol, val params: Array[Shape]) + +fun show(s: Shape) = + if s is + Dyn then "Dyn" + Lit(lit) then "Lit(" + Block.showLiteral(lit) + ")" + Arr(shapes, inf) then "Arr([" + shapes.map(show).join(", ") + "], " + inf + ")" + Class(sym, params) then "Class(" + showSymbol(sym) + ", [" + params.map(show).join(", ") + "])" + +fun zipMrg[A](a: Array[A], b: Array[A]): Array[A] = + a.map((a, i, _) => mrg2(a, b.at(i))) + +// TODO: remove, this is no longer in use +fun mrg2(s1: Shape, s2: Shape) = + if s1 == s2 then s1 + else if [s1, s2] is + [Lit(l), Class(sym, params)] + and isPrimitiveTypeOf(sym, l) + then Class(sym, params) + [Class(sym1, ps), Class(sym2, s2)] + and sym1.name == sym2.name + then Class(sym1, ps.map(p => [p.0, zipMrg(p.1, s2)])) + [Arr(s1, false), Arr(s2, false)] + and s1.length == s2.length + then Arr(zipMrg(s1, s2), false) + else Dyn() + +fun mrg(s1: Array[Shape]) = + s1.reduceRight((acc, s, _, _) => mrg2(s, acc)) + +fun sel(s1: Shape, s2: Shape): Array[Shape] = + if [s1, s2] is + [Class(sym, params), Lit(n)] and n is Str + and sym.args is Some(args) + and args.find(_ == n) + == () then [] + is i then [params.(i)] + [Dyn, Lit(n)] and n is Str + then [Dyn()] + [Arr(shapes, false), Lit(n)] and n is Int + then [shapes.(n)] + [Arr(shapes, false), Dyn] then + shapes + [Arr(shapes, true), _] then [Dyn()] // TODO + [Dyn, Lit(n)] and n is Int + then [Dyn()] + [Dyn, Dyn] + then [Dyn()] + else [] // TODO: return no possibility instead of err? + +fun static(s: Shape) = + if s is + Dyn then false + Lit(l) then not (l is Str and isPrimitiveType(l)) // redundant bracket? + Class(_, params) then params.every(static) + Arr(shapes, false) then shapes.every(static) + Arr(shapes, true) then false // TODO + +open Block { Case } + +fun silh(p: Case): Shape = if p is + Block.Lit(l) then Lit(l) + Block.Cls(sym, path) then + val size = if sym.args is Some(i) then i else 0 + Class(sym, Array(size).fill(Dyn)) + Block.Tup(n, inf) then Arr(Array(n).fill(Dyn), inf) + +// TODO: use Option instead, since all of them return at most one shape +fun filter(s: Shape, p: Case): Array[Shape] = + if [s, p] is + [Lit(l1), Block.Lit(l2)] and l1 == l2 then [s] + [Lit(l), Block.Cls(c, _)] and isPrimitiveTypeOf(c, l) then [s] + [Arr(ls, false), Block.Tup(n, false)] and ls.length == n then [s] + [Arr(ls, true), _] then [s] // TODO + [_, Block.Tup(ls, true)] then [s] // TODO + [Class(c1, _), Block.Cls(c2, _)] and c1.name == c2.name then [s] + [Dyn, _] then [silh(p)] + else [] + +fun rest(s: Shape, p: Case): Array[Shape] = + if [s, p] is + [Lit(l1), Block.Lit(l2)] and l1 == l2 then [] + [Lit(l), Block.Cls(c, _)] and isPrimitiveTypeOf(c, l) then [] + [Arr(ls, false), Block.Tup(n, false)] and ls.length == n then [] + [Arr(ls, true), _] then [s] // TODO + [_, Block.Tup(ls, true)] then [s] // TODO + [Class(c1, _), Block.Cls(c2, _)] and c1.name == c2.name then [] + [Dyn, _] then [s] + else [s] \ No newline at end of file diff --git a/hkmc2/shared/src/test/mlscript/block-staging/Functions.mls b/hkmc2/shared/src/test/mlscript/block-staging/Functions.mls new file mode 100644 index 0000000000..37d277bf6f --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/block-staging/Functions.mls @@ -0,0 +1,91 @@ +:js +:staging + +val x = [1, 2, 3] +staged module Expressions with + fun lit() = 1 + fun assign() = + let x = 42 + let y = x + y + fun tup1() = [1, 2] + fun tup2() = [1, ..x] + fun dynsel() = [1].(0) + fun match1() = + if 9 is + Bool then 1 + 8 then 2 + Int then 3 + 9 then 4 + else 0 + fun match2() = + if [...x] is + [] then 1 + [1, 2] then 2 + [a, ...] then 3 + else 0 +//│ > Return(Lit(1), false) +//│ > Scoped([Symbol("x"), Symbol("y")], Assign(Symbol("x"), Lit(42), Assign(Symbol("y"), Ref(Symbol("x")), Return(Ref(Symbol("y")), false)))) +//│ > Return(Tuple([Lit(1), Lit(2)]), false) +//│ > Return(Tuple([Lit(1), ..Ref(Symbol("x"))]), false) +//│ > Scoped([Symbol("tmp")], Assign(Symbol("tmp"), Tuple([Lit(1)]), Return(DynSelect(Ref(Symbol("tmp")), Lit(0), false), false))) +//│ > Scoped([Symbol("scrut")], Assign(Symbol("scrut"), Lit(9), Match(Ref(Symbol("scrut")), [Cls(ClassSymbol("Bool"), Select(Ref(Symbol("runtime")), Symbol("unreachable"))) -> Return(Lit(1), false), Lit(8) -> Return(Lit(2), false), Cls(ClassSymbol("Int"), Select(Ref(Symbol("runtime")), Symbol("unreachable"))) -> Return(Lit(3), false)], Return(Lit(0), false), End))) +//│ > Scoped([Symbol("element0$"), Symbol("element1$"), Symbol("scrut"), Symbol("a"), Symbol("tmp"), Symbol("middleElements")], Label(Symbol("split_root$"), false, Label(Symbol("split_1$"), false, Label(Symbol("split_2$"), false, Assign(Symbol("scrut"), Tuple([...Ref(Symbol("x"))]), Match(Ref(Symbol("scrut")), [Tup(0, false) -> Assign(Symbol("tmp"), Lit(1), Break(Symbol("split_root$"))), Tup(2, false) -> Assign(Symbol("element0$"), Call(Select(Select(Ref(Symbol("runtime")), Symbol("Tuple")), Symbol("get")), [Ref(Symbol("scrut")), Lit(0)]), Assign(Symbol("element1$"), Call(Select(Select(Ref(Symbol("runtime")), Symbol("Tuple")), Symbol("get")), [Ref(Symbol("scrut")), Lit(1)]), Match(Ref(Symbol("element0$")), [Lit(1) -> Match(Ref(Symbol("element1$")), [Lit(2) -> Assign(Symbol("tmp"), Lit(2), Break(Symbol("split_root$")))], Match(Ref(Symbol("scrut")), [Tup(1, true) -> Assign(Symbol("middleElements"), Call(Select(Select(Ref(Symbol("runtime")), Symbol("Tuple")), Symbol("slice")), [Ref(Symbol("scrut")), Lit(1), Lit(0)]), Assign(Symbol("a"), Ref(Symbol("element0$")), Break(Symbol("split_1$"))))], Break(Symbol("split_2$")), End), End)], Match(Ref(Symbol("scrut")), [Tup(1, true) -> Assign(Symbol("middleElements"), Call(Select(Select(Ref(Symbol("runtime")), Symbol("Tuple")), Symbol("slice")), [Ref(Symbol("scrut")), Lit(1), Lit(0)]), Assign(Symbol("a"), Ref(Symbol("element0$")), Break(Symbol("split_1$"))))], Break(Symbol("split_2$")), End), End))), Tup(1, true) -> Assign(Symbol("element0$"), Call(Select(Select(Ref(Symbol("runtime")), Symbol("Tuple")), Symbol("get")), [Ref(Symbol("scrut")), Lit(0)]), Assign(Symbol("middleElements"), Call(Select(Select(Ref(Symbol("runtime")), Symbol("Tuple")), Symbol("slice")), [Ref(Symbol("scrut")), Lit(1), Lit(0)]), Assign(Symbol("a"), Ref(Symbol("element0$")), Break(Symbol("split_1$")))))], Break(Symbol("split_2$")), End)), Assign(Symbol("tmp"), Lit(0), Break(Symbol("split_root$")))NaN, Assign(Symbol("tmp"), Lit(3), End)NaN, Return(Ref(Symbol("tmp")), false)NaN) +//│ x = [1, 2, 3] + +class Outside(a) +staged module ClassInstrumentation with + class Inside(a, b) + class NoArg + fun inst1() = new Outside(1) + fun inst2() = new NoArg + fun app1() = Outside(1) + fun app2() = Inside(1, 2) +//│ > Define(ClsLikeDefn(ClassSymbol("NoArg"), TODO), End) +//│ > Define(ClsLikeDefn(ClassSymbol("Inside":[Symbol("a"), Symbol("b")]), TODO), End) +//│ > Return(Instantiate(Ref(ClassSymbol("Outside":[Symbol("a")])), [Lit(1)]), false) +//│ > Return(Instantiate(Select(Ref(Symbol("ClassInstrumentation")), ClassSymbol("NoArg")), []), false) +//│ > Return(Call(Ref(ClassSymbol("Outside":[Symbol("a")])), [Lit(1)]), false) +//│ > Return(Call(Select(Ref(Symbol("ClassInstrumentation")), ClassSymbol("Inside":[Symbol("a"), Symbol("b")])), [Lit(1), Lit(2)]), false) + +staged module Arguments with + fun f(x) = + x = 1 + x + fun g(x)(y, z)() = z +//│ > Assign(Symbol("x"), Lit(1), Return(Ref(Symbol("x")), false)) +//│ > Return(Lit(undefined), false) + +staged module OtherBlocks with + fun scope() = + scope.locally of ( + let a = 1 + a + ) + fun breakAndLabel() = + if 1 is + 2 then 0 + 3 then 0 + else 0 +//│ > Scoped([Symbol("a")], Assign(Symbol("a"), Lit(1), Return(Call(Select(Select(Ref(Symbol("OtherBlocks")), Symbol("scope")), Symbol("locally")), [Ref(Symbol("a"))]), false))) +//│ > Scoped([Symbol("scrut"), Symbol("tmp")], Label(Symbol("split_root$"), false, Label(Symbol("split_1$"), false, Assign(Symbol("scrut"), Lit(1), Match(Ref(Symbol("scrut")), [Lit(2) -> Break(Symbol("split_1$")), Lit(3) -> Break(Symbol("split_1$"))], Break(Symbol("split_1$")), End)), Assign(Symbol("tmp"), Lit(0), End)NaN, Return(Ref(Symbol("tmp")), false)NaN) + +// debug printing fails, collision with class name? +:fixme +class A() +staged module A with + fun f() = 1 +//│ ═══[RUNTIME ERROR] TypeError: A1.f_gen is not a function + +// debug printing fails, unable to reference the class when calling the instrumented function +:fixme +module A with + staged module B with + fun f() = 1 +//│ ╔══[COMPILATION ERROR] No definition found in scope for member 'B' +//│ ╟── which references the symbol introduced here +//│ ║ l.83: staged module B with +//│ ║ ^^^^^^ +//│ ║ l.84: fun f() = 1 +//│ ╙── ^^^^^^^^^^^^^^^ +//│ ═══[RUNTIME ERROR] ReferenceError: B is not defined diff --git a/hkmc2/shared/src/test/mlscript/block-staging/PrintCode.mls b/hkmc2/shared/src/test/mlscript/block-staging/PrintCode.mls new file mode 100644 index 0000000000..bff7052850 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/block-staging/PrintCode.mls @@ -0,0 +1,8 @@ +:staging +:js + +import "../../mlscript-compile/Block.mls" + +Block.printCode(Block.FunDefn(Block.Symbol("f"), [[Block.Symbol("x")]], Block.Return(Block.ValueLit(1), false), false)) +//│ > FunDefn(Symbol("f"), ([Symbol("x")]), Return(Lit(1), false), false) + diff --git a/hkmc2/shared/src/test/mlscript/staging/Syntax.mls b/hkmc2/shared/src/test/mlscript/block-staging/Syntax.mls similarity index 71% rename from hkmc2/shared/src/test/mlscript/staging/Syntax.mls rename to hkmc2/shared/src/test/mlscript/block-staging/Syntax.mls index 94accabd06..0eaa38428e 100644 --- a/hkmc2/shared/src/test/mlscript/staging/Syntax.mls +++ b/hkmc2/shared/src/test/mlscript/block-staging/Syntax.mls @@ -18,11 +18,6 @@ staged fun f() = 0 :js :slot -:staging -:w staged module A -//│ ╔══[WARNING] `staged` keyword doesn't do anything currently. -//│ ║ l.23: staged module A -//│ ╙── ^^^^^^^^ //│ Pretty Lowered: //│ define staged class A in set block$res = undefined in end diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala index a134549583..48ab073eec 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala @@ -33,7 +33,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: val runtimeNme = baseScp.allocateName(Elaborator.State.runtimeSymbol) val termNme = baseScp.allocateName(Elaborator.State.termSymbol) val blockNme = baseScp.allocateName(Elaborator.State.blockSymbol) - val shapeNme = baseScp.allocateName(Elaborator.State.shapeSymbol) + val optionNme = baseScp.allocateName(Elaborator.State.optionSymbol) val definitionMetadataNme = baseScp.allocateName(Elaborator.State.definitionMetadataSymbol) val prettyPrintNme = baseScp.allocateName(Elaborator.State.prettyPrintSymbol) @@ -61,7 +61,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: if importQQ.isSet then importRuntimeModule(termNme, termFile) if stageCode.isSet then importRuntimeModule(blockNme, blockFile) - importRuntimeModule(shapeNme, shapeFile) + importRuntimeModule(optionNme, optionFile) h private var hostCreated = false diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala index c3a4070953..eb6e0c33c3 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala @@ -21,7 +21,7 @@ abstract class MLsDiffMaker extends DiffMaker: val runtimeFile: io.Path = predefFile.up / "Runtime.mjs" // * Contains MLscript runtime definitions val termFile: io.Path = predefFile.up / "Term.mjs" // * Contains MLscript runtime term definitions val blockFile: io.Path = predefFile.up / "Block.mjs" // * Contains MLscript runtime block definitions - val shapeFile: io.Path = predefFile.up / "Shape.mjs" // * Contains MLscript runtime shape definitions + val optionFile: io.Path = predefFile.up / "Option.mjs" // * Contains MLscipt runtime option definition val wd = file.up @@ -160,12 +160,6 @@ abstract class MLsDiffMaker extends DiffMaker: PrefixApp(Keywrd(`import`), StrLit(predefFile.toString)) :: Open(Ident("Predef")) :: Nil) - if stageCode.isSet then - given Config = mkConfig - processTrees( - PrefixApp(Keywrd(`import`), StrLit(blockFile.toString)) - :: PrefixApp(Keywrd(`import`), StrLit(shapeFile.toString)) - :: Nil) super.init()