Skip to content

Commit 62cb284

Browse files
authored
Merge Match Arms (#273)
* merge match arms in lowering * add missing cases in blocktransformer * lazy val flatten for block, keep Begin * move trivial intermediate stmts out for merging match arms
1 parent 8f16d1a commit 62cb284

File tree

16 files changed

+482
-201
lines changed

16 files changed

+482
-201
lines changed

hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ sealed abstract class Block extends Product with AutoLocated:
4949
case Begin(sub, rst) => sub.size + rst.size
5050
case Assign(_, _, rst) => 1 + rst.size
5151
case AssignField(_, _, _, rst) => 1 + rst.size
52+
case AssignDynField(_, _, _, _, rst) => 1 + rst.size
5253
case Match(_, arms, dflt, rst) =>
5354
1 + arms.map(_._2.size).sum + dflt.map(_.size).getOrElse(0) + rst.size
5455
case Define(_, rst) => 1 + rst.size
@@ -132,6 +133,45 @@ sealed abstract class Block extends Product with AutoLocated:
132133
case _ => super.applyBlock(b)
133134

134135
(transformer.applyBlock(this), defns)
136+
137+
lazy val flatten: Block =
138+
// traverses a Block like a list, flatten `Begin`s using an accumulator
139+
// returns the flattend but reversed Block (with the dummy tail `End("for flatten only")`) and the actual tail of the Block
140+
def getReversedFlattenAndTrueTail(b: Block, acc: Block): (Block, BlockTail) = b match
141+
case Match(scrut, arms, dflt, rest) => getReversedFlattenAndTrueTail(rest, Match(scrut, arms, dflt, acc))
142+
case Label(label, body, rest) => getReversedFlattenAndTrueTail(rest, Label(label, body, acc))
143+
case Begin(sub, rest) =>
144+
val (firstBlockRev, firstTail) = getReversedFlattenAndTrueTail(sub, acc)
145+
firstTail match
146+
case _: End => getReversedFlattenAndTrueTail(rest, firstBlockRev)
147+
// if the tail of `sub` is not `End`, ignore the `rest` of this `Begin`
148+
case _ => firstBlockRev -> firstTail
149+
case TryBlock(sub, finallyDo, rest) => getReversedFlattenAndTrueTail(rest, TryBlock(sub, finallyDo, acc))
150+
case Assign(lhs, rhs, rest) => getReversedFlattenAndTrueTail(rest, Assign(lhs, rhs, acc))
151+
case a@AssignField(lhs, nme, rhs, rest) => getReversedFlattenAndTrueTail(rest, AssignField(lhs, nme, rhs, acc)(a.symbol))
152+
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => getReversedFlattenAndTrueTail(rest, AssignDynField(lhs, fld, arrayIdx, rhs, acc))
153+
case Define(defn, rest) => getReversedFlattenAndTrueTail(rest, Define(defn, acc))
154+
case HandleBlock(lhs, res, par, args, cls, handlers, body, rest) => getReversedFlattenAndTrueTail(rest, HandleBlock(lhs, res, par, args, cls, handlers, body, acc))
155+
case t: BlockTail => acc -> t
156+
157+
// reverse the Block returnned from the previous function,
158+
// which does not contain `Begin` (except for the nested ones),
159+
// and whose tail must be the dummy `End("for flatten only")`
160+
def rev(b: Block, t: Block): Block = b match
161+
case Match(scrut, arms, dflt, rest) => rev(rest, Match(scrut, arms, dflt, t))
162+
case Label(label, body, rest) => rev(rest, Label(label, body, t))
163+
case TryBlock(sub, finallyDo, rest) => rev(rest, TryBlock(sub, finallyDo, t))
164+
case Assign(lhs, rhs, rest) => rev(rest, Assign(lhs, rhs, t))
165+
case a@AssignField(lhs, nme, rhs, rest) => rev(rest, AssignField(lhs, nme, rhs, t)(a.symbol))
166+
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => rev(rest, AssignDynField(lhs, fld, arrayIdx, rhs, t))
167+
case Define(defn, rest) => rev(rest, Define(defn, t))
168+
case HandleBlock(lhs, res, par, args, cls, handlers, body, rest) => rev(rest, HandleBlock(lhs, res, par, args, cls, handlers, body, t))
169+
case End(msg) => t
170+
case _: BlockTail => ??? // unreachable
171+
case Begin(sub, rest) => ??? // unreachable
172+
173+
val (flattenRev, actualTail) = getReversedFlattenAndTrueTail(this, End("for flatten only"))
174+
rev(flattenRev, actualTail)
135175

136176
end Block
137177

hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ class BlockTransformer(subst: SymbolSubst):
8484
if (l2 is l) && (res2 is res) && (par2 is par) && (args2 is args) &&
8585
(cls2 is cls) && (hdr2 is hdr) && (bod2 is bod) && (rst2 is rst)
8686
then b else HandleBlock(l2, res2, par2, args2, cls2, hdr2, bod2, rst2)
87+
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) =>
88+
val lhs2 = applyPath(lhs)
89+
val fld2 = applyPath(fld)
90+
val rhs2 = applyResult(rhs)
91+
val rest2 = applyBlock(rest)
92+
if (lhs2 is lhs) && (fld2 is fld) && (rhs2 is rhs) && (rest2 is rest)
93+
then b
94+
else AssignDynField(lhs2, fld2, arrayIdx, rhs2, rest2)
95+
8796

