Skip to content

Commit f5ce43a

Browse files
Merge pull request #2776 from devitocodes/guard-pairwise-or
compiler: Add Guards.pairwise_or
2 parents 72069c4 + 86dafc2 commit f5ce43a

File tree

14 files changed

+406
-29
lines changed

14 files changed

+406
-29
lines changed

devito/arch/archinfo.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,23 @@ def _detect_isa(self):
957957

958958
class Device(Platform):
959959

960+
"""
961+
A generic Device is based on the SIMT (Single Instruction, Multiple Threads)
962+
programming model. In this execution model, threads are batched together and
963+
execute the same instruction at the same time, though each thread operates on
964+
its own data. Intel, AMD, and Nvidia GPUs are all based on this model.
965+
Unfortunately they use different terminology to refer to the same or at least
966+
very similar concepts. Throughout Devito, whenever possible, we attempt to
967+
adopt a neutral terminology -- the docstrings below provide some examples.
968+
"""
969+
970+
thread_group_size = None
971+
"""
972+
A collection of threads that execute the same instruction in lockstep.
973+
The group size is a hardware-specific property. For example, this is a
974+
"warp" in NVidia GPUs and a "wavefront" in AMD GPUs.
975+
"""
976+
960977
def __init__(self, name, cores_logical=None, cores_physical=None, isa='cpp',
961978
max_threads_per_block=1024, max_threads_dimx=1024,
962979
max_threads_dimy=1024, max_threads_dimz=64,
@@ -1039,6 +1056,8 @@ def march(self):
10391056

10401057
class NvidiaDevice(Device):
10411058

1059+
thread_group_size = 32
1060+
10421061
max_mem_trans_nbytes = 128
10431062

10441063
@cached_property
@@ -1102,6 +1121,8 @@ class Blackwell(Hopper):
11021121

11031122
class AmdDevice(Device):
11041123

1124+
thread_group_size = 64
1125+
11051126
max_mem_trans_nbytes = 256
11061127

11071128
@cached_property

devito/ir/cgen/printer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,15 @@ def _print_Abs(self, expr):
276276
return f"fabs({self._print(arg)})"
277277
return self._print_fmath_func('abs', expr)
278278

279+
def _print_BitwiseNot(self, expr):
280+
# Unary function, single argument
281+
arg = expr.args[0]
282+
return f'~{self._print(arg)}'
283+
284+
def _print_BitwiseBinaryOp(self, expr):
285+
arg0, arg1 = expr.args
286+
return f'{self._print(arg0)} {expr.op} {self._print(arg1)}'
287+
279288
def _print_Add(self, expr, order=None):
280289
""""
281290
Print an addition.

devito/ir/clusters/cluster.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,23 @@ def from_clusters(cls, *clusters):
6969
"""
7070
assert len(clusters) > 0
7171
root = clusters[0]
72+
73+
if len(clusters) == 1:
74+
return root
75+
7276
if not all(root.ispace.is_compatible(c.ispace) for c in clusters):
7377
raise ValueError("Cannot build a Cluster from Clusters with "
7478
"incompatible IterationSpace")
7579
if not all(root.guards == c.guards for c in clusters):
7680
raise ValueError("Cannot build a Cluster from Clusters with "
7781
"non-homogeneous guards")
7882

83+
writes = set().union(*[c.scope.writes for c in clusters])
84+
reads = set().union(*[c.scope.reads for c in clusters])
85+
if any(f._mem_shared for f in writes & reads):
86+
raise ValueError("Cannot build a Cluster from Clusters with "
87+
"read-write conflicts on shared-memory Functions")
88+
7989
exprs = chain(*[c.exprs for c in clusters])
8090
ispace = IterationSpace.union(*[c.ispace for c in clusters])
8191

devito/ir/iet/algorithms.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from collections import OrderedDict
22

3-
from devito.ir.iet import (Expression, Increment, Iteration, List, Conditional, SyncSpot,
4-
Section, HaloSpot, ExpressionBundle)
5-
from devito.tools import timed_pass
3+
from devito.ir.iet import (
4+
Expression, Increment, Iteration, List, Conditional, SyncSpot, Section,
5+
HaloSpot, ExpressionBundle, Switch
6+
)
7+
from devito.ir.support import GuardSwitch, GuardCaseSwitch
8+
from devito.tools import as_mapper, timed_pass
69

710
__all__ = ['iet_build']
811

@@ -29,7 +32,12 @@ def iet_build(stree):
2932
body = ExpressionBundle(i.ispace, i.ops, i.traffic, body=exprs)
3033

3134
elif i.is_Conditional:
32-
body = Conditional(i.guard, queues.pop(i))
35+
if isinstance(i.guard, GuardSwitch):
36+
bundle, = queues.pop(i)
37+
cases, nodes = _unpack_switch_case(bundle)
38+
body = Switch(i.guard.arg, cases, nodes)
39+
else:
40+
body = Conditional(i.guard, queues.pop(i))
3341

3442
elif i.is_Iteration:
3543
if i.dim.is_Virtual:
@@ -55,3 +63,26 @@ def iet_build(stree):
5563
queues.setdefault(i.parent, []).append(body)
5664

5765
assert False
66+
67+
68+
def _unpack_switch_case(bundle):
69+
"""
70+
Helper to unpack an ExpressionBundle containing GuardCaseSwitch expressions
71+
into Switch cases and corresponding IET nodes.
72+
"""
73+
assert bundle.is_ExpressionBundle
74+
assert all(isinstance(e.rhs, GuardCaseSwitch) for e in bundle.body)
75+
76+
mapper = as_mapper(bundle.body, key=lambda e: e.rhs.case)
77+
78+
cases = list(mapper)
79+
80+
nodes = []
81+
for v in mapper.values():
82+
exprs = [e._rebuild(expr=e.expr._subs(e.rhs, e.rhs.arg)) for e in v]
83+
if len(exprs) > 1:
84+
nodes.append(List(body=exprs))
85+
else:
86+
nodes.append(*exprs)
87+
88+
return cases, nodes

devito/ir/iet/nodes.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
'Increment', 'Return', 'While', 'ListMajor', 'ParallelIteration',
3131
'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot', 'Pragma',
3232
'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait', 'UsingNamespace',
33-
'Using', 'CallableBody', 'Transfer', 'EmptyList']
33+
'Using', 'CallableBody', 'Transfer', 'EmptyList', 'Switch']
3434

