Skip to content

Commit 8022ca3

Browse files
EdCauntmloubout
authored andcommitted
misc: Make avoid_symbolic decorator more generic
1 parent 3d6eff0 commit 8022ca3

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

devito/tools/utils.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import OrderedDict
22
from collections.abc import Iterable
3-
from functools import reduce
3+
from functools import reduce, wraps
44
from itertools import chain, combinations, groupby, product, zip_longest
55
from operator import attrgetter, mul
66
import types
@@ -349,30 +349,36 @@ def key(i):
349349
return sorted(items, key=key, reverse=True)
350350

351351

352-
def avoid_symbolic_relations(func):
352+
def avoid_symbolic(default_val):
353353
"""
354-
Decorator to avoid calculating a relation symbolically if doing so may be slow.
355-
In the case that one of the values being compared is symbolic, just give up
356-
and return False.
354+
Decorator to avoid calling a function where doing so will result in symbolic
355+
computation being performed. For use if symbolic computation may be slow. In
356+
the case that an arg is symbolic, just give up and return a default value.
357357
"""
358-
def wrapper(a, b):
359-
if any(isinstance(expr, sympy.Basic) for expr in (a, b)):
360-
# An argument is symbolic, so give up and assume False
361-
return False
358+
def _avoid_symbolic(func):
359+
@wraps(func)
360+
def wrapper(*args):
361+
if any(isinstance(expr, sympy.Basic) for expr in args):
362+
# An argument is symbolic, so give up and assume default
363+
return default_val
362364

363-
try:
364-
return func(a, b)
365-
except TypeError:
366-
return False
365+
try:
366+
return func(*args)
367+
except TypeError:
368+
return default_val
369+
370+
return wrapper
367371

368-
return wrapper
372+
return _avoid_symbolic
369373

370374

371-
@avoid_symbolic_relations
375+
@avoid_symbolic(False)
372376
def smart_lt(a, b):
377+
"""An Lt that gives up and returns False if supplied a symbolic argument"""
373378
return bool(a < b)
374379

375380

376-
@avoid_symbolic_relations
381+
@avoid_symbolic(False)
377382
def smart_gt(a, b):
383+
"""A Gt that gives up and returns False if supplied a symbolic argument"""
378384
return bool(a > b)

0 commit comments

Comments
 (0)