8897
def applyResult2(r: Result)(k: Result => Block): Block = k(applyResult(r))
8998

@@ -104,6 +113,12 @@ class BlockTransformer(subst: SymbolSubst):
104113
val sym2 = p.symbol.mapConserve(_.subst)
105114
if (qual2 is qual) && (sym2 is p.symbol) then p else Select(qual2, name)(sym2)
106115
case v: Value => applyValue(v)
116+
case DynSelect(qual, fld, ai) =>
117+
val qual2 = applyPath(qual)
118+
val fld2 = applyPath(fld)
119+
if (qual2 is qual) && (fld2 is fld)
120+
then p
121+
else DynSelect(qual2, fld2, ai)
107122

108123
def applyValue(v: Value): Value = v match
109124
case Value.Ref(l) =>

hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,8 +589,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
589589
case N => res
590590
case S(sts) => StackSafeTransform(sts.stackLimit).transformTopLevel(res)
591591

592-
if lowerHandlers then HandlerLowering().translateTopLevel(stackSafe)
593-
else stackSafe
592+
MergeMatchArmTransformer.applyBlock(
593+
if lowerHandlers then HandlerLowering().translateTopLevel(stackSafe)
594+
else stackSafe
595+
)
594596

595597
def program(main: st): Program =
596598
def go(acc: Ls[Local -> Str], trm: st): Program =
@@ -732,3 +734,42 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State)
732734
)
733735

734736

737+
object TrivialStatementsAndMatch:
738+
def unapply(b: Block): Opt[(Opt[Block => Block], Match)] =
739+
def handleAssignAndMatch(
740+
assign: Block => Block,
741+
m: Match,
742+
k: Opt[Block => Block]
743+
): Some[(Some[Block => Block], Match)] =
744+
def newK(r: Block): Block =
745+
val newR = k.getOrElse(identity: Block => Block)(r)
746+
assign(newR)
747+
S(S(newK), m)
748+
749+
b match
750+
case m: Match => S(N, m)
751+
case Assign(lhs, rhs: Path, TrivialStatementsAndMatch(k, m)) =>
752+
handleAssignAndMatch(r => Assign(lhs, rhs, r), m, k)
753+
case a@AssignField(lhs, nme, rhs: Path, TrivialStatementsAndMatch(k, m)) =>
754+
handleAssignAndMatch(r => AssignField(lhs, nme, rhs, r)(a.symbol), m, k)
755+
case AssignDynField(lhs, fld, arrayIdx, rhs: Path, TrivialStatementsAndMatch(k, m)) =>
756+
handleAssignAndMatch(r => AssignDynField(lhs, fld, arrayIdx, rhs, r), m, k)
757+
case Define(defn, TrivialStatementsAndMatch(k, m)) =>
758+
handleAssignAndMatch(r => Define(defn, r), m, k)
759+
case _ => N
760+
761+
762+
object MergeMatchArmTransformer extends BlockTransformer(new SymbolSubst()):
763+
override def applyBlock(b: Block): Block = super.applyBlock(b) match
764+
case m@Match(scrut, arms, Some(dflt), rest) =>
765+
dflt.flatten match
766+
case TrivialStatementsAndMatch(k, Match(scrutRewritten, armsRewritten, dfltRewritten, restRewritten))
767+
if (scrutRewritten === scrut) && (restRewritten.size * armsRewritten.length) < 10 =>
768+
val newArms = restRewritten match
769+
case _: End => armsRewritten
770+
case _ => armsRewritten.map:
771+
case (cse, body) =>
772+
cse -> Begin(body, restRewritten)
773+
k.getOrElse(identity: Block => Block)(Match(scrut, arms ::: newArms, dfltRewritten, rest))
774+
case _ => m
775+
case b => b