3535
# First-class IET nodes
3636

@@ -406,6 +406,11 @@ def output(self):
406406
"""The Symbol/Indexed this Expression writes to."""
407407
return self.expr.lhs
408408

409+
@property
410+
def rhs(self):
411+
"""The right-hand side of the underlying expression."""
412+
return self.expr.rhs
413+
409414
@cached_property
410415
def reads(self):
411416
"""The Functions read by the Expression."""
@@ -892,6 +897,50 @@ def __repr__(self):
892897
return "<[%s] ? [%s]" % (ccode(self.condition), repr(self.then_body))
893898

894899

900+
class Switch(DoIf):
901+
902+
"""
903+
A node to express switch-case blocks.
904+
905+
Parameters
906+
----------
907+
condition : expr-like
908+
The switch condition.
909+
cases : expr-like or list of expr-like
910+
One or more case conditions; there must be one case per node in `nodes`,
911+
plus an optional default case.
912+
nodes : Node or list of Node
913+
One or more Case nodes.
914+
default : Node, optional
915+
The default case of the switch, if any.
916+
"""
917+
918+
_traversable = ['nodes', 'default']
919+
920+
def __init__(self, condition, cases, nodes, default=None):
921+
super().__init__(condition)
922+
923+
self.cases = as_tuple(cases)
924+
self.nodes = as_tuple(nodes)
925+
self.default = default
926+
927+
assert len(self.cases) == len(self.nodes)
928+
929+
def __repr__(self):
930+
return "f<Switch {ccode(self.condition)}; {self.ncases} cases>"
931+
932+
@property
933+
def ncases(self):
934+
return len(self.cases) + int(self.default is not None)
935+
936+
@property
937+
def as_mapper(self):
938+
retval = dict(zip(self.cases, self.nodes))
939+
if self.default:
940+
retval['default'] = self.default
941+
return retval
942+
943+
895944
# Second level IET nodes
896945

897946
class TimedList(List):

devito/ir/iet/visitors.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,12 @@ def visit_Conditional(self, o):
622622
else:
623623
return c.If(self.ccode(o.condition), then_body)
624624

625+
def visit_Switch(self, o):
626+
condition = self.ccode(o.condition)
627+
mapper = {k: self._visit(v) for k, v in o.as_mapper.items()}
628+
629+
return Switch(condition, mapper)
630+
625631
def visit_Iteration(self, o):
626632
body = flatten(self._visit(i) for i in self._blankline_logic(o.children))
627633

@@ -1426,6 +1432,12 @@ def visit_Conditional(self, o):
14261432
return o._rebuild(condition=condition, then_body=then_body,
14271433
else_body=else_body)
14281434

1435+
def visit_Switch(self, o):
1436+
condition = uxreplace(o.condition, self.mapper)
1437+
nodes = self._visit(o.nodes)
1438+
default = self._visit(o.default)
1439+
return o._rebuild(condition=condition, nodes=nodes, default=default)
1440+
14291441
def visit_PointerCast(self, o):
14301442
function = self.mapper.get(o.function, o.function)
14311443
obj = self.mapper.get(o.obj, o.obj)
@@ -1500,6 +1512,31 @@ class LambdaCollection(c.Collection):
15001512
pass
15011513

15021514

1515+
class Switch(c.Generable):
1516+
1517+
def __init__(self, condition, mapper):
1518+
self.condition = condition
1519+
1520+
# If the `default` case is present, it is encoded with the key "default"
1521+
self.mapper = mapper
1522+
1523+
def generate(self):
1524+
yield f"switch ({self.condition})"
1525+
yield "{"
1526+
1527+
for case, body in self.mapper.items():
1528+
if case == "default":
1529+
yield " default: {"
1530+
else:
1531+
yield f" case {case}: {{"
1532+
for line in body.generate():
1533+
yield f" {line}"
1534+
yield " break;"
1535+
yield " }"
1536+
1537+
yield "}"
1538+
1539+
15031540
class MultilineCall(c.Generable):
15041541

15051542
def __init__(self, name, arguments, is_expr=False, is_indirect=False,

0 commit comments

Comments
 (0)