Skip to content

Commit c5ef5d1

Browse files
authored
Cache numerical Interpolator on the symbolic Interpolate expression (#4827)
* Cache numerical Interpolator on the symbolic Interpolate expression * cached_property
1 parent c84b612 commit c5ef5d1

File tree

2 files changed

+57
-42
lines changed

2 files changed

+57
-42
lines changed

firedrake/interpolation.py

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from tsfc.driver import compile_expression_dual_evaluation
3535
from tsfc.ufl_utils import extract_firedrake_constants, hash_expr
3636

37-
from firedrake.utils import IntType, ScalarType, known_pyop2_safe, tuplify
37+
from firedrake.utils import IntType, ScalarType, cached_property, known_pyop2_safe, tuplify
3838
from firedrake.tsfc_interface import extract_numbered_coefficients, _cachedir
3939
from firedrake.ufl_expr import Argument, Coargument, action
4040
from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology, MeshGeometry, MeshTopology, VertexOnlyMesh
@@ -155,6 +155,57 @@ def options(self) -> InterpolateOptions:
155155
"""
156156
return self._options
157157

158+
@cached_property
159+
def _interpolator(self):
160+
"""Access the numerical interpolator.
161+
162+
Returns
163+
-------
164+
Interpolator
165+
An appropriate :class:`Interpolator` subclass for this
166+
interpolation expression.
167+
"""
168+
arguments = self.arguments()
169+
has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments)
170+
if len(arguments) == 2 and has_mixed_arguments:
171+
return MixedInterpolator(self)
172+
173+
operand, = self.ufl_operands
174+
target_mesh = self.target_space.mesh()
175+
176+
try:
177+
source_mesh = extract_unique_domain(operand) or target_mesh
178+
except ValueError:
179+
raise NotImplementedError(
180+
"Interpolating an expression with no arguments defined on multiple meshes is not implemented yet."
181+
)
182+
183+
try:
184+
target_mesh = target_mesh.unique()
185+
source_mesh = source_mesh.unique()
186+
except RuntimeError:
187+
return MixedInterpolator(self)
188+
189+
submesh_interp_implemented = (
190+
all(isinstance(m.topology, MeshTopology) for m in [target_mesh, source_mesh])
191+
and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1]
192+
and target_mesh.topological_dimension == source_mesh.topological_dimension
193+
)
194+
if target_mesh is source_mesh or submesh_interp_implemented:
195+
return SameMeshInterpolator(self)
196+
197+
if isinstance(target_mesh.topology, VertexOnlyMeshTopology):
198+
if isinstance(source_mesh.topology, VertexOnlyMeshTopology):
199+
return VomOntoVomInterpolator(self)
200+
if target_mesh.geometric_dimension != source_mesh.geometric_dimension:
201+
raise ValueError("Cannot interpolate onto a VertexOnlyMesh of a different geometric dimension.")
202+
return SameMeshInterpolator(self)
203+
204+
if has_mixed_arguments or len(self.target_space) > 1:
205+
return MixedInterpolator(self)
206+
207+
return CrossMeshInterpolator(self)
208+
158209

159210
@PETSc.Log.EventDecorator()
160211
def interpolate(expr: Expr, V: WithGeometry | BaseForm, **kwargs) -> Interpolate:
@@ -353,46 +404,7 @@ def get_interpolator(expr: Interpolate) -> Interpolator:
353404
An appropriate :class:`Interpolator` subclass for the given
354405
interpolation expression.
355406
"""
356-
arguments = expr.arguments()
357-
has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments)
358-
if len(arguments) == 2 and has_mixed_arguments:
359-
return MixedInterpolator(expr)
360-
361-
operand, = expr.ufl_operands
362-
target_mesh = expr.target_space.mesh()
363-
364-
try:
365-
source_mesh = extract_unique_domain(operand) or target_mesh
366-
except ValueError:
367-
raise NotImplementedError(
368-
"Interpolating an expression with no arguments defined on multiple meshes is not implemented yet."
369-
)
370-
371-
try:
372-
target_mesh = target_mesh.unique()
373-
source_mesh = source_mesh.unique()
374-
except RuntimeError:
375-
return MixedInterpolator(expr)
376-
377-
submesh_interp_implemented = (
378-
all(isinstance(m.topology, MeshTopology) for m in [target_mesh, source_mesh])
379-
and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1]
380-
and target_mesh.topological_dimension == source_mesh.topological_dimension
381-
)
382-
if target_mesh is source_mesh or submesh_interp_implemented:
383-
return SameMeshInterpolator(expr)
384-
385-
if isinstance(target_mesh.topology, VertexOnlyMeshTopology):
386-
if isinstance(source_mesh.topology, VertexOnlyMeshTopology):
387-
return VomOntoVomInterpolator(expr)
388-
if target_mesh.geometric_dimension != source_mesh.geometric_dimension:
389-
raise ValueError("Cannot interpolate onto a VertexOnlyMesh of a different geometric dimension.")
390-
return SameMeshInterpolator(expr)
391-
392-
if has_mixed_arguments or len(expr.target_space) > 1:
393-
return MixedInterpolator(expr)
394-
395-
return CrossMeshInterpolator(expr)
407+
return expr._interpolator
396408

397409

398410
class DofNotDefinedError(Exception):

tests/firedrake/regression/test_interpolate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,9 +592,12 @@ def test_interpolator_reuse(family, degree, mode):
592592
u = Function(V.dual())
593593
expr = interpolate(TestFunction(V), u)
594594

595-
I = get_interpolator(expr)
595+
Iorig = get_interpolator(expr)
596596

597597
for k in range(3):
598+
I = get_interpolator(expr)
599+
assert I is Iorig
600+
598601
u.assign(rg.uniform(u.function_space()))
599602
expected = u.dat.data.copy()
600603

0 commit comments

Comments
 (0)