Skip to content

Commit 5a6d48e

Browse files
authored
Merge pull request #2530 from devitocodes/custom-coeff-spacing2
api: Relax custom coefficient spacing enforcement
2 parents e8bf1fb + afdde39 commit 5a6d48e

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

devito/finite_differences/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,5 +339,5 @@ def process_weights(weights, expr, dim):
339339
return shape[weights.dimensions.index(wdim)], wdim, False
340340
else:
341341
# Adimensional weight from custom coeffs need to be multiplied by h^order
342-
scale = not all(sympify(w).has(dim.spacing) for w in weights if w != 0)
342+
scale = all(sympify(w).is_Number for w in weights)
343343
return len(list(weights)), None, scale

tests/test_symbolic_coefficients.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,31 @@ def test_backward_compat_mixed(self):
372372
assert '7.0*f(x + 3*h_x)' in str(eqe.rhs)
373373
assert '0.5*g(x + h_x)' in str(eqe.rhs)
374374
assert 'g(x + 2*h_x)' not in str(eqe.rhs)
375+
376+
def test_backward_compat_array_of_func(self):
377+
grid = Grid(shape=(11, 11, 11))
378+
x, _, _ = grid.dimensions
379+
hx = x.spacing
380+
381+
f = Function(name='f', grid=grid, space_order=16, coefficients='symbolic')
382+
383+
# Define stencil coefficients.
384+
weights = Function(name="w", space_order=0, shape=(9,), dimensions=(x,))
385+
wdx = [weights[0]]
386+
for iq in range(1, weights.shape[0]):
387+
wdx.append(weights[iq])
388+
wdx.insert(0, weights[iq])
389+
390+
# Plain numbers for comparison
391+
wdxn = np.random.rand(17)
392+
393+
# Number with spacing
394+
wdxns = wdxn / hx
395+
396+
dexev = f.dx(weights=wdx).evaluate
397+
dexevn = f.dx(weights=wdxn).evaluate
398+
dexevns = f.dx(weights=wdxns).evaluate
399+
400+
assert all(a.as_coefficient(1/hx) for a in dexevn.args)
401+
assert all(a.as_coefficient(1/hx) for a in dexevns.args)
402+
assert all(not a.as_coefficient(1/hx) for a in dexev.args)

0 commit comments

Comments
 (0)