55"""
66
77from sympy import And , Ge , Gt , Le , Lt , Mul , true
8+ from sympy .logic .boolalg import BooleanFunction
9+ import numpy as np
810
911from devito .ir .support .space import Forward , IterationDirection
1012from devito .symbolics import CondEq , CondNe
1113from devito .tools import Pickable , as_tuple , frozendict , split
12- from devito .types import Dimension
14+ from devito .types import Dimension , LocalObject
1315
1416__all__ = ['GuardFactor' , 'GuardBound' , 'GuardBoundNext' , 'BaseGuardBound' ,
15- 'BaseGuardBoundNext' , 'GuardOverflow' , 'Guards' ]
17+ 'BaseGuardBoundNext' , 'GuardOverflow' , 'Guards' , 'GuardExpr' ]
1618
1719
1820class Guard :
@@ -273,6 +275,40 @@ def filter(self, key):
273275 return Guards (m )
274276
275277
278+ class GuardExpr (LocalObject , BooleanFunction ):
279+
280+ """
281+ A boolean symbol that can be used as a guard. As such, it can be chained
282+ with other relations using the standard boolean operators (&, |, ...).
283+
284+ Being a LocalObject, a GuardExpr may carry an `initvalue`, which is
285+ the value that the guard assumes at the beginning of the scope where
286+ it is defined.
287+
288+ Through the `supersets` argument, a GuardExpr may also carry a set of
289+ GuardExprs that are known to be more restrictive than itself. This is
290+ usesful, e.g., to avoid redundant checks when chaining multiple guards
291+ together (see `simplify_and`).
292+ """
293+
294+ dtype = np .bool
295+
296+ def __init__ (self , name , liveness = 'eager' , supersets = None , ** kwargs ):
297+ super ().__init__ (name , liveness = liveness , ** kwargs )
298+
299+ self .supersets = frozenset (as_tuple (supersets ))
300+
301+ def _hashable_content (self ):
302+ return super ()._hashable_content () + (self .supersets ,)
303+
304+ __hash__ = LocalObject .__hash__
305+
306+ def __eq__ (self , other ):
307+ return (isinstance (other , GuardExpr ) and
308+ super ().__eq__ (other ) and
309+ self .supersets == other .supersets )
310+
311+
276312# *** Utils
277313
278314def simplify_and (relation , v ):
@@ -291,10 +327,18 @@ def simplify_and(relation, v):
291327 else :
292328 candidates , other = [], [relation , v ]
293329
330+ # Quick check based on GuardExpr.supersets to avoid adding `v` to `relation`
331+ # if `relation` already includes a more restrictive guard than `v`
332+ if isinstance (v , GuardExpr ):
333+ if any (a in v .supersets for a in candidates ):
334+ return relation
335+
294336 covered = False
295337 new_args = []
296338 for a in candidates :
297- if a .lhs is v .lhs :
339+ if isinstance (a , GuardExpr ) or a .lhs is not v .lhs :
340+ new_args .append (a )
341+ else :
298342 covered = True
299343 try :
300344 if type (a ) in (Gt , Ge ) and v .rhs > a .rhs :
@@ -307,8 +351,7 @@ def simplify_and(relation, v):
307351 # E.g., `v.rhs = const + z_M` and `a.rhs = z_M`, so the inequalities
308352 # above are not evaluable to True/False
309353 new_args .append (a )
310- else :
311- new_args .append (a )
354+
312355 if not covered :
313356 new_args .append (v )
314357
0 commit comments