Skip to content

Commit f566861

Browse files
committed
Interpolate: Fix constant and coefficient ordering
1 parent 6e10ebf commit f566861

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

firedrake/assemble.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
553553
arg_expression = ufl.algorithms.extract_arguments(expression)
554554

555555
# Handle interpolation onto subfunctions
556-
if (arg_expression and len(arg_expression[0].function_space()) > 1
557-
and not isinstance(expression, firedrake.Argument)):
556+
if arg_expression and len(arg_expression[0].function_space()) > 1:
558557
assert rank == 1
559558
V = arg_expression[0].function_space()
560559
if tensor is not None:

firedrake/interpolation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,10 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
11961196
if needs_cell_sizes:
11971197
cs = target_mesh.cell_sizes
11981198
parloop_args.append(cs.dat(op2.READ, cs.cell_node_map()))
1199+
1200+
for const in extract_firedrake_constants(expr):
1201+
parloop_args.append(const.dat(op2.READ))
1202+
11991203
for coefficient in coefficients:
12001204
if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology):
12011205
coeff_mesh = extract_unique_domain(coefficient)
@@ -1224,9 +1228,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
12241228
m_ = coefficient.function_space().entity_node_map(target_mesh.topology, "cell", None, None)
12251229
parloop_args.append(coefficient.dat(op2.READ, m_))
12261230

1227-
for const in extract_firedrake_constants(expr):
1228-
parloop_args.append(const.dat(op2.READ))
1229-
12301231
parloop = op2.ParLoop(*parloop_args)
12311232
parloop_compute_callable = parloop.compute
12321233
if isinstance(tensor, op2.Mat):

tests/firedrake/regression/test_interp_dual.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_solve_interp_u(mesh):
280280

281281

282282
@pytest.mark.parametrize("family0,degree0,family1,degree1",
283-
[("DG", 1, "CG", 1),
283+
[("DG", 1, "CG", 2),
284284
("DG", 0, "RT", 1)])
285285
def test_interp_subfunction(mesh, family0, degree0, family1, degree1):
286286
V = FunctionSpace(mesh, "DG", 0)
@@ -295,7 +295,7 @@ def test_interp_subfunction(mesh, family0, degree0, family1, degree1):
295295

296296
expr = sum(w[i] for i in np.ndindex(w.ufl_shape))
297297

298-
Fw = inner(1, expr)*dx
298+
Fw = inner(1, expr)*dx(degree=0)
299299
expected = assemble(Fw)
300300

301301
IFv = Interpolate(expr, Fv)
@@ -306,5 +306,7 @@ def test_interp_subfunction(mesh, family0, degree0, family1, degree1):
306306

307307
result = assemble(IFv, tensor=tensor)
308308
assert result.function_space() == W.dual()
309-
with result.dat.vec_ro as x, expected.dat.vec_ro as y:
310-
assert np.allclose(x[:], y[:])
309+
if tensor:
310+
assert result is tensor
311+
for x, y, in zip(result.subfunctions, expected.subfunctions):
312+
assert np.allclose(x.dat.data_ro, y.dat.data_ro)

0 commit comments

Comments
 (0)