Skip to content

Commit d12129f

Browse files
committed
compiler: Add support to generate switch-case
1 parent 6265862 commit d12129f

File tree

5 files changed

+192
-12
lines changed

5 files changed

+192
-12
lines changed

devito/ir/iet/algorithms.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from collections import OrderedDict
22

3-
from devito.ir.iet import (Expression, Increment, Iteration, List, Conditional, SyncSpot,
4-
Section, HaloSpot, ExpressionBundle)
5-
from devito.tools import timed_pass
3+
from devito.ir.iet import (
4+
Expression, Increment, Iteration, List, Conditional, SyncSpot, Section,
5+
HaloSpot, ExpressionBundle, Switch
6+
)
7+
from devito.ir.support import GuardSwitch, GuardCaseSwitch
8+
from devito.tools import as_mapper, timed_pass
69

710
__all__ = ['iet_build']
811

@@ -29,7 +32,12 @@ def iet_build(stree):
2932
body = ExpressionBundle(i.ispace, i.ops, i.traffic, body=exprs)
3033

3134
elif i.is_Conditional:
32-
body = Conditional(i.guard, queues.pop(i))
35+
if isinstance(i.guard, GuardSwitch):
36+
bundle, = queues.pop(i)
37+
cases, nodes = _unpack_switch_case(bundle)
38+
body = Switch(i.guard.arg, cases, nodes)
39+
else:
40+
body = Conditional(i.guard, queues.pop(i))
3341

3442
elif i.is_Iteration:
3543
if i.dim.is_Virtual:
@@ -55,3 +63,26 @@ def iet_build(stree):
5563
queues.setdefault(i.parent, []).append(body)
5664

5765
assert False
66+
67+
68+
def _unpack_switch_case(bundle):
69+
"""
70+
Helper to unpack an ExpressionBundle containing GuardCaseSwitch expressions
71+
into Switch cases and corresponding IET nodes.
72+
"""
73+
assert bundle.is_ExpressionBundle
74+
assert all(isinstance(e.rhs, GuardCaseSwitch) for e in bundle.body)
75+
76+
mapper = as_mapper(bundle.body, key=lambda e: e.rhs.case)
77+
78+
cases = list(mapper)
79+
80+
nodes = []
81+
for v in mapper.values():
82+
exprs = [e._rebuild(expr=e.expr._subs(e.rhs, e.rhs.arg)) for e in v]
83+
if len(exprs) > 1:
84+
nodes.append(List(body=exprs))
85+
else:
86+
nodes.append(*exprs)
87+
88+
return cases, nodes

devito/ir/iet/nodes.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
'Increment', 'Return', 'While', 'ListMajor', 'ParallelIteration',
3131
'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot', 'Pragma',
3232
'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait', 'UsingNamespace',
33-
'Using', 'CallableBody', 'Transfer', 'EmptyList']
33+
'Using', 'CallableBody', 'Transfer', 'EmptyList', 'Switch']
3434

3535
# First-class IET nodes
3636

@@ -406,6 +406,11 @@ def output(self):
406406
"""The Symbol/Indexed this Expression writes to."""
407407
return self.expr.lhs
408408

409+
@property
410+
def rhs(self):
411+
"""The right-hand side of the underlying expression."""
412+
return self.expr.rhs
413+
409414
@cached_property
410415
def reads(self):
411416
"""The Functions read by the Expression."""
@@ -892,6 +897,50 @@ def __repr__(self):
892897
return "<[%s] ? [%s]" % (ccode(self.condition), repr(self.then_body))
893898

894899

900+
class Switch(DoIf):
901+
902+
"""
903+
A node to express switch-case blocks.
904+
905+
Parameters
906+
----------
907+
condition : expr-like
908+
The switch condition.
909+
cases : expr-like or list of expr-like
910+
One or more case conditions; there must be one case per node in `nodes`,
911+
plus an optional default case.
912+
nodes : Node or list of Node
913+
One or more Case nodes.
914+
default : Node, optional
915+
The default case of the switch, if any.
916+
"""
917+
918+
_traversable = ['nodes', 'default']
919+
920+
def __init__(self, condition, cases, nodes, default=None):
921+
super().__init__(condition)
922+
923+
self.cases = as_tuple(cases)
924+
self.nodes = as_tuple(nodes)
925+
self.default = default
926+
927+
assert len(self.cases) == len(self.nodes)
928+
929+
def __repr__(self):
930+
return "f<Switch {ccode(self.condition)}; {self.ncases} cases>"
931+
932+
@property
933+
def ncases(self):
934+
return len(self.cases) + int(self.default is not None)
935+
936+
@property
937+
def as_mapper(self):
938+
retval = dict(zip(self.cases, self.nodes))
939+
if self.default:
940+
retval['default'] = self.default
941+
return retval
942+
943+
895944
# Second level IET nodes
896945

897946
class TimedList(List):

devito/ir/iet/visitors.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,12 @@ def visit_Conditional(self, o):
622622
else:
623623
return c.If(self.ccode(o.condition), then_body)
624624

625+
def visit_Switch(self, o):
626+
condition = self.ccode(o.condition)
627+
mapper = {k: self._visit(v) for k, v in o.as_mapper.items()}
628+
629+
return Switch(condition, mapper)
630+
625631
def visit_Iteration(self, o):
626632
body = flatten(self._visit(i) for i in self._blankline_logic(o.children))
627633

@@ -1426,6 +1432,12 @@ def visit_Conditional(self, o):
14261432
return o._rebuild(condition=condition, then_body=then_body,
14271433
else_body=else_body)
14281434

