Skip to content

Commit f580a6e

Browse files
committed
api: remove obsolete and clunky is_parameter
1 parent 6771d4c commit f580a6e

File tree

6 files changed

+18
-40
lines changed

6 files changed

+18
-40
lines changed

devito/finite_differences/differentiable.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,20 +1131,8 @@ def _(expr, x0, **kwargs):
11311131

11321132
@interp_for_fd.register(AbstractFunction)
11331133
def _(expr, x0, **kwargs):
1134-
from devito.finite_differences.derivative import Derivative
1135-
x0_expr = {d: v for d, v in x0.items() if v is not expr.indices_ref[d]}
1136-
if expr.is_parameter:
1137-
# Parameter might not have been evaluated at x0 yet
1138-
# E.g., in expressions such as u[x]*f[x] evaluated at x+dx/2
1139-
# `f` will not be evaluated at it since gather_for_diff will pick `u` as
1140-
# higher priority function.
1141-
for d, v in x0.items():
1142-
if expr.indices[d] is d:
1143-
expr = expr._subs(d, v)
1144-
return expr
1145-
elif x0_expr:
1146-
dims = tuple((d, 0) for d in x0_expr)
1147-
fd_o = tuple([expr.interp_order]*len(dims))
1148-
return Derivative(expr, *dims, fd_order=fd_o, x0=x0_expr)
1134+
x0_expr = {d: v for d, v in x0.items() if v.has(d)}
1135+
if x0_expr:
1136+
return expr.subs({expr.indices[d]: v for d, v in x0_expr.items()})
11491137
else:
11501138
return expr

devito/types/dense.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ class Function(DiscreteFunction):
10191019
is_autopaddable = True
10201020

10211021
__rkwargs__ = (DiscreteFunction.__rkwargs__ +
1022-
('space_order', 'interp_order', 'dimensions', 'is_parameter'))
1022+
('space_order', 'interp_order', 'dimensions'))
10231023

10241024
def _cache_meta(self):
10251025
# Attach additional metadata to self's cache entry
@@ -1060,12 +1060,6 @@ def __init_finalize__(self, *args, **kwargs):
10601060
# can clearly avoid that here though!
10611061
self._fd = self.function._fd
10621062

