Skip to content

Commit dcb19e9

Browse files
FabioLuporinimloubout
authored andcommitted
compiler: Improve rcompile
1 parent 5efb668 commit dcb19e9

File tree

4 files changed

+55
-30
lines changed

4 files changed

+55
-30
lines changed

conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,8 @@ def pytest_runtest_makereport(item, call):
290290

291291
def pytest_make_parametrize_id(config, val, argname):
292292
"""
293-
Prevents pytest to make obsucre parameter names (param0, param1, ...)
294-
and default to str(val) instead for better log readability.
293+
Prevents pytest from making obscure parameter names (param0, param1, ...)
294+
and default to sympy.sstr(val) instead for better log readability.
295295
"""
296296
# First see if it has a name
297297
if hasattr(val, '__name__'):

devito/operator/operator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
switch_log_level)
1919
from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims
2020
from devito.ir.clusters import ClusterGroup, clusterize
21-
from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols,
22-
MetaCall, derive_parameters, iet_build)
21+
from devito.ir.iet import (Callable, CInterface, EntryFunction, DeviceFunction,
22+
FindSymbols, MetaCall, derive_parameters, iet_build)
2323
from devito.ir.support import AccessMode, SymbolRegistry
2424
from devito.ir.stree import stree_build
2525
from devito.operator.profiling import create_profile
@@ -1270,9 +1270,12 @@ def rcompile(expressions, kwargs, options, target=None):
12701270

12711271
# Recursive compilation is expensive, so we cache the result because sometimes
12721272
# it is called multiple times for the same input
1273-
compiled = RCompiles(expressions, cls).compile(**kwargs)
1273+
irs, byproduct0 = RCompiles(expressions, cls).compile(**kwargs)
12741274

1275-
return compiled
1275+
key = lambda i: isinstance(i, (EntryFunction, DeviceFunction))
1276+
byproduct = byproduct0.filter(key)
1277+
1278+
return irs, byproduct
12761279

12771280

12781281
# *** Misc helpers

devito/passes/iet/definitions.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def __init__(self, rcompile=None, sregistry=None, platform=None,
9191
self.rcompile = rcompile
9292
self.sregistry = sregistry
9393
self.platform = platform
94-
self.alloc_mapped = (options or {}).get('alloc_mapped', True)
9594

9695
def _alloc_object_on_low_lat_mem(self, site, obj, storage):
9796
"""
@@ -178,10 +177,6 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
178177
"""
179178
Allocate a mapped Array in the host high bandwidth memory.
180179
"""
181-
if not self.alloc_mapped:
182-
# Mapped array assumed to be preallocated, likely from rcompile
183-
return
184-
185180
decl = Definition(obj)
186181

187182
sizeof_dtypeN = SizeOf(obj.indexed._C_typedata)

devito/passes/iet/engine.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from collections import OrderedDict, defaultdict
1+
from collections import defaultdict
22
from functools import partial, singledispatch, wraps
33

44
import numpy as np
55
from sympy import Mul
66

77
from devito.ir.iet import (
88
Call, ExprStmt, Expression, Iteration, SyncSpot, AsyncCallable, FindNodes,
9-
FindSymbols, MapNodes, MetaCall, Transformer, EntryFunction,
10-
ThreadCallable, Uxreplace, derive_parameters
9+
FindSymbols, MapNodes, MetaCall, Transformer, EntryFunction, ThreadCallable,
10+
Uxreplace, derive_parameters
1111
)
1212
from devito.ir.support import SymbolRegistry
1313
from devito.mpi.distributed import MPINeighborhood
@@ -28,7 +28,31 @@
2828
__all__ = ['Graph', 'iet_pass', 'iet_visit']
2929

3030

31-
class Graph:
31+
class Byproduct:
32+
33+
"""
34+
A Byproduct is a mutable collection of metadata produced by one or more
35+
compiler passes.
36+
37+
This metadata may be used internally or by the caller itself, typically
38+
for code generation purposes.
39+
"""
40+
41+
def __init__(self, efuncs=None, includes=None, headers=None, namespaces=None,
42+
globs=None):
43+
self.efuncs = efuncs or {}
44+
self.includes = includes or []
45+
self.headers = headers or []
46+
self.namespaces = namespaces or []
47+
self.globals = globs or []
48+
49+
@property
50+
def funcs(self):
51+
return tuple(MetaCall(v, True) for v in self.efuncs.values()
52+
if not isinstance(v, EntryFunction))
53+
54+
55+
class Graph(Byproduct):
3256

3357
"""
3458
DAG representation of a call graph.
@@ -49,16 +73,9 @@ class Graph:
4973
"""
5074

5175
def __init__(self, iet, options=None, sregistry=None, **kwargs):
52-
self.efuncs = OrderedDict([(iet.name, iet)])
53-
5476
self.sregistry = sregistry
5577

56-
self.includes = []
57-
self.headers = []
58-
self.namespaces = []
59-
self.globals = []
60-
61-
# Stash immutable information useful for one or more compiler passes
78+
super().__init__({iet.name: iet})
6279

6380
# All written user-level objects
6481
writes = FindSymbols('writes').visit(iet)
@@ -79,10 +96,6 @@ def __init__(self, iet, options=None, sregistry=None, **kwargs):
7996
def root(self):
8097
return self.efuncs[list(self.efuncs).pop(0)]
8198

82-
@property
83-
def funcs(self):
84-
return tuple(MetaCall(v, True) for v in self.efuncs.values())[1:]
85-
8699
@property
87100
def sync_mapper(self):
88101
"""
@@ -146,7 +159,7 @@ def apply(self, func, **kwargs):
146159
new_efuncs = metadata.get('efuncs', [])
147160

148161
efuncs[i] = efunc
149-
efuncs.update(OrderedDict([(i.name, i) for i in new_efuncs]))
162+
efuncs.update(dict([(i.name, i) for i in new_efuncs]))
150163

151164
# Update the parameters / arguments lists since `func` may have
152165
# introduced or removed objects
@@ -176,10 +189,24 @@ def visit(self, func, **kwargs):
176189
dag = create_call_graph(self.root.name, self.efuncs)
177190
toposort = dag.topological_sort()
178191

179-
mapper = OrderedDict([(i, func(self.efuncs[i], **kwargs)) for i in toposort])
192+
mapper = dict([(i, func(self.efuncs[i], **kwargs)) for i in toposort])
180193

181194
return mapper
182195

196+
def filter(self, key):
197+
"""
198+
Return a Byproduct containing only the Callables in the Graph
199+
for which `key` evaluates to True. The resulting object cannot be
200+
further modified by an IET pass.
201+
"""
202+
return Byproduct(
203+
efuncs={i: v for i, v in self.efuncs.items() if key(v)},
204+
includes=as_tuple(self.includes),
205+
headers=as_tuple(self.headers),
206+
namespaces=as_tuple(self.namespaces),
207+
globs=as_tuple(self.globals)
208+
)
209+
183210

184211
def iet_pass(func):
185212
if isinstance(func, tuple):
@@ -732,7 +759,7 @@ def _filter(v, efunc=None):
732759

733760
return processed
734761

735-
efuncs = OrderedDict(efuncs)
762+
efuncs = dict(efuncs)
736763
efuncs[root.name] = root._rebuild(parameters=_filter(root.parameters, root))
737764

738765
# Update all call sites to use the new signature

0 commit comments

Comments
 (0)