Skip to content

Commit 3d6eff0

Browse files
EdCauntmloubout
authored andcommitted
misc: Add comparisons that give up upon encountering symbolic arguments
1 parent 89cde25 commit 3d6eff0

File tree

2 files changed

+41
-20
lines changed

2 files changed

+41
-20
lines changed

devito/ir/support/basic.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
q_constant, q_comp_acc, q_affine, q_routine, search,
1212
uxreplace)
1313
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
14-
flatten, memoized_meth, memoized_generator)
14+
flatten, memoized_meth, memoized_generator, smart_gt,
15+
smart_lt)
1516
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
1617
CriticalRegion, Function, Symbol, Temp, TempArray,
1718
TBArray)
@@ -317,9 +318,6 @@ def lex_le(self, other):
317318
def lex_lt(self, other):
318319
return self.timestamp < other.timestamp
319320

320-
# Note: memoization yields mild compiler speedup. Will need to be made
321-
# thread-safe for multithreading the compiler.
322-
@memoized_meth
323321
def distance(self, other):
324322
"""
325323
Compute the distance from ``self`` to ``other``.
@@ -366,21 +364,14 @@ def distance(self, other):
366364
# Case 1: `sit` is an IterationInterval with statically known
367365
# trip count. E.g. it ranges from 0 to 3; `other` performs a
368366
# constant access at 4
369-
370-
# To avoid evaluating expensive symbolic Lt or Gt operations,
371-
# we pre-empt such operations by checking if the values to be compared
372-
# to are symbolic, and skip this case if not.
373-
if not any(isinstance(i, sympy.Basic)
374-
for i in (sit.symbolic_min, sit.symbolic_max)):
375-
376-
for v in (self[n], other[n]):
377-
try:
378-
# Note: Boolean is split to make the conditional short
379-
# circuit more frequently for mild speedup
380-
if bool(v < sit.symbolic_min) or bool(v > sit.symbolic_max):
381-
return Vector(S.ImaginaryUnit)
382-
except TypeError:
383-
pass
367+
for v in (self[n], other[n]):
368+
# Note: To avoid evaluating expensive symbolic Lt or Gt operations,
369+
# we pre-empt such operations by checking if the values to be compared
370+
# to are symbolic, and skip this case if not.
371+
# Note: Boolean is split to make the conditional short
372+
# circuit more frequently for mild speedup.
373+
if smart_lt(v, sit.symbolic_min) or smart_gt(v, sit.symbolic_max):
374+
return Vector(S.ImaginaryUnit)
384375

385376
# Case 2: `sit` is an IterationInterval over a local SubDimension
386377
# and `other` performs a constant access

devito/tools/utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
'roundm', 'powerset', 'invert', 'flatten', 'single_or', 'filter_ordered',
1313
'as_mapper', 'filter_sorted', 'pprint', 'sweep', 'all_equal', 'as_list',
1414
'indices_to_slices', 'indices_to_sections', 'transitive_closure',
15-
'humanbytes', 'contains_val', 'sorted_priority', 'as_set', 'is_number']
15+
'humanbytes', 'contains_val', 'sorted_priority', 'as_set', 'is_number',
16+
'smart_lt', 'smart_gt']
1617

1718

1819
def prod(iterable, initial=1):
@@ -346,3 +347,32 @@ def key(i):
346347
return (v, str(type(i)))
347348

348349
return sorted(items, key=key, reverse=True)
350+
351+
352+
def avoid_symbolic_relations(func):
353+
"""
354+
Decorator to avoid calculating a relation symbolically if doing so may be slow.
355+
In the case that one of the values being compared is symbolic, just give up
356+
and return False.
357+
"""
358+
def wrapper(a, b):
359+
if any(isinstance(expr, sympy.Basic) for expr in (a, b)):
360+
# An argument is symbolic, so give up and assume False
361+
return False
362+
363+
try:
364+
return func(a, b)
365+
except TypeError:
366+
return False
367+
368+
return wrapper
369+
370+
371+
@avoid_symbolic_relations
372+
def smart_lt(a, b):
373+
return bool(a < b)
374+
375+
376+
@avoid_symbolic_relations
377+
def smart_gt(a, b):
378+
return bool(a > b)

0 commit comments

Comments
 (0)