Skip to content

Commit 5f89aed

Browse files
committed
api: enforce pow_to_mul to be un-evaluable
1 parent 24faf28 commit 5f89aed

File tree

8 files changed

+28
-18
lines changed

8 files changed

+28
-18
lines changed

devito/core/cpu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@ def _normalize_kwargs(cls, **kwargs):
3232
o['mpi'] = oo.pop('mpi')
3333
o['parallel'] = o['openmp'] # Backwards compatibility
3434

35-
# Minimum scalar type
36-
o['scalar-min-type'] = oo.pop('scalar-min-type', cls.SCALAR_MIN_TYPE)
37-
3835
# Buffering
3936
o['buf-async-degree'] = oo.pop('buf-async-degree', None)
4037

@@ -87,6 +84,7 @@ def _normalize_kwargs(cls, **kwargs):
8784
o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE)
8885
o['place-transfers'] = oo.pop('place-transfers', True)
8986
o['errctl'] = oo.pop('errctl', cls.ERRCTL)
87+
o['scalar-min-type'] = oo.pop('scalar-min-type', cls.SCALAR_MIN_TYPE)
9088

9189
# Recognised but unused by the CPU backend
9290
oo.pop('par-disabled', None)

devito/core/gpu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ def _normalize_kwargs(cls, **kwargs):
4040
o['mpi'] = oo.pop('mpi')
4141
o['parallel'] = True
4242

43-
# Minimum scalar type
44-
o['scalar-min-type'] = oo.pop('scalar-min-type', cls.SCALAR_MIN_TYPE)
45-
4643
# Buffering
4744
o['buf-async-degree'] = oo.pop('buf-async-degree', None)
4845

@@ -98,6 +95,7 @@ def _normalize_kwargs(cls, **kwargs):
9895
o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE)
9996
o['place-transfers'] = oo.pop('place-transfers', True)
10097
o['errctl'] = oo.pop('errctl', cls.ERRCTL)
98+
o['scalar-min-type'] = oo.pop('scalar-min-type', cls.SCALAR_MIN_TYPE)
10199