hkmc2/shared/src/test/mlscript-compile/Example.mjs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,12 @@ Example1 = class Example {
1212
static test(x1) {
1313
if (globalThis.Number.isInteger(x1)) {
1414
return "int"
15+
} else if (typeof x1 === 'number') {
16+
return "num"
17+
} else if (typeof x1 === 'string') {
18+
return "str"
1519
} else {
16-
if (typeof x1 === 'number') {
17-
return "num"
18-
} else {
19-
if (typeof x1 === 'string') {
20-
return "str"
21-
} else {
22-
return "other"
23-
}
24-
}
20+
return "other"
2521
}
2622
}
2723
static toString() { return "Example"; }

hkmc2/shared/src/test/mlscript-compile/Option.mjs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,10 @@ Option1 = class Option {
2828
static isDefined(x) {
2929
if (x instanceof Option.Some.class) {
3030
return true
31+
} else if (x instanceof Option.None.class) {
32+
return false
3133
} else {
32-
if (x instanceof Option.None.class) {
33-
return false
34-
} else {
35-
throw new globalThis.Error("match error");
36-
}
34+
throw new globalThis.Error("match error");
3735
}
3836
}
3937
static test() {

hkmc2/shared/src/test/mlscript-compile/Predef.mjs

Lines changed: 69 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -251,92 +251,78 @@ Predef1 = class Predef {
251251
let ts, p, scrut, scrut1, scrut2, nme, tmp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12, tmp13, tmp14, tmp15, tmp16, tmp17, tmp18, tmp19, tmp20;
252252
if (arg1 === undefined) {
253253
return "undefined"
254-
} else {
255-
if (arg1 === null) {
256-
return "null"
254+
} else if (arg1 === null) {
255+
return "null"
256+
} else if (arg1 instanceof globalThis.Array) {
257+
tmp = Predef.fold((arg11, arg2) => {
258+
return arg11 + arg2
259+
});
260+
tmp1 = Predef.interleave(", ");
261+
tmp2 = Predef.map(Predef.render);
262+
tmp3 = runtime.safeCall(tmp2(...arg1));
263+
tmp4 = runtime.safeCall(tmp1(...tmp3));
264+
return runtime.safeCall(tmp("[", ...tmp4, "]"))
265+
} else if (typeof arg1 === 'string') {
266+
return runtime.safeCall(globalThis.JSON.stringify(arg1))
267+
} else if (arg1 instanceof globalThis.Set) {
268+
tmp5 = Predef.fold((arg11, arg2) => {
269+
return arg11 + arg2
270+
});
271+
tmp6 = Predef.interleave(", ");
272+
tmp7 = Predef.map(Predef.render);
273+
tmp8 = runtime.safeCall(tmp7(...arg1));
274+
tmp9 = runtime.safeCall(tmp6(...tmp8));
275+
return runtime.safeCall(tmp5("Set{", ...tmp9, "}"))
276+
} else if (arg1 instanceof globalThis.Map) {
277+
tmp10 = Predef.fold((arg11, arg2) => {
278+
return arg11 + arg2
279+
});
280+
tmp11 = Predef.interleave(", ");
281+
tmp12 = Predef.map(Predef.render);
282+
tmp13 = runtime.safeCall(tmp12(...arg1));
283+
tmp14 = runtime.safeCall(tmp11(...tmp13));
284+
return runtime.safeCall(tmp10("Map{", ...tmp14, "}"))
285+
} else if (arg1 instanceof globalThis.Function) {
286+
p = globalThis.Object.getOwnPropertyDescriptor(arg1, "prototype");
287+
if (p instanceof globalThis.Object) {
288+
scrut = p["writable"];
289+
if (scrut === true) {
290+
tmp15 = true;
291+
} else {
292+
tmp15 = false;
293+
}
257294
} else {
258-
if (arg1 instanceof globalThis.Array) {
259-
tmp = Predef.fold((arg11, arg2) => {
260-
return arg11 + arg2
261-
});
262-
tmp1 = Predef.interleave(", ");
263-
tmp2 = Predef.map(Predef.render);
264-
tmp3 = runtime.safeCall(tmp2(...arg1));
265-
tmp4 = runtime.safeCall(tmp1(...tmp3));
266-
return runtime.safeCall(tmp("[", ...tmp4, "]"))
295+
tmp15 = false;
296+
}
297+
if (p === undefined) {
298+
tmp16 = true;
299+
} else {
300+
tmp16 = false;
301+
}
302+
scrut1 = tmp15 || tmp16;
303+
if (scrut1 === true) {
304+
scrut2 = arg1.name;
305+
if (scrut2 === "") {
306+
tmp17 = "";
267307
} else {
268-
if (typeof arg1 === 'string') {
269-
return runtime.safeCall(globalThis.JSON.stringify(arg1))
270-
} else {
271-
if (arg1 instanceof globalThis.Set) {
272-
tmp5 = Predef.fold((arg11, arg2) => {
273-
return arg11 + arg2
274-
});
275-
tmp6 = Predef.interleave(", ");
276-
tmp7 = Predef.map(Predef.render);
277-
tmp8 = runtime.safeCall(tmp7(...arg1));
278-
tmp9 = runtime.safeCall(tmp6(...tmp8));
279-
return runtime.safeCall(tmp5("Set{", ...tmp9, "}"))
280-
} else {
281-
if (arg1 instanceof globalThis.Map) {
282-
tmp10 = Predef.fold((arg11, arg2) => {
283-
return arg11 + arg2
284-
});
285-
tmp11 = Predef.interleave(", ");
286-
tmp12 = Predef.map(Predef.render);
287-
tmp13 = runtime.safeCall(tmp12(...arg1));
288-
tmp14 = runtime.safeCall(tmp11(...tmp13));
289-
return runtime.safeCall(tmp10("Map{", ...tmp14, "}"))
290-
} else {
291-
if (arg1 instanceof globalThis.Function) {
292-
p = globalThis.Object.getOwnPropertyDescriptor(arg1, "prototype");
293-
if (p instanceof globalThis.Object) {
294-
scrut = p["writable"];
295-
if (scrut === true) {
296-
tmp15 = true;
297-
} else {
298-
tmp15 = false;
299-
}
300-
} else {
301-
tmp15 = false;
302-
}
303-
if (p === undefined) {
304-
tmp16 = true;
305-
} else {
306-
tmp16 = false;
307-
}
308-
scrut1 = tmp15 || tmp16;
309-
if (scrut1 === true) {
310-
scrut2 = arg1.name;
311-
if (scrut2 === "") {
312-
tmp17 = "";
313-
} else {
314-
nme = scrut2;
315-
tmp17 = " " + nme;
316-
}
317-
tmp18 = "[function" + tmp17;
318-
return tmp18 + "]"
319-
} else {
320-
return globalThis.String(arg1)
321-
}
322-
} else {
323-
if (arg1 instanceof globalThis.Object) {
324-
return globalThis.String(arg1)
325-
} else {
326-
ts = arg1["toString"];
327-
if (ts === undefined) {
328-
tmp19 = typeof arg1;
329-
tmp20 = "[" + tmp19;
330-
return tmp20 + "]"
331-
} else {
332-
return runtime.safeCall(ts.call(arg1))
333-
}
334-
}
335-
}
336-
}
337-
}
338-
}
308+
nme = scrut2;
309+
tmp17 = " " + nme;
339310
}
311+
tmp18 = "[function" + tmp17;
312+
return tmp18 + "]"
313+
} else {
314+
return globalThis.String(arg1)
315+
}
316+
} else if (arg1 instanceof globalThis.Object) {
317+
return globalThis.String(arg1)
318+
} else {
319+
ts = arg1["toString"];
320+
if (ts === undefined) {
321+
tmp19 = typeof arg1;
322+
tmp20 = "[" + tmp19;
323+
return tmp20 + "]"
324+
} else {
325+
return runtime.safeCall(ts.call(arg1))
340326
}
341327
}
342328
}

0 commit comments

Comments
 (0)