Skip to content

Commit 822f146

Browse files
Merge pull request #2677 from enwask/cached-scopes
compiler: Reapply "Cache `Scope` + `Dependence` instances"
2 parents c7c5277 + 71b7388 commit 822f146

File tree

8 files changed

+212
-96
lines changed

8 files changed

+212
-96
lines changed

devito/ir/clusters/algorithms.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from devito.ir.equations import OpMin, OpMax, identity_mapper
1313
from devito.ir.clusters.analysis import analyze
1414
from devito.ir.clusters.cluster import Cluster, ClusterGroup
15-
from devito.ir.clusters.visitors import Queue, QueueStateful, cluster_pass
15+
from devito.ir.clusters.visitors import Queue, cluster_pass
16+
from devito.ir.support import Scope
1617
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
1718
from devito.mpi.reduction_scheme import DistReduce
1819
from devito.symbolics import (limits_mapper, retrieve_indexed, uxreplace,
@@ -77,7 +78,7 @@ def impose_total_ordering(clusters):
7778
return processed
7879

7980

80-
class Schedule(QueueStateful):
81+
class Schedule(Queue):
8182

8283
"""
8384
This special Queue produces a new sequence of "scheduled" Clusters, which
@@ -135,7 +136,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
135136
# `clusters` are supposed to share it
136137
candidates = prefix[-1].dim._defines
137138

138-
scope = self._fetch_scope(clusters)
139+
scope = Scope(flatten(c.exprs for c in clusters))
139140

140141
# Handle the nastiest case -- ambiguity due to the presence of both a
141142
# flow- and an anti-dependence.

devito/ir/clusters/analysis.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,39 @@
1-
from devito.ir.clusters.visitors import QueueStateful
1+
from devito.ir.clusters.cluster import Cluster
2+
from devito.ir.clusters.visitors import Queue
23
from devito.ir.support import (AFFINE, PARALLEL, PARALLEL_INDEP, PARALLEL_IF_ATOMIC,
3-
SEQUENTIAL)
4+
SEQUENTIAL, Property, Scope)
5+
from devito.ir.support.space import IterationSpace
46
from devito.tools import as_tuple, flatten, timed_pass
7+
from devito.types.dimension import Dimension
58

69
__all__ = ['analyze']
710

811

12+
# Describes properties fetched by a `Detector`
13+
Properties = dict[Cluster, dict[Dimension, set[Property]]]
14+
15+
916
@timed_pass()
1017
def analyze(clusters):
11-
state = QueueStateful.State()
18+
properties: Properties = {}
1219

1320
# Collect properties
14-
clusters = Parallelism(state).process(clusters)
15-
clusters = Affiness(state).process(clusters)
21+
clusters = Parallelism().process(clusters, properties=properties)
22+
clusters = Affiness().process(clusters, properties=properties)
1623

1724
# Reconstruct Clusters attaching the discovered properties
18-
processed = [c.rebuild(properties=state.properties.get(c)) for c in clusters]
25+
processed = [c.rebuild(properties=properties.get(c)) for c in clusters]
1926

2027
return processed
2128

2229

23-
class Detector(QueueStateful):
30+
class Detector(Queue):
2431

25-
def process(self, elements):
26-
return self._process_fatd(elements, 1)
32+
def process(self, clusters: list[Cluster], properties: Properties) -> list[Cluster]:
33+
return self._process_fatd(clusters, 1, properties=properties)
2734

28-
def callback(self, clusters, prefix):
35+
def callback(self, clusters: list[Cluster], prefix: IterationSpace | None,
36+
properties: Properties) -> list[Cluster]:
2937
if not prefix:
3038
return clusters
3139

@@ -41,11 +49,19 @@ def callback(self, clusters, prefix):
4149
# Update `self.state`
4250
if retval:
4351
for c in clusters:
44-
properties = self.state.properties.setdefault(c, {})
45-
properties.setdefault(d, set()).update(retval)
52+
c_properties = properties.setdefault(c, {})
53+
c_properties.setdefault(d, set()).update(retval)
4654

4755
return clusters
4856

57+
def _callback(self, clusters: list[Cluster], dim: Dimension,
58+
prefix: IterationSpace | None) -> set[Property]:
59+
"""
60+
Callback to be implemented by subclasses. It should return a set of
61+
properties for the given dimension.
62+
"""
63+
raise NotImplementedError()
64+
4965

5066
class Parallelism(Detector):
5167

@@ -72,27 +88,27 @@ class Parallelism(Detector):
7288
the 'write' is known to be an associative and commutative increment
7389
"""
7490

75-
def _callback(self, clusters, d, prefix):
91+
def _callback(self, clusters, dim, prefix):
7692
# Rule out if non-unitary increment Dimension (e.g., `t0=(time+1)%2`)
77-
if any(c.sub_iterators[d] for c in clusters):
78-
return SEQUENTIAL
93+
if any(c.sub_iterators[dim] for c in clusters):
94+
return {SEQUENTIAL}
7995

8096
# All Dimensions up to and including `i-1`
8197
prev = flatten(i.dim._defines for i in prefix[:-1])
8298

8399
is_parallel_indep = True
84100
is_parallel_atomic = False
85101

86-
scope = self._fetch_scope(clusters)
102+
scope = Scope(flatten(c.exprs for c in clusters))
87103
for dep in scope.d_all_gen():
88-
test00 = dep.is_indep(d) and not dep.is_storage_related(d)
104+
test00 = dep.is_indep(dim) and not dep.is_storage_related(dim)
89105
test01 = all(dep.is_reduce_atmost(i) for i in prev)
90106
if test00 and test01:
91107
continue
92108

93109
test1 = len(prev) > 0 and any(dep.is_carried(i) for i in prev)
94110
if test1:
95-
is_parallel_indep &= (dep.distance_mapper.get(d.root) == 0)
111+
is_parallel_indep &= (dep.distance_mapper.get(dim.root) == 0)
96112
continue
97113

98114
if dep.function in scope.initialized:
@@ -103,14 +119,14 @@ def _callback(self, clusters, d, prefix):
103119
is_parallel_atomic = True
104120
continue
105121

106-
return SEQUENTIAL
122+
return {SEQUENTIAL}
107123

108124
if is_parallel_atomic:
109-
return PARALLEL_IF_ATOMIC
125+
return {PARALLEL_IF_ATOMIC}
110126
elif is_parallel_indep:
111127
return {PARALLEL, PARALLEL_INDEP}
112128
else:
113-
return PARALLEL
129+
return {PARALLEL}
114130

115131

116132
class Affiness(Detector):
@@ -119,8 +135,11 @@ class Affiness(Detector):
119135
Detect the AFFINE Dimensions.
120136
"""
121137

122-
def _callback(self, clusters, d, prefix):
123-
scope = self._fetch_scope(clusters)
138+
def _callback(self, clusters, dim, prefix):
139+
scope = Scope(flatten(c.exprs for c in clusters))
124140
accesses = [a for a in scope.accesses if not a.is_scalar]
125-
if all(a.is_regular and a.affine_if_present(d._defines) for a in accesses):
126-
return AFFINE
141+
142+
if all(a.is_regular and a.affine_if_present(dim._defines) for a in accesses):
143+
return {AFFINE}
144+
145+
return set()

devito/ir/clusters/visitors.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from collections import defaultdict
21
from collections.abc import Iterable
32

43
from itertools import groupby
54

6-
from devito.ir.support import IterationSpace, Scope, null_ispace
7-
from devito.tools import as_tuple, flatten, timed_pass
5+
from devito.ir.support import IterationSpace, null_ispace
6+
from devito.tools import flatten, timed_pass
87

9-
__all__ = ['Queue', 'QueueStateful', 'cluster_pass']
8+
__all__ = ['Queue', 'cluster_pass']
109

1110

1211
class Queue:
@@ -113,48 +112,6 @@ def _process_fatd(self, clusters, level, prefix=None, **kwargs):
113112
return processed
114113

115114

116-
class QueueStateful(Queue):
117-
118-
"""
119-
A Queue carrying along some state. This is useful when one wants to avoid
120-
expensive re-computations of information.
121-
"""
122-
123-
class State:
124-
125-
def __init__(self):
126-
self.properties = {}
127-
self.scopes = {}
128-
129-
def __init__(self, state=None):
130-
super().__init__()
131-
self.state = state or QueueStateful.State()
132-
133-
def _fetch_scope(self, clusters):
134-
exprs = flatten(c.exprs for c in as_tuple(clusters))
135-
key = tuple(exprs)
136-
if key not in self.state.scopes:
137-
self.state.scopes[key] = Scope(exprs)
138-
return self.state.scopes[key]
139-
140-
def _fetch_properties(self, clusters, prefix):
141-
# If the situation is:
142-
#
143-
# t
144-
# x0
145-
# <some clusters>
146-
# x1
147-
# <some other clusters>
148-
#
149-
# then retain only the "common" properties, that is those along `t`
150-
properties = defaultdict(set)
151-
for c in clusters:
152-
v = self.state.properties.get(c, {})
153-
for i in prefix:
154-
properties[i.dim].update(v.get(i.dim, set()))
155-
return properties
156-
157-
158115
class Prefix(IterationSpace):
159116

160117
def __init__(self, ispace, guards, properties, syncs):

devito/ir/support/basic.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from collections.abc import Iterable
12
from itertools import chain, product
23
from functools import cached_property
4+
from typing import Callable
35

4-
from sympy import S
6+
from sympy import S, Expr
57
import sympy
68

79
from devito.ir.support.space import Backward, null_ispace
@@ -12,7 +14,7 @@
1214
uxreplace)
1315
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
1416
flatten, memoized_meth, memoized_generator, smart_gt,
15-
smart_lt)
17+
smart_lt, CacheInstances)
1618
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
1719
CriticalRegion, Function, Symbol, Temp, TempArray,
1820
TBArray)
@@ -246,7 +248,7 @@ def __eq__(self, other):
246248
self.ispace == other.ispace)
247249

248250
def __hash__(self):
249-
return super().__hash__()
251+
return hash((self.access, self.mode, self.timestamp, self.ispace))
250252

251253
@property
252254
def function(self):
@@ -624,7 +626,7 @@ def is_imaginary(self):
624626
return S.ImaginaryUnit in self.distance
625627

626628

627-
class Dependence(Relation):
629+
class Dependence(Relation, CacheInstances):
628630

629631
"""
630632
A data dependence between two TimedAccess objects.
@@ -823,17 +825,26 @@ def project(self, function):
823825
return DependenceGroup(i for i in self if i.function is function)
824826

825827

826-
class Scope:
828+
class Scope(CacheInstances):
829+
830+
# Describes a rule for dependencies
831+
Rule = Callable[[TimedAccess, TimedAccess], bool]
827832

828-
def __init__(self, exprs, rules=None):
833+
@classmethod
834+
def _preprocess_args(cls, exprs: Expr | Iterable[Expr],
835+
**kwargs) -> tuple[tuple, dict]:
836+
return (as_tuple(exprs),), kwargs
837+
838+
def __init__(self, exprs: tuple[Expr],
839+
rules: Rule | tuple[Rule] | None = None) -> None:
829840
"""
830841
A Scope enables data dependence analysis on a totally ordered sequence
831842
of expressions.
832843
"""
833-
self.exprs = as_tuple(exprs)
844+
self.exprs = exprs
834845

835846
# A set of rules to drive the collection of dependencies
836-
self.rules = as_tuple(rules)
847+
self.rules: tuple[Scope.Rule] = as_tuple(rules) # type: ignore[assignment]
837848
assert all(callable(i) for i in self.rules)
838849

839850
@memoized_generator
@@ -1172,12 +1183,10 @@ def d_from_access_gen(self, accesses):
11721183
Generate all flow, anti, and output dependences involving any of
11731184
the given TimedAccess objects.
11741185
"""
1175-
accesses = as_tuple(accesses)
1186+
accesses = set(as_tuple(accesses))
11761187
for d in self.d_all_gen():
1177-
for i in accesses:
1178-
if d.source == i or d.sink == i:
1179-
yield d
1180-
break
1188+
if accesses & {d.source, d.sink}:
1189+
yield d
11811190

11821191
@memoized_meth
11831192
def d_from_access(self, accesses):

devito/operator/operator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
from devito.symbolics import estimate_cost, subs_op_args
3232
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3333
flatten, filter_sorted, frozendict, is_integer,
34-
split, timed_pass, timed_region, contains_val)
34+
split, timed_pass, timed_region, contains_val,
35+
CacheInstances)
3536
from devito.types import (Buffer, Evaluable, host_layer, device_layer,
3637
disk_layer)
3738
from devito.types.dimension import Thickness
@@ -245,6 +246,9 @@ def _build(cls, expressions, **kwargs):
245246
op._dtype, op._dspace = irs.clusters.meta
246247
op._profiler = profiler
247248

249+
# Clear build-scoped instance caches
250+
CacheInstances.clear_caches()
251+
248252
return op
249253

250254
def __init__(self, *args, **kwargs):

0 commit comments

Comments
 (0)