|
1 | 1 | from collections import OrderedDict |
2 | 2 | from collections.abc import Iterable |
3 | | -from functools import reduce |
| 3 | +from functools import reduce, wraps |
4 | 4 | from itertools import chain, combinations, groupby, product, zip_longest |
5 | 5 | from operator import attrgetter, mul |
6 | 6 | import types |
@@ -349,30 +349,36 @@ def key(i): |
349 | 349 | return sorted(items, key=key, reverse=True) |
350 | 350 |
|
351 | 351 |
|
352 | | -def avoid_symbolic_relations(func): |
| 352 | +def avoid_symbolic(default_val): |
353 | 353 | """ |
354 | | - Decorator to avoid calculating a relation symbolically if doing so may be slow. |
355 | | - In the case that one of the values being compared is symbolic, just give up |
356 | | - and return False. |
| 354 | + Decorator to avoid calling a function where doing so will result in symbolic |
| 355 | + computation being performed. For use if symbolic computation may be slow. In |
| 356 | + the case that an arg is symbolic, just give up and return a default value. |
357 | 357 | """ |
358 | | - def wrapper(a, b): |
359 | | - if any(isinstance(expr, sympy.Basic) for expr in (a, b)): |
360 | | - # An argument is symbolic, so give up and assume False |
361 | | - return False |
| 358 | + def _avoid_symbolic(func): |
| 359 | + @wraps(func) |
| 360 | + def wrapper(*args): |
| 361 | + if any(isinstance(expr, sympy.Basic) for expr in args): |
| 362 | + # An argument is symbolic, so give up and assume default |
| 363 | + return default_val |
362 | 364 |
|
363 | | - try: |
364 | | - return func(a, b) |
365 | | - except TypeError: |
366 | | - return False |
| 365 | + try: |
| 366 | + return func(*args) |
| 367 | + except TypeError: |
| 368 | + return default_val |
| 369 | + |
| 370 | + return wrapper |
367 | 371 |
|
368 | | - return wrapper |
| 372 | + return _avoid_symbolic |
369 | 373 |
|
370 | 374 |
|
371 | | -@avoid_symbolic_relations |
| 375 | +@avoid_symbolic(False) |
372 | 376 | def smart_lt(a, b): |
| 377 | + """An Lt that gives up and returns False if supplied a symbolic argument""" |
373 | 378 | return bool(a < b) |
374 | 379 |
|
375 | 380 |
|
376 | | -@avoid_symbolic_relations |
| 381 | +@avoid_symbolic(False) |
377 | 382 | def smart_gt(a, b): |
| 383 | + """A Gt that gives up and returns False if supplied a symbolic argument""" |
378 | 384 | return bool(a > b) |
0 commit comments