Skip to content

Commit b3eff8b

Browse files
MTCaminducer
andcommitted
Enable dt estimate for quads/hexes, extend tests for it
Co-authored-by: Andreas Kloeckner <[email protected]>
1 parent c2be523 commit b3eff8b

File tree

3 files changed

+50
-22
lines changed

3 files changed

+50
-22
lines changed

grudge/dt_utils.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -229,20 +229,35 @@ def h_min_from_volume(
229229
def dt_geometric_factors(
230230
dcoll: DiscretizationCollection, dd: Optional[DOFDesc] = None) -> DOFArray:
231231
r"""Computes a geometric scaling factor for each cell following
232-
[Hesthaven_2008]_, section 6.4, defined as the inradius (radius of an
233-
inscribed circle/sphere).
232+
[Hesthaven_2008]_, section 6.4, For simplicial elemenents, this factor is
233+
defined as the inradius (radius of an inscribed circle/sphere). For
234+
non-simplicial elements, a mean length measure is returned.
234235
235-
Specifically, the inradius for each element is computed using the following
236-
formula from [Shewchuk_2002]_, Table 1, for simplicial cells
237-
(triangles/tetrahedra):
236+
Specifically, the inradius for each simplicial element is computed using the
237+
following formula from [Shewchuk_2002]_, Table 1 (triangles, tetrahedra):
238238
239239
.. math::
240240
241-
r_D = \frac{d V}{\sum_{i=1}^{N_{faces}} F_i},
241+
r_D = \frac{d~V}{\sum_{i=1}^{N_{faces}} F_i},
242242
243243
where :math:`d` is the topological dimension, :math:`V` is the cell volume,
244244
and :math:`F_i` are the areas of each face of the cell.
245245
246+
For non-simplicial elements, we use the following formula for a mean
247+
cell size measure:
248+
249+
.. math::
250+
251+
r_D = \frac{2~d~V}{\sum_{i=1}^{N_{faces}} F_i},
252+
253+
where :math:`d` is the topological dimension, :math:`V` is the cell volume,
254+
and :math:`F_i` are the areas of each face of the cell. Other valid choices
255+
here include the shortest, longest, average of the cell diagonals, or edges.
256+
The value returned by this routine (i.e. the cell volume divided by the
257+
average cell face area) is bounded by the extrema of the cell edge lengths,
258+
is straightforward to calculate regardless of element shape, and jibes well
259+
with the foregoing calculation for simplicial elements.
260+
246261
:arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one.
247262
Defaults to the base volume discretization if not provided.
248263
:returns: a frozen :class:`~meshmode.dof_array.DOFArray` containing the
@@ -256,11 +271,10 @@ def dt_geometric_factors(
256271
actx = dcoll._setup_actx
257272
volm_discr = dcoll.discr_from_dd(dd)
258273

274+
r_fac = dcoll.dim
259275
if any(not isinstance(grp, SimplexElementGroupBase)
260276
for grp in volm_discr.groups):
261-
raise NotImplementedError(
262-
"Geometric factors are only implemented for simplex element groups"
263-
)
277+
r_fac = 2.0*r_fac
264278

265279
if volm_discr.dim != volm_discr.ambient_dim:
266280
from warnings import warn
@@ -342,7 +356,7 @@ def dt_geometric_factors(
342356
"e,ei->ei",
343357
1/sae_i,
344358
actx.tag_axis(1, DiscretizationDOFAxisTag(), cv_i),
345-
tagged=(FirstAxisIsElementsTag(),)) * dcoll.dim
359+
tagged=(FirstAxisIsElementsTag(),)) * r_fac
346360
for cv_i, sae_i in zip(cell_vols, surface_areas)))))
347361

348362
# }}}

test/mesh_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def get_mesh(self, resolution, mesh_order):
8686

8787
class BoxMeshBuilder(MeshBuilder):
8888
ambient_dim = 2
89+
group_cls = None
8990

9091
mesh_order = 1
9192
resolutions = [4, 8, 16]
@@ -100,6 +101,7 @@ def get_mesh(self, resolution, mesh_order):
100101
return mgen.generate_regular_rect_mesh(
101102
a=self.a, b=self.b,
102103
nelements_per_axis=resolution,
104+
group_cls=self.group_cls,
103105
order=mesh_order)
104106

105107

test/test_dt_utils.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
[PytestPyOpenCLArrayContextFactory,
3434
PytestPytatoPyOpenCLArrayContextFactory])
3535

36-
from grudge import DiscretizationCollection
36+
from grudge import make_discretization_collection
3737

3838
import grudge.op as op
3939

@@ -47,22 +47,26 @@
4747

4848

4949
@pytest.mark.parametrize("name", ["interval", "box2d", "box3d"])
50-
def test_geometric_factors_regular_refinement(actx_factory, name):
50+
@pytest.mark.parametrize("tpe", [False, True])
51+
def test_geometric_factors_regular_refinement(actx_factory, name, tpe):
5152
from grudge.dt_utils import dt_geometric_factors
5253

5354
actx = actx_factory()
5455

5556
# {{{ cases
5657

58+
from meshmode.mesh import TensorProductElementGroup
59+
group_cls = TensorProductElementGroup if tpe else None
60+
5761
if name == "interval":
5862
from mesh_data import BoxMeshBuilder
59-
builder = BoxMeshBuilder(ambient_dim=1)
63+
builder = BoxMeshBuilder(ambient_dim=1, group_cls=group_cls)
6064
elif name == "box2d":
6165
from mesh_data import BoxMeshBuilder
62-
builder = BoxMeshBuilder(ambient_dim=2)
66+
builder = BoxMeshBuilder(ambient_dim=2, group_cls=group_cls)
6367
elif name == "box3d":
6468
from mesh_data import BoxMeshBuilder
65-
builder = BoxMeshBuilder(ambient_dim=3)
69+
builder = BoxMeshBuilder(ambient_dim=3, group_cls=group_cls)
6670
else:
6771
raise ValueError("unknown geometry name: %s" % name)
6872

@@ -71,7 +75,7 @@ def test_geometric_factors_regular_refinement(actx_factory, name):
7175
min_factors = []
7276
for resolution in builder.resolutions:
7377
mesh = builder.get_mesh(resolution, builder.mesh_order)
74-
dcoll = DiscretizationCollection(actx, mesh, order=builder.order)
78+
dcoll = make_discretization_collection(actx, mesh, order=builder.order)
7579
min_factors.append(
7680
actx.to_numpy(
7781
op.nodal_min(dcoll, "vol", actx.thaw(dt_geometric_factors(dcoll))))
@@ -85,7 +89,7 @@ def test_geometric_factors_regular_refinement(actx_factory, name):
8589

8690
# Make sure it works with empty meshes
8791
mesh = builder.get_mesh(0, builder.mesh_order)
88-
dcoll = DiscretizationCollection(actx, mesh, order=builder.order)
92+
dcoll = make_discretization_collection(actx, mesh, order=builder.order)
8993
factors = actx.thaw(dt_geometric_factors(dcoll)) # noqa: F841
9094

9195

@@ -115,7 +119,7 @@ def test_non_geometric_factors(actx_factory, name):
115119
degrees = list(range(1, 8))
116120
for degree in degrees:
117121
mesh = builder.get_mesh(1, degree)
118-
dcoll = DiscretizationCollection(actx, mesh, order=degree)
122+
dcoll = make_discretization_collection(actx, mesh, order=degree)
119123
factors.append(min(dt_non_geometric_factors(dcoll)))
120124

121125
# Crude estimate, factors should behave like 1/N**2
@@ -134,7 +138,7 @@ def test_build_jacobian(actx_factory):
134138
mesh = mgen.generate_regular_rect_mesh(a=[0], b=[1], nelements_per_axis=(3,))
135139
assert mesh.dim == 1
136140

137-
dcoll = DiscretizationCollection(actx, mesh, order=1)
141+
dcoll = make_discretization_collection(actx, mesh, order=1)
138142

139143
def rhs(x):
140144
return 3*x**2 + 2*x + 5
@@ -151,19 +155,27 @@ def rhs(x):
151155

152156
@pytest.mark.parametrize("dim", [1, 2])
153157
@pytest.mark.parametrize("degree", [2, 4])
154-
def test_wave_dt_estimate(actx_factory, dim, degree, visualize=False):
158+
@pytest.mark.parametrize("tpe", [False, True])
159+
def test_wave_dt_estimate(actx_factory, dim, degree, tpe, visualize=False):
155160
actx = actx_factory()
156161

162+
# {{{ cases
163+
164+
from meshmode.mesh import TensorProductElementGroup
165+
group_cls = TensorProductElementGroup if tpe else None
166+
157167
import meshmode.mesh.generation as mgen
158168

159169
a = [0, 0, 0]
160170
b = [1, 1, 1]
161171
mesh = mgen.generate_regular_rect_mesh(
162172
a=a[:dim], b=b[:dim],
163-
nelements_per_axis=(3,)*dim)
173+
nelements_per_axis=(3,)*dim,
174+
group_cls=group_cls)
175+
164176
assert mesh.dim == dim
165177

166-
dcoll = DiscretizationCollection(actx, mesh, order=degree)
178+
dcoll = make_discretization_collection(actx, mesh, order=degree)
167179

168180
from grudge.models.wave import WeakWaveOperator
169181
wave_op = WeakWaveOperator(dcoll, c=1)

0 commit comments

Comments
 (0)