102100
if oo:
103101
raise InvalidOperator("Unsupported optimization options: [%s]"

devito/ir/iet/visitors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,19 +629,19 @@ def visit_Lambda(self, o):
629629
if body:
630630
body.append(c.Line())
631631
body.extend(as_tuple(v))
632+
632633
captures = [str(i) for i in o.captures]
633634
decls = [i.inline() for i in self._args_decl(o.parameters)]
635+
634636
extra = []
635637
if o.special:
636638
extra.append(' ')
637639
extra.append(' '.join(str(i) for i in o.special))
638640
if o.attributes:
639641
extra.append(' ')
640642
extra.append(' '.join(f'[[{i}]]' for i in o.attributes))
641-
ccapt = ', '.join(captures)
642-
cdecls = ', '.join(decls)
643-
cextra = ''.join(extra)
644-
top = c.Line(f'[{ccapt}]({cdecls}){cextra}')
643+
644+
top = c.Line(f"[{', '.join(captures)}]({', '.join(decls)}){''.join(extra)}")
645645
return LambdaCollection([top, c.Block(body)])
646646

647647
def visit_HaloSpot(self, o):

devito/passes/iet/misc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from devito.finite_differences import Max, Min
88
from devito.finite_differences.differentiable import SafeInv
9+
from devito.logger import warning
910
from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder,
1011
FindApplications, FindNodes, FindSymbols, Transformer,
1112
Uxreplace, filter_iterations, retrieve_iteration_tree,
@@ -144,13 +145,12 @@ def generate_macros(graph, **kwargs):
144145

145146

146147
@iet_pass
147-
def _generate_macros(iet, tracker=None, langbb=None, **kwargs):
148+
def _generate_macros(iet, tracker=None, langbb=None, printer=CPrinter, **kwargs):
148149
# Derive the Macros necessary for the FIndexeds
149150
iet = _generate_macros_findexeds(iet, tracker=tracker, **kwargs)
150151

151152
# NOTE: sorting is necessary to ensure deterministic code generation
152153
headers = [i.header for i in tracker.values()]
153-
printer = kwargs.get('printer', CPrinter)
154154
headers = sorted((printer()._print(define), printer()._print(expr))
155155
for define, expr in headers)
156156

@@ -238,7 +238,7 @@ def _(expr, langbb):
238238
try:
239239
eps = np.finfo(expr.base.dtype).resolution**2
240240
except ValueError:
241-
print(f"Warning: dtype not recognized in SafeInv for {expr.base}")
241+
warning(f"Warning: dtype not recognized in SafeInv for {expr.base}")
242242
eps = np.finfo(np.float32).resolution**2
243243
b = Cast('b', dtype=np.float32)
244244
return (('SAFEINV(a, b)',

devito/symbolics/manipulation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from devito.symbolics.extended_sympy import DefFunction, rfunc
1414
from devito.symbolics.queries import q_leaf
1515
from devito.symbolics.search import retrieve_indexed, retrieve_functions
16+
from devito.symbolics.unevaluation import Mul as UMul, Pow as UPow
1617
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
1718
from devito.types.basic import Basic, Indexed
1819
from devito.types.array import ComponentAccess
@@ -337,13 +338,13 @@ def pow_to_mul(expr):
337338
# but at least we traverse the base looking for other Pows
338339
return expr.func(pow_to_mul(base), exp, evaluate=False)
339340
elif exp > 0:
340-
return Mul(*[pow_to_mul(base)]*int(exp), evaluate=False)
341+
return UMul(*[pow_to_mul(base)]*int(exp), evaluate=False)
341342
else:
342343
# SymPy represents 1/x as Pow(x,-1). Also, it represents
343344
# 2/x as Mul(2, Pow(x, -1)). So we shouldn't end up here,
344345
# but just in case SymPy changes its internal conventions...
345346
posexpr = Mul(*[base]*(-int(exp)), evaluate=False)
346-
return Pow(posexpr, -1, evaluate=False)
347+
return UPow(posexpr, -1, evaluate=False)
347348
else:
348349
args = [pow_to_mul(i) for i in expr.args]
349350

devito/types/grid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class Grid(CartesianDiscretization, ArgProvider):
7979
shape : tuple of ints
8080
Shape of the computational domain in grid points.
8181
extent : tuple of values interpretable as dtype, default=unit box of extent 1m
82-
in all dimensions
82+
in all dimensions.
8383
Physical extent of the domain in m.
8484
origin : tuple of values interpretable as dtype, default=0.0 in all dimensions
8585
Physical coordinate of the origin of the domain.

examples/compiler/03_iet-A.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@
227227
"cell_type": "markdown",
228228
"metadata": {},
229229
"source": [
230-
"In this example, `op` is represented as a `<Callable Kernel>`. Attached to it are metadata, such as `_headers` and `_includes`, as well as the `body`, which includes the children IET nodes. Here, the body is the concatenation of an `PointerCast` and a `List` object.\n"
230+
"In this example, `op` is represented as a `<Callable Kernel>`. Attached to it are metadata, such as `headers` and `includes`, as well as the `body`, which includes the children IET nodes. Here, the body is the concatenation of an `PointerCast` and a `List` object.\n"
231231
]
232232
},
233233
{

tests/test_pickle.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from devito.types.basic import BoundSymbol, AbstractSymbol
2424
from devito.tools import EnrichedTuple
2525
from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer,
26-
CallFromPointer, DefFunction, Cast, SizeOf)
26+
CallFromPointer, DefFunction, Cast, SizeOf,
27+
pow_to_mul)
2728
from examples.seismic import (demo_model, AcquisitionGeometry,
2829
TimeAxis, RickerSource, Receiver)
2930

@@ -609,6 +610,18 @@ def test_SizeOf(self, pickle, typ):
609610

610611
assert un == new_un
611612

613+
def test_pow_to_mul(self, pickle):
614+
grid = Grid(shape=(3,))
615+
f = Function(name='f', grid=grid)
616+
expr = pow_to_mul(f ** 2)
617+
618+
assert expr.is_Mul
619+
620+
pkl_expr = pickle.dumps(expr)
621+
new_expr = pickle.loads(pkl_expr)
622+
623+
assert new_expr.is_Mul
624+
612625

613626
class TestAdvanced:
614627

0 commit comments

Comments
 (0)