1
1
package effekt
2
2
package core
3
3
4
- import effekt .context .Context
5
4
import effekt .core .substitutions .Substitution
6
- import effekt .symbols .TmpValue
7
5
8
6
import scala .collection .mutable
9
7
@@ -54,9 +52,9 @@ import scala.collection.mutable
54
52
object PatternMatchingCompiler {
55
53
56
54
/**
57
- * The conditions need to be met in sequence before the block at [[label ]] can be evaluated with given [[args ]].
55
+ * The conditions need to be met in sequence before the block at [[label ]] can be evaluated with given [[targs ]] and [[ args ]].
58
56
*/
59
- case class Clause (conditions : List [Condition ], label : BlockVar , args : List [ValueVar ])
57
+ case class Clause (conditions : List [Condition ], label : BlockVar , targs : List [ ValueType ], args : List [ValueVar ])
60
58
61
59
enum Condition {
62
60
// all of the patterns need to match for this condition to be met
@@ -71,7 +69,7 @@ object PatternMatchingCompiler {
71
69
enum Pattern {
72
70
// sub-patterns are annotated with the inferred type of the scrutinee at this point
73
71
// i.e. Cons(Some(x : TInt): Option[Int], xs: List[Option[Int]])
74
- case Tag (id : Id , patterns : List [(Pattern , ValueType )])
72
+ case Tag (id : Id , tparams : List [ Id ], patterns : List [(Pattern , ValueType )])
75
73
case Ignore ()
76
74
case Any (id : Id )
77
75
case Or (patterns : List [Pattern ])
@@ -93,21 +91,21 @@ object PatternMatchingCompiler {
93
91
// (1) Check the first clause to be matched (we can immediately handle non-pattern cases)
94
92
val patterns = headClause match {
95
93
// - The top-most clause already matches successfully
96
- case Clause (Nil , target, args) =>
97
- return core.App (target, Nil , args, Nil )
94
+ case Clause (Nil , target, targs, args) =>
95
+ return core.App (target, targs , args, Nil )
98
96
// - We need to perform a computation
99
- case Clause (Condition .Val (x, tpe, binding) :: rest, target, args) =>
100
- return core.Val (x, tpe, binding, compile(Clause (rest, target, args) :: remainingClauses))
97
+ case Clause (Condition .Val (x, tpe, binding) :: rest, target, targs, args) =>
98
+ return core.Val (x, tpe, binding, compile(Clause (rest, target, targs, args) :: remainingClauses))
101
99
// - We need to perform a computation
102
- case Clause (Condition .Let (x, tpe, binding) :: rest, target, args) =>
103
- return core.Let (x, tpe, binding, compile(Clause (rest, target, args) :: remainingClauses))
100
+ case Clause (Condition .Let (x, tpe, binding) :: rest, target, targs, args) =>
101
+ return core.Let (x, tpe, binding, compile(Clause (rest, target, targs, args) :: remainingClauses))
104
102
// - We need to check a predicate
105
- case Clause (Condition .Predicate (pred) :: rest, target, args) =>
103
+ case Clause (Condition .Predicate (pred) :: rest, target, targs, args) =>
106
104
return core.If (pred,
107
- compile(Clause (rest, target, args) :: remainingClauses),
105
+ compile(Clause (rest, target, targs, args) :: remainingClauses),
108
106
compile(remainingClauses)
109
107
)
110
- case Clause (Condition .Patterns (patterns) :: rest, target, args) =>
108
+ case Clause (Condition .Patterns (patterns) :: rest, target, targs, args) =>
111
109
patterns
112
110
}
113
111
@@ -127,7 +125,7 @@ object PatternMatchingCompiler {
127
125
def splitOnLiteral (lit : Literal , equals : (Pure , Pure ) => Pure ): core.Stmt = {
128
126
// the different literal values that we match on
129
127
val variants : List [core.Literal ] = normalized.collect {
130
- case Clause (Split (Pattern .Literal (lit, _), _, _), _, _) => lit
128
+ case Clause (Split (Pattern .Literal (lit, _), _, _), _, _, _ ) => lit
131
129
}.distinct
132
130
133
131
// for each literal, we collect the clauses that match it correctly
@@ -141,8 +139,8 @@ object PatternMatchingCompiler {
141
139
defaults = defaults :+ cl
142
140
143
141
normalized.foreach {
144
- case Clause (Split (Pattern .Literal (lit, _), restPatterns, restConds), label, args) =>
145
- addClause(lit, Clause (Condition .Patterns (restPatterns) :: restConds, label, args))
142
+ case Clause (Split (Pattern .Literal (lit, _), restPatterns, restConds), label, targs, args) =>
143
+ addClause(lit, Clause (Condition .Patterns (restPatterns) :: restConds, label, targs, args))
146
144
case c =>
147
145
addDefault(c)
148
146
variants.foreach { v => addClause(v, c) }
@@ -164,7 +162,7 @@ object PatternMatchingCompiler {
164
162
def splitOnTag () = {
165
163
// collect all variants that are mentioned in the clauses
166
164
val variants : List [Id ] = normalized.collect {
167
- case Clause (Split (p : Pattern .Tag , _, _), _, _) => p.id
165
+ case Clause (Split (p : Pattern .Tag , _, _), _, _, _ ) => p.id
168
166
}.distinct
169
167
170
168
// for each tag, we collect the clauses that match it correctly
@@ -179,7 +177,9 @@ object PatternMatchingCompiler {
179
177
180
178
// used to make up new scrutinees
181
179
val varsFor = mutable.Map .empty[Id , List [ValueVar ]]
182
- def fieldVarsFor (constructor : Id , fieldInfo : List [((Pattern , ValueType ), String )]): List [ValueVar ] =
180
+ val tvarsFor = mutable.Map .empty[Id , List [Id ]]
181
+ def fieldVarsFor (constructor : Id , tparams : List [Id ], fieldInfo : List [((Pattern , ValueType ), String )]): List [ValueVar ] =
182
+ tvarsFor.getOrElseUpdate(constructor, tparams)
183
183
varsFor.getOrElseUpdate(
184
184
constructor,
185
185
fieldInfo.map {
@@ -191,17 +191,17 @@ object PatternMatchingCompiler {
191
191
)
192
192
193
193
normalized.foreach {
194
- case Clause (Split (Pattern .Tag (constructor, patternsAndTypes), restPatterns, restConds), label, args) =>
194
+ case Clause (Split (Pattern .Tag (constructor, tparams, patternsAndTypes), restPatterns, restConds), label, targs , args) =>
195
195
// NOTE: Ideally, we would use a `DeclarationContext` here, but we cannot: we're currently in the Source->Core transformer, so we do not have all of the details yet.
196
196
val fieldNames : List [String ] = constructor match {
197
197
case c : symbols.Constructor => c.fields.map(_.name.name)
198
198
case _ => List .fill(patternsAndTypes.size) { " y" } // NOTE: Only reached in PatternMatchingTests
199
199
}
200
- val fieldVars = fieldVarsFor(constructor, patternsAndTypes.zip(fieldNames))
200
+ val fieldVars = fieldVarsFor(constructor, tparams, patternsAndTypes.zip(fieldNames))
201
201
val nestedMatches = fieldVars.zip(patternsAndTypes.map { case (pat, tpe) => pat }).toMap
202
202
addClause(constructor,
203
203
// it is important to add nested matches first, since they might include substitutions for the rest.
204
- Clause (Condition .Patterns (nestedMatches) :: Condition .Patterns (restPatterns) :: restConds, label, args))
204
+ Clause (Condition .Patterns (nestedMatches) :: Condition .Patterns (restPatterns) :: restConds, label, targs, args))
205
205
206
206
case c =>
207
207
// Clauses that don't match on that var are duplicated.
@@ -214,8 +214,9 @@ object PatternMatchingCompiler {
214
214
// (4) assemble syntax tree for the pattern match
215
215
val branches = variants.map { v =>
216
216
val body = compile(clausesFor.getOrElse(v, Nil ))
217
+ val tparams = tvarsFor(v)
217
218
val params = varsFor(v).map { case ValueVar (id, tpe) => core.ValueParam (id, tpe): core.ValueParam }
218
- val blockLit : BlockLit = BlockLit (Nil , Nil , params, Nil , body)
219
+ val blockLit : BlockLit = BlockLit (tparams , Nil , params, Nil , body)
219
220
(v, blockLit)
220
221
}
221
222
@@ -232,16 +233,17 @@ object PatternMatchingCompiler {
232
233
233
234
def branchingHeuristic (patterns : Map [ValueVar , Pattern ], clauses : List [Clause ]): ValueVar =
234
235
patterns.keys.maxBy(v => clauses.count {
235
- case Clause (ps, _, _) => ps.contains(v)
236
+ case Clause (ps, _, _, _ ) => ps.contains(v)
236
237
})
237
238
238
239
/**
239
240
* Substitutes AnyPattern and removes wildcards.
240
241
*/
241
242
def normalize (clause : Clause ): Clause = clause match {
242
- case Clause (conditions, label, args) =>
243
+ case Clause (conditions, label, targs, args) =>
243
244
val (normalized, substitution) = normalize(Map .empty, conditions, Map .empty)
244
- Clause (normalized, label, args.map(v => substitution.getOrElse(v.id, v)))
245
+ // TODO also substitute types?
246
+ Clause (normalized, label, targs, args.map(v => substitution.getOrElse(v.id, v)))
245
247
}
246
248
247
249
@@ -309,8 +311,8 @@ object PatternMatchingCompiler {
309
311
// -----------------------------
310
312
311
313
def show (cl : Clause ): String = cl match {
312
- case Clause (conditions, label, args) =>
313
- s " case ${conditions.map(show).mkString(" ; " )} => ${util.show(label.id)}${args.map(x => util.show(x)).mkString(" (" , " , " , " )" )}"
314
+ case Clause (conditions, label, targs, args) =>
315
+ s " case ${conditions.map(show).mkString(" ; " )} => ${util.show(label.id)}${targs.map(x => util.show(x))}${ args.map(x => util.show(x)).mkString(" (" , " , " , " )" )}"
314
316
}
315
317
316
318
def show (c : Condition ): String = c match {
@@ -321,7 +323,7 @@ object PatternMatchingCompiler {
321
323
}
322
324
323
325
def show (p : Pattern ): String = p match {
324
- case Pattern .Tag (id, patterns) => util.show(id) + patterns.map { case (p, tpe) => show(p) }.mkString(" (" , " , " , " )" )
326
+ case Pattern .Tag (id, tparams, patterns) => util.show(id) + tparams.map(util.show).mkString( " [ " , " , " , " ] " ) + patterns.map { case (p, tpe) => show(p) }.mkString(" (" , " , " , " )" )
325
327
case Pattern .Ignore () => " _"
326
328
case Pattern .Any (id) => util.show(id)
327
329
case Pattern .Or (patterns) => patterns.map(show).mkString(" | " )
0 commit comments