Skip to content

Commit ad69ef7

Browse files
EdCauntmloubout
authored andcommitted
compiler: Add a check to pre-empt expensive symbolic comparisons before try-except
1 parent ca79bee commit ad69ef7

File tree

5 files changed

+18
-21
lines changed

5 files changed

+18
-21
lines changed

devito/ir/clusters/algorithms.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,6 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
156156
# Schedule Clusters over different IterationSpaces if this increases
157157
# parallelism
158158
for i in range(1, len(clusters)):
159-
# FIXME: This eats a lot of time (four seconds each time)
160-
# FIXME: Pull scope out of this
161159
if self._break_for_parallelism(scope, candidates, i):
162160
return self.callback(clusters[:i], prefix, clusters[i:] + backlog,
163161
candidates | known_break)
@@ -193,9 +191,6 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
193191
def _break_for_parallelism(self, scope, candidates, i):
194192
# `test` will be True if there's at least one data-dependence that would
195193
# break parallelism
196-
197-
# TODO: Can this loop be made to short-circuit?
198-
# TODO: Most of the time is burned in d_from_access_gen
199194
test = False
200195
for d in scope.d_from_access_gen(scope.a_query(i)):
201196
if d.is_local or d.is_storage_related(candidates):

devito/ir/support/basic.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def lex_le(self, other):
317317
def lex_lt(self, other):
318318
return self.timestamp < other.timestamp
319319

320-
# NOTE: This is called a lot with the same arguments - memoize yields mild speedup
320+
# Note: memoization yields mild compiler speedup
321321
@memoized_meth
322322
def distance(self, other):
323323
"""
@@ -365,14 +365,21 @@ def distance(self, other):
365365
# Case 1: `sit` is an IterationInterval with statically known
366366
# trip count. E.g. it ranges from 0 to 3; `other` performs a
367367
# constant access at 4
368-
for v in (self[n], other[n]):
369-
try:
370-
# NOTE: Split the boolean to make the conditional short circuit
371-
# more frequently for mild speedup
372-
if bool(v < sit.symbolic_min) or bool(v > sit.symbolic_max):
373-
return Vector(S.ImaginaryUnit)
374-
except TypeError:
375-
pass
368+
369+
# To avoid evaluating expensive symbolic Lt or Gt operations,
370+
# we pre-empt such operations by checking if the values to be compared
371+
# to are symbolic, and skip this case if not.
372+
if not any(isinstance(i, sympy.core.Basic)
373+
for i in (sit.symbolic_min, sit.symbolic_max)):
374+
375+
for v in (self[n], other[n]):
376+
try:
377+
# Note: Boolean is split to make the conditional short
378+
# circuit more frequently for mild speedup
379+
if bool(v < sit.symbolic_min) or bool(v > sit.symbolic_max):
380+
return Vector(S.ImaginaryUnit)
381+
except TypeError:
382+
pass
376383

377384
# Case 2: `sit` is an IterationInterval over a local SubDimension
378385
# and `other` performs a constant access
@@ -1174,7 +1181,6 @@ def d_from_access_gen(self, accesses):
11741181
Generate all flow, anti, and output dependences involving any of
11751182
the given TimedAccess objects.
11761183
"""
1177-
# FIXME: This seems to be a hotspot
11781184
accesses = as_tuple(accesses)
11791185
for d in self.d_all_gen():
11801186
for i in accesses:

devito/operator/operator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -967,10 +967,8 @@ def _emit_build_profiling(self):
967967
tot = timings.pop('op-compile')
968968
perf(f"Operator `{self.name}` generated in {fround(tot):.2f} s")
969969

970-
# max_hotspots = 3
971-
# threshold = 20.
972-
max_hotspots = 300
973-
threshold = 0.5
970+
max_hotspots = 3
971+
threshold = 20.
974972

975973
def _emit_timings(timings, indent=''):
976974
timings.pop('total', None)

devito/passes/clusters/cse.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ def _compact(exprs, exclude):
234234

235235
# If there are ctemps in the expressions, then add any to the mapper which only
236236
# appear once
237-
# TODO: Double check this is exactly the prior behaviour
238237
if ctemps:
239238
mapper.update({e.lhs: e.rhs for e in candidates
240239
if ctemp_count[e.lhs] == 1})

devito/passes/clusters/misc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,6 @@ def is_cross(source, sink):
362362
# * All ClusterGroups between `cg0` and `cg1` must precede `cg1`
363363
# * All ClusterGroups after `cg1` cannot precede `cg1`
364364

365-
# FIXME: Slow
366365
if any(i.cause & prefix for i in scope.d_anti_gen()):
367366
for cg2 in cgroups[n:cgroups.index(cg1)]:
368367
dag.add_edge(cg2, cg1)

0 commit comments

Comments
 (0)