Skip to content

Commit 2b2fa02

Browse files
committed
compiler: Add Guards.pairwise_or
1 parent 77c6edc commit 2b2fa02

File tree

2 files changed

+113
-2
lines changed

2 files changed

+113
-2
lines changed

devito/ir/support/guards.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
(e.g., Eq, Cluster, ...) should be evaluated at runtime.
55
"""
66

7+
from collections import Counter, defaultdict
78
from operator import ge, gt, le, lt
89

910
from functools import singledispatch
@@ -254,6 +255,20 @@ def xandg(self, d, guard):
254255

255256
return Guards(m)
256257

258+
def pairwise_or(self, d, *guards):
259+
m = dict(self)
260+
261+
if d in m:
262+
guards.append(m[d])
263+
264+
g = pairwise_or(*guards)
265+
if g is true:
266+
m.pop(d, None)
267+
else:
268+
m[d] = g
269+
270+
return Guards(m)
271+
257272
def impose(self, d, guard):
258273
m = dict(self)
259274

@@ -405,3 +420,72 @@ def simplify_and(relation, v):
405420
new_args.append(v)
406421

407422
return And(*(new_args + other))
423+
424+
425+
def pairwise_or(*guards):
426+
"""
427+
Given a series of guards, create a new guard that combines them by taking
428+
the OR of each subset of homogeneous components. Here, "homogeneous" means
429+
that they apply to the same variable with the same operator (e.g., given
430+
`y >= 2`, `y >= 3` is homogeneous, while `z >= 3` and `y <= 4` are not).
431+
432+
Examples
433+
--------
434+
Given:
435+
436+
g0 = {flag0 and z >= 2 and z <= 10 and y >= 3}
437+
g1 = {z >= 4 and z <= 8}
438+
g2 = {y >= 2 and y <= 5}
439+
440+
Where `flag0` is of type GuardExpr, then:
441+
442+
Return:
443+
444+
{z >= 2 and z <= 10 and y >= 2}
445+
"""
446+
errmsg = lambda g: f"Cannot handle guard component of type {type(g)}"
447+
448+
flags = Counter()
449+
mapper = defaultdict(list)
450+
451+
# Analysis
452+
for guard in guards:
453+
if guard is true or guard is None:
454+
continue
455+
elif isinstance(guard, And):
456+
components = guard.args
457+
elif isinstance(guard, GuardExpr) or guard.is_Relational:
458+
components = [guard]
459+
else:
460+
raise NotImplementedError(errmsg(guard))
461+
462+
for g in components:
463+
if isinstance(g, GuardExpr):
464+
flags[g] += 1
465+
elif g.is_Relational and g.lhs.is_Symbol and g.rhs.is_Number:
466+
mapper[(g.lhs, type(g))].append(g.rhs)
467+
else:
468+
raise NotImplementedError(errmsg(g))
469+
470+
# Synthesis
471+
guard = true
472+
for (s, op), v in mapper.items():
473+
if len(v) < len(guards):
474+
# Not all guards contributed to this component; cannot simplify
475+
pass
476+
elif op in (Ge, Gt):
477+
guard = And(guard, op(s, min(v)))
478+
else:
479+
guard = And(guard, op(s, max(v)))
480+
481+
for flag, v in flags.items():
482+
if v == len(guards):
483+
guard = And(guard, flag)
484+
elif flag.initvalue.free_symbols & guard.free_symbols:
485+
# We still lack the logic to properly handle this case
486+
raise NotImplementedError(errmsg(flag))
487+
else:
488+
# Safe to ignore
489+
pass
490+
491+
return guard

tests/test_symbolics.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import pytest
55
import numpy as np
66

7-
from sympy import And, Expr, Number, Symbol
7+
from sympy import And, Expr, Number, Symbol, true
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
1010
Min, Max, Real, Imag, Conj, SubDomain, configuration)
1111
from devito.finite_differences.differentiable import SafeInv, Weights, Mul
1212
from devito.ir import Expression, FindNodes, ccode
13-
from devito.ir.support.guards import GuardExpr, simplify_and
13+
from devito.ir.support.guards import GuardExpr, simplify_and, pairwise_or
1414
from devito.mpi.halo_scheme import HaloTouch
1515
from devito.symbolics import (
1616
retrieve_functions, retrieve_indexed, evalrel, CallFromPointer, Cast, # noqa
@@ -671,6 +671,33 @@ def test_guard_expr_Le_Ge_mixed():
671671
assert v11 is And(g5, g6)
672672

673673

674+
def test_guard_pairwise_or():
675+
grid = Grid(shape=(3, 3, 3))
676+
x, y, z = grid.dimensions
677+
678+
flag = GuardExpr('flag', initvalue=And(x >= 4, x <= 14))
679+
680+
g0 = And(flag, z >= 8, z <= 39)
681+
g1 = And(z >= 8, z <= 40)
682+
g2 = y >= 9
683+
v0 = pairwise_or(g0, g1, g2)
684+
assert v0 is true
685+
686+
g3 = And(z >= 7, z <= 40)
687+
g4 = And(z >= 8, z <= 42, y >= 9)
688+
v1 = pairwise_or(g0, g3, g4)
689+
assert v1 == And(z >= 7, z <= 42)
690+
691+
# Some unsupported cases
692+
g5 = And(flag, z >= 9, x >= 5)
693+
g6 = x >= 3
694+
with pytest.raises(NotImplementedError):
695+
pairwise_or(g5, g6)
696+
g7 = And(z <= y)
697+
with pytest.raises(NotImplementedError):
698+
pairwise_or(g0, g7)
699+
700+
674701
def test_canonical_ordering_of_weights():
675702
grid = Grid(shape=(3, 3, 3))
676703
x, y, z = grid.dimensions

0 commit comments

Comments
 (0)