|
34 | 34 | from tsfc.driver import compile_expression_dual_evaluation |
35 | 35 | from tsfc.ufl_utils import extract_firedrake_constants, hash_expr |
36 | 36 |
|
37 | | -from firedrake.utils import IntType, ScalarType, known_pyop2_safe, tuplify |
| 37 | +from firedrake.utils import IntType, ScalarType, cached_property, known_pyop2_safe, tuplify |
38 | 38 | from firedrake.tsfc_interface import extract_numbered_coefficients, _cachedir |
39 | 39 | from firedrake.ufl_expr import Argument, Coargument, action |
40 | 40 | from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology, MeshGeometry, MeshTopology, VertexOnlyMesh |
@@ -155,6 +155,57 @@ def options(self) -> InterpolateOptions: |
155 | 155 | """ |
156 | 156 | return self._options |
157 | 157 |
|
| 158 | + @cached_property |
| 159 | + def _interpolator(self): |
| 160 | + """Access the numerical interpolator. |
| 161 | +
|
| 162 | + Returns |
| 163 | + ------- |
| 164 | + Interpolator |
| 165 | + An appropriate :class:`Interpolator` subclass for this |
| 166 | + interpolation expression. |
| 167 | + """ |
| 168 | + arguments = self.arguments() |
| 169 | + has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) |
| 170 | + if len(arguments) == 2 and has_mixed_arguments: |
| 171 | + return MixedInterpolator(self) |
| 172 | + |
| 173 | + operand, = self.ufl_operands |
| 174 | + target_mesh = self.target_space.mesh() |
| 175 | + |
| 176 | + try: |
| 177 | + source_mesh = extract_unique_domain(operand) or target_mesh |
| 178 | + except ValueError: |
| 179 | + raise NotImplementedError( |
| 180 | + "Interpolating an expression with no arguments defined on multiple meshes is not implemented yet." |
| 181 | + ) |
| 182 | + |
| 183 | + try: |
| 184 | + target_mesh = target_mesh.unique() |
| 185 | + source_mesh = source_mesh.unique() |
| 186 | + except RuntimeError: |
| 187 | + return MixedInterpolator(self) |
| 188 | + |
| 189 | + submesh_interp_implemented = ( |
| 190 | + all(isinstance(m.topology, MeshTopology) for m in [target_mesh, source_mesh]) |
| 191 | + and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] |
| 192 | + and target_mesh.topological_dimension == source_mesh.topological_dimension |
| 193 | + ) |
| 194 | + if target_mesh is source_mesh or submesh_interp_implemented: |
| 195 | + return SameMeshInterpolator(self) |
| 196 | + |
| 197 | + if isinstance(target_mesh.topology, VertexOnlyMeshTopology): |
| 198 | + if isinstance(source_mesh.topology, VertexOnlyMeshTopology): |
| 199 | + return VomOntoVomInterpolator(self) |
| 200 | + if target_mesh.geometric_dimension != source_mesh.geometric_dimension: |
| 201 | + raise ValueError("Cannot interpolate onto a VertexOnlyMesh of a different geometric dimension.") |
| 202 | + return SameMeshInterpolator(self) |
| 203 | + |
| 204 | + if has_mixed_arguments or len(self.target_space) > 1: |
| 205 | + return MixedInterpolator(self) |
| 206 | + |
| 207 | + return CrossMeshInterpolator(self) |
| 208 | + |
158 | 209 |
|
159 | 210 | @PETSc.Log.EventDecorator() |
160 | 211 | def interpolate(expr: Expr, V: WithGeometry | BaseForm, **kwargs) -> Interpolate: |
@@ -353,46 +404,7 @@ def get_interpolator(expr: Interpolate) -> Interpolator: |
353 | 404 | An appropriate :class:`Interpolator` subclass for the given |
354 | 405 | interpolation expression. |
355 | 406 | """ |
356 | | - arguments = expr.arguments() |
357 | | - has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) |
358 | | - if len(arguments) == 2 and has_mixed_arguments: |
359 | | - return MixedInterpolator(expr) |
360 | | - |
361 | | - operand, = expr.ufl_operands |
362 | | - target_mesh = expr.target_space.mesh() |
363 | | - |
364 | | - try: |
365 | | - source_mesh = extract_unique_domain(operand) or target_mesh |
366 | | - except ValueError: |
367 | | - raise NotImplementedError( |
368 | | - "Interpolating an expression with no arguments defined on multiple meshes is not implemented yet." |
369 | | - ) |
370 | | - |
371 | | - try: |
372 | | - target_mesh = target_mesh.unique() |
373 | | - source_mesh = source_mesh.unique() |
374 | | - except RuntimeError: |
375 | | - return MixedInterpolator(expr) |
376 | | - |
377 | | - submesh_interp_implemented = ( |
378 | | - all(isinstance(m.topology, MeshTopology) for m in [target_mesh, source_mesh]) |
379 | | - and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] |
380 | | - and target_mesh.topological_dimension == source_mesh.topological_dimension |
381 | | - ) |
382 | | - if target_mesh is source_mesh or submesh_interp_implemented: |
383 | | - return SameMeshInterpolator(expr) |
384 | | - |
385 | | - if isinstance(target_mesh.topology, VertexOnlyMeshTopology): |
386 | | - if isinstance(source_mesh.topology, VertexOnlyMeshTopology): |
387 | | - return VomOntoVomInterpolator(expr) |
388 | | - if target_mesh.geometric_dimension != source_mesh.geometric_dimension: |
389 | | - raise ValueError("Cannot interpolate onto a VertexOnlyMesh of a different geometric dimension.") |
390 | | - return SameMeshInterpolator(expr) |
391 | | - |
392 | | - if has_mixed_arguments or len(expr.target_space) > 1: |
393 | | - return MixedInterpolator(expr) |
394 | | - |
395 | | - return CrossMeshInterpolator(expr) |
| 407 | + return expr._interpolator |
396 | 408 |
|
397 | 409 |
|
398 | 410 | class DofNotDefinedError(Exception): |
|
0 commit comments