1063-
# Flag whether it is a parameter or a variable.
1064-
# Used at operator evaluation to evaluate the Function at the
1065-
# variable location (i.e. if the variable is staggered in x the
1066-
# parameter has to be computed at x + hx/2)
1067-
self._is_parameter = kwargs.get('parameter', kwargs.get('is_parameter', False))
1068-
10691063
def __fd_setup__(self):
10701064
"""
10711065
Dynamically add derivative short-cuts.
@@ -1076,12 +1070,8 @@ def __fd_setup__(self):
10761070
def _fd_priority(self):
10771071
return 1 if self.staggered.on_node else 2
10781072

1079-
@property
1080-
def is_parameter(self):
1081-
return self._is_parameter
1082-
10831073
def _eval_at(self, func):
1084-
if not self.is_parameter or self.staggered == func.staggered:
1074+
if self.staggered == func.staggered:
10851075
return self
10861076
mapper = {self.indices_ref[d]: func.indices_ref[d]
10871077
for d in self.dimensions

tests/test_derivatives.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def test_param_stagg_inner(self):
773773
xp = x + x.spacing / 2
774774

775775
f = TimeFunction(name="f", grid=grid, space_order=space_order, staggered=y)
776-
p = Function(name="p", grid=grid, space_order=space_order, parameter=True)
776+
p = Function(name="p", grid=grid, space_order=space_order)
777777
g = TimeFunction(name="g", grid=grid, space_order=space_order, staggered=(x, y))
778778

779779
eqn = Eq(g, (p * f).dx)
@@ -798,8 +798,8 @@ def test_param_stagg_add(self):
798798
time_order=1, staggered=NODE)
799799
txy = TimeFunction(name="txy", grid=grid, space_order=space_order,
800800
time_order=1, staggered=(x, y))
801-
c11 = Function(name="c11", grid=grid, space_order=space_order, parameter=True)
802-
c66 = Function(name="c66", grid=grid, space_order=space_order, parameter=True)
801+
c11 = Function(name="c11", grid=grid, space_order=space_order)
802+
c66 = Function(name="c66", grid=grid, space_order=space_order)
803803

804804
eq0 = Eq(vx, (c66 * txy).dy)
805805
eq1 = Eq(vx, (c11 * txx).dy)

tests/test_differentiable.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,23 +109,23 @@ def test_avg_mode(ndim, io):
109109

110110
with pytest.raises(ValueError):
111111
# interp_order > space_order
112-
Function(name="a", grid=grid, parameter=True, interp_order=8, space_order=4)
112+
Function(name="a", grid=grid, interp_order=8, space_order=4)
113113
with pytest.raises(ValueError):
114114
# interp_order < 1
115-
Function(name="a", grid=grid, parameter=True, interp_order=0, space_order=4)
115+
Function(name="a", grid=grid, interp_order=0, space_order=4)
116116
with pytest.raises(TypeError):
117117
# interp_order not int
118-
Function(name="a", grid=grid, parameter=True, interp_order=2.5, space_order=4)
118+
Function(name="a", grid=grid, interp_order=2.5, space_order=4)
119119

120120
a0 = Function(name="a0", grid=grid, **kw)
121-
a = Function(name="a", grid=grid, parameter=True, **kw)
122-
b = Function(name="b", grid=grid, parameter=True, avg_mode='harmonic', **kw)
121+
a = Function(name="a", grid=grid, **kw)
122+
b = Function(name="b", grid=grid, avg_mode='harmonic', **kw)
123123

124124
a0_avg = a0._eval_at(v)
125125
a_avg = a._eval_at(v).evaluate.simplify()
126126
b_avg = b._eval_at(v).evaluate.simplify()
127127

128-
assert a0_avg == a0
128+
assert a0_avg == a0.subs(v.indices_ref.getters)
129129

130130
# Indices around the point at the center of a cell
131131
idx = list(range(-io//2 + 1, io//2 + 1))

tests/test_dse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1588,7 +1588,7 @@ def test_extraction_from_lifted_ispace(self, rotate):
15881588
so = 8
15891589
grid = Grid(shape=(6, 6, 6))
15901590

1591-
f = Function(name='f', grid=grid, space_order=so, parameter=True)
1591+
f = Function(name='f', grid=grid, space_order=so)
15921592
v = TimeFunction(name="v", grid=grid, space_order=so)
15931593
v1 = TimeFunction(name="v1", grid=grid, space_order=so)
15941594
p = TimeFunction(name="p", grid=grid, space_order=so, staggered=NODE)

tests/test_staggered_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_is_param(ndim):
6868
var = Function(name="f", grid=grid, staggered=NODE)
6969
for d in dims:
7070
f = Function(name="f", grid=grid, staggered=d)
71-
f2 = Function(name="f2", grid=grid, staggered=d, parameter=True)
71+
f2 = Function(name="f2", grid=grid, staggered=d)
7272

7373
# Not a parameter stay untouched (or FD would be destroyed by _eval_at)
7474
assert f._eval_at(var).evaluate == f
@@ -91,7 +91,7 @@ def test_gather_for_diff(expr, expected):
9191
y0 = y + y.spacing/2 # noqa
9292
a = Function(name="a", grid=grid, staggered=NODE) # noqa
9393
b = Function(name="b", grid=grid, staggered=x) # noqa
94-
c = Function(name="c", grid=grid, staggered=y, parameter=True) # noqa
94+
c = Function(name="c", grid=grid, staggered=y) # noqa
9595
d = Function(name="d", grid=grid) # noqa
9696

9797
assert eval(expr) == eval(expected)
@@ -143,7 +143,7 @@ def test_staggered_div():
143143
v[0].data[:] = 5.
144144
v[1].data[:] = 5.
145145

146-
A = Function(name="A", grid=grid, space_order=4, staggred=NODE, parameter=True)
146+
A = Function(name="A", grid=grid, space_order=4, staggred=NODE)
147147
A._data_with_outhalo[:] = .5
148148

149149
av = VectorTimeFunction(name="av", grid=grid, time_order=1, space_order=4)

0 commit comments

Comments
 (0)