Skip to content

Commit 99b9e76

Browse files
authored
Core compiler optimization: fix MASSIVE performance bugs (#404)
... SMU example compile time (core compile time only, with hot cache including generator results) goes from ~minute to a bit under a second. Fixes two major performance bugs: - ConstProp used prevent parameters without types declared from propagating by rejecting it from the ready queue, repeatedly. This now adds the parameter declaration as a dependency. - (probably) DependencyGraph used to subtract the entire values keySet, which would get massive, this now does a filterInPlace on incoming dependencies which should be much smaller Also changes a few things from Set to Iterable, which may marginally improve performance
1 parent 3f6711d commit 99b9e76

File tree

6 files changed

+49
-46
lines changed

6 files changed

+49
-46
lines changed

compiler/src/main/scala/edg/compiler/Compiler.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ class Compiler private (
258258
// Returns all errors, by scanning the design tree for errors and adding errors accumulated through the compile
259259
// process
260260
def getErrors(): Seq[CompilerError] = {
261-
val pendingErrors = elaboratePending.getMissingValue.map { missingNode =>
261+
val pendingErrors = elaboratePending.getMissingValues.map { missingNode =>
262262
CompilerError.Unelaborated(missingNode, elaboratePending.nodeMissing(missingNode))
263263
}.toSeq
264264

@@ -1628,15 +1628,18 @@ class Compiler private (
16281628
val partialCompileIgnoredRecords = mutable.Set[ElaborateRecord]()
16291629

16301630
// repeat as long as there is work ready, and all the ready work isn't marked to be ignored
1631-
var readyList = Set[ElaborateRecord]()
1631+
var readyList = Iterable[ElaborateRecord]()
16321632
do {
1633-
readyList = elaboratePending.getReady -- partialCompileIgnoredRecords
1633+
// TODO this is kind of ugly and expensive in that it keeps filtering the ignored records
1634+
// ideally this should be done at the getReady side, but these also need to be restored
1635+
// when the compiler forks
1636+
readyList = elaboratePending.getReady.filter(!partialCompileIgnoredRecords.contains(_))
16341637
readyList.foreach { elaborateRecord =>
16351638
try {
16361639
elaborateRecord match {
16371640
case elaborateRecord @ ElaborateRecord.ExpandBlock(blockPath, blockClass, blockProgress) =>
16381641
if (partial.blocks.contains(blockPath) || partial.classes.contains(blockClass)) {
1639-
partialCompileIgnoredRecords.add(elaborateRecord)
1642+
partialCompileIgnoredRecords += elaborateRecord
16401643
} else {
16411644
expandBlock(blockPath, blockProgress)
16421645
elaboratePending.setValue(elaborateRecord, None)
@@ -1653,7 +1656,7 @@ class Compiler private (
16531656
case elaborateRecord @ ElaborateRecord.Parameter(root, rootClasses, postfix, param) =>
16541657
val container = resolveBlock(root).asInstanceOf[wir.HasParams]
16551658
if (paramMatchesPartial(root, rootClasses, postfix)) {
1656-
partialCompileIgnoredRecords.add(elaborateRecord)
1659+
partialCompileIgnoredRecords += elaborateRecord
16571660
} else {
16581661
constProp.addDeclaration(root ++ postfix, param)
16591662
elaboratePending.setValue(elaborateRecord, None)

compiler/src/main/scala/edg/compiler/CompilerError.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ sealed trait CompilerError {
1313
}
1414

1515
object CompilerError {
16-
case class Unelaborated(record: ElaborateRecord, missing: Set[ElaborateRecord]) extends CompilerError {
16+
case class Unelaborated(record: ElaborateRecord, missing: Iterable[ElaborateRecord]) extends CompilerError {
1717
// These errors may be redundant with below, but provides dependency data
1818
override def toString: String = s"Unelaborated missing dependencies $record:\n" +
1919
s"${missing.map(x => s"- $x").mkString("\n")}"
@@ -165,7 +165,7 @@ object CompilerError {
165165
root: DesignPath,
166166
constrName: String,
167167
value: expr.ValueExpr,
168-
missing: Set[IndirectDesignPath]
168+
missing: Iterable[IndirectDesignPath]
169169
) extends AssertionError {
170170
override def toString: String =
171171
s"Unevaluated assertion: $root.$constrName: missing ${missing.mkString(", ")} in ${ExprToString.apply(value)}"

compiler/src/main/scala/edg/compiler/ConstProp.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class ConstProp() {
4242
// Assign statements are added to the dependency graph only when arrays are ready
4343
// This is the authoritative source for the state of any param - in the graph (and its dependencies), or value solved
4444
// CONNECTED_LINK has an empty value but indicates that the path was resolved in that data structure
45+
// NAME has an empty value but indicates declaration (existence in paramTypes)
4546
private val params = DependencyGraph[IndirectDesignPath, ExprValue]()
4647
// Parameter types are used to track declared parameters
4748
// Undeclared parameters cannot have values set, but can be forced (though the value is not effective until declared)
@@ -106,16 +107,9 @@ class ConstProp() {
106107
connectedLink.setValue(ready, DesignPath())
107108
}
108109

109-
var readyList = Set[IndirectDesignPath]()
110+
var readyList = Iterable[IndirectDesignPath]()
110111
do {
111-
// ignore params where we haven't seen the decl yet, to allow forced-assign when the block is expanded
112-
// TODO support this for all params, including indirect ones (eg, name)
113-
readyList = params.getReady.filter { elt =>
114-
DesignPath.fromIndirectOption(elt) match {
115-
case Some(elt) => paramTypes.keySet.contains(elt.asIndirect)
116-
case None => true
117-
}
118-
}
112+
readyList = params.getReady
119113
readyList.foreach { constrTarget =>
120114
val assign = paramAssign(constrTarget)
121115
new ExprEvaluatePartial(getValue, assign.root).map(assign.value) match {
@@ -165,6 +159,7 @@ class ConstProp() {
165159
case _ => throw new NotImplementedError(s"Unknown param declaration / init $decl")
166160
}
167161
paramTypes.put(target.asIndirect, paramType)
162+
params.setValue(target.asIndirect + IndirectStep.Name, BooleanValue(false)) // dummy value
168163
update()
169164
}
170165

@@ -196,14 +191,19 @@ class ConstProp() {
196191
require(target.splitConnectedLink.isEmpty, "cannot set CONNECTED_LINK")
197192
val paramSourceRecord = (root, constrName, targetExpr)
198193

194+
// ignore params where we haven't seen the decl yet, to allow forced-assign when the block is expanded
195+
val paramTypesDep = DesignPath.fromIndirectOption(target) match {
196+
case Some(path) => Seq(path.asIndirect + IndirectStep.Name)
197+
case None => Seq() // has indirect step, no direct decl
198+
}
199199
if (forced) {
200200
require(!forcedParams.contains(target), s"attempt to re-force $target")
201201
forcedParams.add(target)
202202
require(
203203
!params.valueDefinedAt(target),
204204
s"forced value must be set before value is resolved, prior ${paramSource(target)}"
205205
)
206-
params.addNode(target, Seq(), overwrite = true) // forced can overwrite other records
206+
params.addNode(target, paramTypesDep, overwrite = true) // forced can overwrite other records
207207
} else {
208208
if (!forcedParams.contains(target)) {
209209
if (params.nodeDefinedAt(target)) { // TODO add propagated assign
@@ -213,7 +213,7 @@ class ConstProp() {
213213
)
214214
return // first set "wins"
215215
}
216-
params.addNode(target, Seq())
216+
params.addNode(target, paramTypesDep)
217217
} else {
218218
return // ignored - param was forced, discard the new assign
219219
}

compiler/src/main/scala/edg/util/DependencyGraph.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import scala.collection.mutable
88
*/
99
class DependencyGraph[KeyType, ValueType] {
1010
private val values = mutable.HashMap[KeyType, ValueType]()
11-
private val inverseDeps = mutable.HashMap[KeyType, mutable.Set[KeyType]]()
11+
private val inverseDeps = mutable.HashMap[KeyType, mutable.ArrayBuffer[KeyType]]()
1212
private val deps = mutable.HashMap[KeyType, mutable.Set[KeyType]]() // cache structure tracking undefined deps
1313
private val ready = mutable.Set[KeyType]()
1414

@@ -27,22 +27,23 @@ class DependencyGraph[KeyType, ValueType] {
2727

2828
// Adds a node in the graph. May only be called once per node.
2929
def addNode(node: KeyType, dependencies: Seq[KeyType], overwrite: Boolean = false): Unit = {
30+
val dependenciesSet = dependencies.to(mutable.Set)
3031
deps.get(node) match {
3132
case Some(prevDeps) =>
3233
require(overwrite, s"reinsertion of dependency for node $node <- $dependencies without overwrite=true")
3334
// TODO can this requirement be eliminated?
34-
require(prevDeps.subsetOf(dependencies.toSet), "update of dependencies without being a superset of prior")
35+
require(prevDeps.forall(dependencies.contains(_)), "update of dependencies without being a superset of prior")
3536
case None => // nothing if no previous dependencies
3637
}
3738
require(
3839
!values.isDefinedAt(node),
3940
s"reinsertion of dependency for node with value $node = ${values(node)} <- $dependencies"
4041
)
41-
val remainingDeps = (dependencies.toSet -- values.keySet).to(mutable.Set)
42+
val remainingDeps = dependenciesSet.filterInPlace(!values.contains(_))
4243

4344
deps.put(node, remainingDeps)
4445
for (dependency <- remainingDeps) {
45-
inverseDeps.getOrElseUpdate(dependency, mutable.Set()) += node
46+
inverseDeps.getOrElseUpdate(dependency, mutable.ArrayBuffer()) += node
4647
}
4748

4849
if (overwrite && ready.contains(node)) {
@@ -60,7 +61,7 @@ class DependencyGraph[KeyType, ValueType] {
6061

6162
// Returns missing dependencies for a node, or empty if the node is ready or has a value assigned
6263
// Node must exist, or this will exception out
63-
def nodeMissing(node: KeyType): Set[KeyType] = deps(node).toSet
64+
def nodeMissing(node: KeyType): Iterable[KeyType] = deps(node)
6465

6566
// Clears a node from ready without setting a value in the graph.
6667
// Useful to stop propagation at some point, but without crashing.
@@ -77,13 +78,12 @@ class DependencyGraph[KeyType, ValueType] {
7778
require(!values.isDefinedAt(node), s"redefinition of $node (prior value ${values(node)}, new value $value)")
7879
deps.put(node, mutable.Set())
7980
values.put(node, value)
80-
if (ready.contains(node)) {
81-
ready -= node
82-
}
81+
ready -= node
8382

8483
// See if the update caused anything else to be ready
85-
for (inverseDep <- inverseDeps.getOrElse(node, mutable.Set())) {
86-
val remainingDeps = deps(inverseDep) -= node
84+
for (inverseDep <- inverseDeps.getOrElse(node, mutable.ArrayBuffer())) {
85+
val remainingDeps = deps(inverseDep)
86+
remainingDeps -= node
8787
if (remainingDeps.isEmpty && !values.isDefinedAt(inverseDep)) {
8888
ready += inverseDep
8989
}
@@ -95,13 +95,13 @@ class DependencyGraph[KeyType, ValueType] {
9595
}
9696

9797
// Returns all the KeyTypes that don't have values and have satisfied dependencies.
98-
def getReady: Set[KeyType] = {
99-
ready.toSet
98+
def getReady: Iterable[KeyType] = {
99+
ready
100100
}
101101

102102
// Returns all the KeyTypes that have no values. NOT a fast operation. Includes items in the ready list.
103-
def getMissingValue: Set[KeyType] = {
104-
deps.keySet.toSet -- values.keySet
103+
def getMissingValues: Iterable[KeyType] = {
104+
deps.keys.toSet -- values.keys
105105
}
106106

107107
def knownValueKeys: Iterable[KeyType] = {

compiler/src/test/scala/edg/util/DependencyGraphTest.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class DependencyGraphTest extends AnyFlatSpec {
4141
dep.nodeMissing(1) should equal(Set(0))
4242
dep.setValue(0, 0)
4343
dep.getReady should equal(Set(1))
44-
dep.nodeMissing(1) should equal(Set())
44+
dep.nodeMissing(1) shouldBe empty
4545
}
4646

4747
it should "track multiple dependencies, and add to ready when all set" in {
@@ -57,7 +57,7 @@ class DependencyGraphTest extends AnyFlatSpec {
5757
dep.nodeMissing(3) should equal(Set(2))
5858
dep.setValue(2, 0)
5959
dep.getReady should equal(Set(3))
60-
dep.nodeMissing(3) should equal(Set())
60+
dep.nodeMissing(3) shouldBe empty
6161
}
6262

6363
it should "track a chain of dependencies" in {
@@ -105,22 +105,22 @@ class DependencyGraphTest extends AnyFlatSpec {
105105

106106
it should "return getMissing" in {
107107
val dep = DependencyGraph[Int, Int]()
108-
dep.getMissingValue shouldBe empty
108+
dep.getMissingValues shouldBe empty
109109
dep.addNode(1, Seq(0))
110-
dep.getMissingValue should equal(Set(1))
110+
dep.getMissingValues should equal(Set(1))
111111
dep.setValue(1, 1)
112-
dep.getMissingValue shouldBe empty
112+
dep.getMissingValues shouldBe empty
113113
}
114114

115115
it should "return getMissing including ready nodes" in {
116116
val dep = DependencyGraph[Int, Int]()
117-
dep.getMissingValue shouldBe empty
117+
dep.getMissingValues shouldBe empty
118118
dep.addNode(1, Seq(0))
119-
dep.getMissingValue should equal(Set(1))
120-
dep.getReady should equal(Set())
119+
dep.getMissingValues should equal(Set(1))
120+
dep.getReady shouldBe empty
121121
dep.setValue(0, 0)
122122
dep.getReady should equal(Set(1))
123-
dep.getMissingValue should equal(Set(1)) // test ready and missing
123+
dep.getMissingValues should equal(Set(1)) // test ready and missing
124124
}
125125

126126
it should "prevent reinsertion of a node" in {
@@ -154,20 +154,20 @@ class DependencyGraphTest extends AnyFlatSpec {
154154
dep.addNode(10, Seq(0))
155155
dep.addNode(10, Seq(0, 1), overwrite = true)
156156
dep.setValue(0, 0)
157-
dep.getReady should equal(Set()) // should still be blocked on 1
157+
dep.getReady shouldBe empty // should still be blocked on 1
158158

159159
dep.addNode(10, Seq(1, 2), overwrite = true) // 0 should no longer be required
160160

161161
dep.setValue(1, 1)
162-
dep.getReady should equal(Set())
162+
dep.getReady shouldBe empty
163163
dep.setValue(2, 2)
164164
dep.getReady should equal(Set(10))
165165

166166
dep.addNode(10, Seq(1, 2), overwrite = true) // should be a nop
167167
dep.getReady should equal(Set(10))
168168

169169
dep.addNode(10, Seq(3), overwrite = true)
170-
dep.getReady should equal(Set()) // should no longer be ready
170+
dep.getReady shouldBe empty // should no longer be ready
171171
}
172172

173173
it should "return nodeDefinedAt and valueDefinedAt for dependencies" in {
@@ -197,12 +197,12 @@ class DependencyGraphTest extends AnyFlatSpec {
197197

198198
dep1.setValue(0, 0)
199199
dep1.getReady should equal(Set(1))
200-
dep1.nodeMissing(1) should equal(Set())
200+
dep1.nodeMissing(1) shouldBe empty
201201
dep2.getReady shouldBe empty
202202
dep2.nodeMissing(1) should equal(Set(0))
203203

204204
dep2.setValue(0, 0)
205205
dep2.getReady should equal(Set(1))
206-
dep2.nodeMissing(1) should equal(Set())
206+
dep2.nodeMissing(1) shouldBe empty
207207
}
208208
}
502 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)