1435+
def visit_Switch(self, o):
1436+
condition = uxreplace(o.condition, self.mapper)
1437+
nodes = self._visit(o.nodes)
1438+
default = self._visit(o.default)
1439+
return o._rebuild(condition=condition, nodes=nodes, default=default)
1440+
14291441
def visit_PointerCast(self, o):
14301442
function = self.mapper.get(o.function, o.function)
14311443
obj = self.mapper.get(o.obj, o.obj)
@@ -1500,6 +1512,31 @@ class LambdaCollection(c.Collection):
15001512
pass
15011513

15021514

1515+
class Switch(c.Generable):
1516+
1517+
def __init__(self, condition, mapper):
1518+
self.condition = condition
1519+
1520+
# If the `default` case is present, it is encoded with the key "default"
1521+
self.mapper = mapper
1522+
1523+
def generate(self):
1524+
yield f"switch ({self.condition})"
1525+
yield "{"
1526+
1527+
for case, body in self.mapper.items():
1528+
if case == "default":
1529+
yield " default: {"
1530+
else:
1531+
yield f" case {case}: {{"
1532+
for line in body.generate():
1533+
yield f" {line}"
1534+
yield " break;"
1535+
yield " }"
1536+
1537+
yield "}"
1538+
1539+
15031540
class MultilineCall(c.Generable):
15041541

15051542
def __init__(self, name, arguments, is_expr=False, is_indirect=False,

devito/ir/support/guards.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,25 @@
88
from operator import ge, gt, le, lt
99

1010
from functools import singledispatch
11-
from sympy import And, Ge, Gt, Le, Lt, Mul, true
11+
from sympy import And, Expr, Ge, Gt, Le, Lt, Mul, true
1212
from sympy.logic.boolalg import BooleanFunction
1313
import numpy as np
1414

1515
from devito.ir.support.space import Forward, IterationDirection
16-
from devito.symbolics import CondEq, CondNe
16+
from devito.symbolics import CondEq, CondNe, search
1717
from devito.tools import Pickable, as_tuple, frozendict, split
1818
from devito.types import Dimension, LocalObject
1919

2020
__all__ = ['GuardFactor', 'GuardBound', 'GuardBoundNext', 'BaseGuardBound',
21-
'BaseGuardBoundNext', 'GuardOverflow', 'Guards', 'GuardExpr']
21+
'BaseGuardBoundNext', 'GuardOverflow', 'Guards', 'GuardExpr',
22+
'GuardSwitch', 'GuardCaseSwitch']
2223

2324

24-
class Guard:
25+
class AbstractGuard:
26+
pass
27+
28+
29+
class Guard(AbstractGuard):
2530

2631
@property
2732
def _args_rebuild(self):
@@ -217,6 +222,35 @@ class GuardOverflowLt(BaseGuardOverflow, Lt):
217222
}
218223

219224

225+
class GuardSwitch(AbstractGuard, Expr):
226+
227+
"""
228+
A switch guard (akin to C's switch-case) that can be used to select
229+
between multiple cases at runtime.
230+
"""
231+
232+
def __new__(cls, arg, **kwargs):
233+
return Expr.__new__(cls, arg)
234+
235+
@property
236+
def arg(self):
237+
return self.args[0]
238+
239+
240+
class GuardCaseSwitch(GuardSwitch):
241+
242+
"""
243+
A case within a GuardSwitch.
244+
"""
245+
246+
def __new__(cls, arg, case, **kwargs):
247+
return Expr.__new__(cls, arg, case)
248+
249+
@property
250+
def case(self):
251+
return self.args[1]
252+
253+
220254
class Guards(frozendict):
221255

222256
"""

tests/test_iet.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from devito import (Eq, Grid, Function, TimeFunction, Operator, Dimension, # noqa
99
switchconfig)
1010
from devito.ir.iet import (
11-
Call, Callable, Conditional, Definition, DeviceCall, DummyExpr, Iteration, List,
12-
KernelLaunch, Dereference, Lambda, ElementalFunction, CGen, FindSymbols,
13-
filter_iterations, make_efunc, retrieve_iteration_tree, Transformer
11+
Call, Callable, Conditional, Definition, DeviceCall, DummyExpr, Iteration,
12+
List, KernelLaunch, Dereference, Lambda, Switch, ElementalFunction, CGen,
13+
FindSymbols, filter_iterations, make_efunc, retrieve_iteration_tree,
14+
Transformer
1415
)
1516
from devito.ir import SymbolRegistry
1617
from devito.passes.iet.engine import Graph
@@ -509,3 +510,31 @@ def test_dereference_base_plus_off():
509510
deref = Dereference(x, ptr, offset=off)
510511

511512
assert str(deref) == "float (*restrict x)[3] = (float (*)[3]) (p + offs);"
513+
514+
515+
def test_switch_case():
516+
flag = Symbol(name='flag')
517+
a = Symbol(name='a')
518+
519+
cases = [0, 1]
520+
nodes = [DummyExpr(a, 1), DummyExpr(a, 2)]
521+
default = DummyExpr(a, 0)
522+
523+
switch = Switch(flag, cases, nodes, default=default)
524+
525+
assert str(switch) == """\
526+
switch (flag)
527+
{
528+
case 0: {
529+
a = 1;
530+
break;
531+
}
532+
case 1: {
533+
a = 2;
534+
break;
535+
}
536+
default: {
537+
a = 0;
538+
break;
539+
}
540+
}"""

0 commit comments

Comments
 (0)