Skip to content

Commit 77cfaff

Browse files
committed
Add explicit dtypes to more operations
1 parent 62af257 commit 77cfaff

File tree

4 files changed

+152
-77
lines changed

4 files changed

+152
-77
lines changed

grudge/discretization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
THE SOFTWARE.
3030
"""
3131

32+
from typing import Any, Optional
3233
from pytools import memoize_method
3334

3435
from grudge.dof_desc import (
@@ -698,7 +699,7 @@ def order(self):
698699

699700
# {{{ Discretization-specific geometric properties
700701

701-
def nodes(self, dd=None):
702+
def nodes(self, dd=None, dtype: Optional[np.dtype[Any]] = None):
702703
r"""Return the nodes of a discretization specified by *dd*.
703704
704705
:arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one.
@@ -707,7 +708,7 @@ def nodes(self, dd=None):
707708
"""
708709
if dd is None:
709710
dd = DD_VOLUME
710-
return self.discr_from_dd(dd).nodes()
711+
return self.discr_from_dd(dd).nodes(dtype)
711712

712713
def normal(self, dd):
713714
r"""Get the unit normal to the specified surface discretization, *dd*.

grudge/geometry/metrics.py

Lines changed: 72 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,14 @@
5858
"""
5959

6060

61+
from typing import Optional, Any
62+
6163
import numpy as np
6264

6365
from arraycontext import thaw, freeze, ArrayContext
6466
from meshmode.dof_array import DOFArray
6567

68+
from grudge.tools import to_real_dtype
6669
from grudge import DiscretizationCollection
6770
import grudge.dof_desc as dof_desc
6871

@@ -111,7 +114,8 @@ def to_quad(vec):
111114
def forward_metric_nth_derivative(
112115
actx: ArrayContext, dcoll: DiscretizationCollection,
113116
xyz_axis, ref_axes, dd=None,
114-
*, _use_geoderiv_connection=False) -> DOFArray:
117+
*, _use_geoderiv_connection=False,
118+
dtype: Optional[np.dtype[Any]] = None) -> DOFArray:
115119
r"""Pointwise metric derivatives representing repeated derivatives of the
116120
physical coordinate enumerated by *xyz_axis*: :math:`x_{\mathrm{xyz\_axis}}`
117121
with respect to the coordiantes on the reference element :math:`\xi_i`:
@@ -169,7 +173,8 @@ def forward_metric_nth_derivative(
169173
vec = num_reference_derivative(
170174
dcoll.discr_from_dd(inner_dd),
171175
flat_ref_axes,
172-
thaw(dcoll.discr_from_dd(inner_dd).nodes(), actx)[xyz_axis]
176+
thaw(dcoll.discr_from_dd(inner_dd).nodes(
177+
dtype=to_real_dtype(dtype)), actx)[xyz_axis]
173178
)
174179

175180
return _geometry_to_quad_if_requested(
@@ -178,7 +183,8 @@ def forward_metric_nth_derivative(
178183

179184
def forward_metric_derivative_vector(
180185
actx: ArrayContext, dcoll: DiscretizationCollection, rst_axis, dd=None,
181-
*, _use_geoderiv_connection=False) -> np.ndarray:
186+
*, _use_geoderiv_connection=False, dtype: Optional[np.dtype[Any]] = None
187+
) -> np.ndarray:
182188
r"""Computes an object array containing the forward metric derivatives
183189
of each physical coordinate.
184190
@@ -195,15 +201,17 @@ def forward_metric_derivative_vector(
195201
return make_obj_array([
196202
forward_metric_nth_derivative(
197203
actx, dcoll, i, rst_axis, dd=dd,
198-
_use_geoderiv_connection=_use_geoderiv_connection)
204+
_use_geoderiv_connection=_use_geoderiv_connection,
205+
dtype=dtype)
199206
for i in range(dcoll.ambient_dim)
200207
]
201208
)
202209

203210

204211
def forward_metric_derivative_mv(
205212
actx: ArrayContext, dcoll: DiscretizationCollection, rst_axis, dd=None,
206-
*, _use_geoderiv_connection=False) -> MultiVector:
213+
*, _use_geoderiv_connection=False, dtype: Optional[np.dtype[Any]] = None
214+
) -> MultiVector:
207215
r"""Computes a :class:`pymbolic.geometric_algebra.MultiVector` containing
208216
the forward metric derivatives of each physical coordinate.
209217
@@ -220,13 +228,15 @@ def forward_metric_derivative_mv(
220228
return MultiVector(
221229
forward_metric_derivative_vector(
222230
actx, dcoll, rst_axis, dd=dd,
223-
_use_geoderiv_connection=_use_geoderiv_connection)
231+
_use_geoderiv_connection=_use_geoderiv_connection,
232+
dtype=dtype)
224233
)
225234

226235

227236
def forward_metric_derivative_mat(
228237
actx: ArrayContext, dcoll: DiscretizationCollection, dd=None,
229-
*, _use_geoderiv_connection=False) -> np.ndarray:
238+
*, _use_geoderiv_connection=False,
239+
dtype: Optional[np.dtype[Any]] = None) -> np.ndarray:
230240
r"""Computes the forward metric derivative matrix, also commonly
231241
called the Jacobian matrix, with entries defined as the
232242
forward metric derivatives:
@@ -260,13 +270,15 @@ def forward_metric_derivative_mat(
260270
for j in range(dim):
261271
result[:, j] = forward_metric_derivative_vector(
262272
actx, dcoll, j, dd=dd,
263-
_use_geoderiv_connection=_use_geoderiv_connection)
273+
_use_geoderiv_connection=_use_geoderiv_connection,
274+
dtype=dtype)
264275

265276
return result
266277

267278

268279
def first_fundamental_form(actx: ArrayContext, dcoll: DiscretizationCollection,
269-
dd=None, *, _use_geoderiv_connection=False) -> np.ndarray:
280+
dd=None, *, _use_geoderiv_connection=False,
281+
dtype: Optional[np.dtype[Any]] = None) -> np.ndarray:
270282
r"""Computes the first fundamental form using the Jacobian matrix:
271283
272284
.. math::
@@ -295,14 +307,16 @@ def first_fundamental_form(actx: ArrayContext, dcoll: DiscretizationCollection,
295307
dd = DD_VOLUME
296308

297309
mder = forward_metric_derivative_mat(
298-
actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection)
310+
actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection,
311+
dtype=dtype)
299312

300313
return mder.T.dot(mder)
301314

302315

303316
def inverse_metric_derivative_mat(
304317
actx: ArrayContext, dcoll: DiscretizationCollection, dd=None,
305-
*, _use_geoderiv_connection=False) -> np.ndarray:
318+
*, _use_geoderiv_connection=False, dtype: Optional[np.dtype[Any]] = None
319+
) -> np.ndarray:
306320
r"""Computes the inverse metric derivative matrix, which is
307321
the inverse of the Jacobian (forward metric derivative) matrix.
308322
@@ -324,15 +338,16 @@ def inverse_metric_derivative_mat(
324338
for j in range(ambient_dim):
325339
result[i, j] = inverse_metric_derivative(
326340
actx, dcoll, i, j, dd=dd,
327-
_use_geoderiv_connection=_use_geoderiv_connection
328-
)
341+
_use_geoderiv_connection=_use_geoderiv_connection,
342+
dtype=dtype)
329343

330344
return result
331345

332346

333347
def inverse_first_fundamental_form(
334348
actx: ArrayContext, dcoll: DiscretizationCollection, dd=None,
335-
*, _use_geoderiv_connection=False) -> np.ndarray:
349+
*, _use_geoderiv_connection=False, dtype: Optional[np.dtype[Any]]
350+
) -> np.ndarray:
336351
r"""Computes the inverse of the first fundamental form:
337352
338353
.. math::
@@ -361,11 +376,13 @@ def inverse_first_fundamental_form(
361376

362377
if dcoll.ambient_dim == dim:
363378
inv_mder = inverse_metric_derivative_mat(
364-
actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection)
379+
actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection,
380+
dtype=dtype)
365381
inv_form1 = inv_mder.dot(inv_mder.T)
366382
else:
367383
form1 = first_fundamental_form(
368-
actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection)
384+
actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection,
385+
dtype=dtype)
369386

370387
if dim == 1:
371388
inv_form1 = 1.0 / form1
@@ -383,7 +400,7 @@ def inverse_first_fundamental_form(
383400

384401
def inverse_metric_derivative(
385402
actx: ArrayContext, dcoll: DiscretizationCollection, rst_axis, xyz_axis, dd,
386-
*, _use_geoderiv_connection=False
403+
*, _use_geoderiv_connection=False, dtype: Optional[np.dtype[Any]] = None
387404
) -> DOFArray:
388405
r"""Computes the inverse metric derivative of the physical
389406
coordinate enumerated by *xyz_axis* with respect to the
@@ -409,7 +426,8 @@ def inverse_metric_derivative(
409426
par_vecs = [
410427
forward_metric_derivative_mv(
411428
actx, dcoll, rst, dd,
412-
_use_geoderiv_connection=_use_geoderiv_connection)
429+
_use_geoderiv_connection=_use_geoderiv_connection,
430+
dtype=dtype)
413431
for rst in range(dim)]
414432

415433
# Yay Cramer's rule!
@@ -442,7 +460,8 @@ def outprod_with_unit(i, at):
442460
def inverse_surface_metric_derivative(
443461
actx: ArrayContext, dcoll: DiscretizationCollection,
444462
rst_axis, xyz_axis, dd=None,
445-
*, _use_geoderiv_connection=False):
463+
*, _use_geoderiv_connection=False,
464+
dtype: Optional[np.dtype[Any]]):
446465
r"""Computes the inverse surface metric derivative of the physical
447466
coordinate enumerated by *xyz_axis* with respect to the
448467
reference axis *rst_axis*. These geometric terms are used in the
@@ -467,24 +486,24 @@ def inverse_surface_metric_derivative(
467486
dd = dof_desc.as_dofdesc(dd)
468487

469488
if ambient_dim == dim:
470-
result = inverse_metric_derivative(
489+
return inverse_metric_derivative(
471490
actx, dcoll, rst_axis, xyz_axis, dd=dd,
472-
_use_geoderiv_connection=_use_geoderiv_connection
473-
)
491+
_use_geoderiv_connection=_use_geoderiv_connection,
492+
dtype=dtype)
474493
else:
475494
inv_form1 = inverse_first_fundamental_form(actx, dcoll, dd=dd)
476-
result = sum(
495+
return sum(
477496
inv_form1[rst_axis, d]*forward_metric_nth_derivative(
478497
actx, dcoll, xyz_axis, d, dd=dd,
479-
_use_geoderiv_connection=_use_geoderiv_connection
498+
_use_geoderiv_connection=_use_geoderiv_connection,
499+
dtype=dtype,
480500
) for d in range(dim))
481501

482-
return result
483-
484502

485503
def inverse_surface_metric_derivative_mat(
486504
actx: ArrayContext, dcoll: DiscretizationCollection, dd=None,
487-
*, times_area_element=False, _use_geoderiv_connection=False):
505+
*, times_area_element=False, _use_geoderiv_connection=False,
506+
dtype: Optional[np.dtype[Any]] = None):
488507
r"""Computes the matrix of inverse surface metric derivatives, indexed by
489508
``(xyz_axis, rst_axis)``. It returns all values of
490509
:func:`inverse_surface_metric_derivative_mat` in cached matrix form.
@@ -505,7 +524,7 @@ def inverse_surface_metric_derivative_mat(
505524

506525
@memoize_in(dcoll, (inverse_surface_metric_derivative_mat, dd,
507526
times_area_element, _use_geoderiv_connection))
508-
def _inv_surf_metric_deriv():
527+
def _inv_surf_metric_deriv(dtype):
509528
if times_area_element:
510529
multiplier = area_element(actx, dcoll, dd=dd,
511530
_use_geoderiv_connection=_use_geoderiv_connection)
@@ -517,13 +536,17 @@ def _inv_surf_metric_deriv():
517536
multiplier
518537
* inverse_surface_metric_derivative(actx, dcoll,
519538
rst_axis, xyz_axis, dd=dd,
520-
_use_geoderiv_connection=_use_geoderiv_connection)
539+
_use_geoderiv_connection=_use_geoderiv_connection,
540+
dtype=dtype)
521541
for rst_axis in range(dcoll.dim)])
522542
for xyz_axis in range(dcoll.ambient_dim)])
523543

524544
return freeze(mat, actx)
525545

526-
return thaw(_inv_surf_metric_deriv(), actx)
546+
if dtype is not None:
547+
dtype = to_real_dtype(dtype)
548+
549+
return thaw(_inv_surf_metric_deriv(dtype), actx)
527550

528551

529552
def _signed_face_ones(
@@ -557,7 +580,8 @@ def _signed_face_ones(
557580

558581
def parametrization_derivative(
559582
actx: ArrayContext, dcoll: DiscretizationCollection, dd,
560-
*, _use_geoderiv_connection=False) -> MultiVector:
583+
*, _use_geoderiv_connection=False,
584+
dtype: Optional[np.dtype[Any]]) -> MultiVector:
561585
r"""Computes the product of forward metric derivatives spanning the
562586
tangent space with topological dimension *dim*.
563587
@@ -585,13 +609,15 @@ def parametrization_derivative(
585609
return product(
586610
forward_metric_derivative_mv(
587611
actx, dcoll, rst_axis, dd,
588-
_use_geoderiv_connection=_use_geoderiv_connection)
612+
_use_geoderiv_connection=_use_geoderiv_connection,
613+
dtype=dtype)
589614
for rst_axis in range(dim)
590615
)
591616

592617

593618
def pseudoscalar(actx: ArrayContext, dcoll: DiscretizationCollection,
594-
dd=None, *, _use_geoderiv_connection=False) -> MultiVector:
619+
dd=None, *, _use_geoderiv_connection=False,
620+
dtype: Optional[np.dtype[Any]]) -> MultiVector:
595621
r"""Computes the field of pseudoscalars for the domain/discretization
596622
identified by *dd*.
597623
@@ -607,12 +633,14 @@ def pseudoscalar(actx: ArrayContext, dcoll: DiscretizationCollection,
607633

608634
return parametrization_derivative(
609635
actx, dcoll, dd,
610-
_use_geoderiv_connection=_use_geoderiv_connection).project_max_grade()
636+
_use_geoderiv_connection=_use_geoderiv_connection,
637+
dtype=dtype).project_max_grade()
611638

612639

613640
def area_element(
614641
actx: ArrayContext, dcoll: DiscretizationCollection, dd=None,
615-
*, _use_geoderiv_connection=False
642+
*, _use_geoderiv_connection=False,
643+
dtype: Optional[np.dtype[Any]] = None
616644
) -> DOFArray:
617645
r"""Computes the scale factor used to transform integrals from reference
618646
to global space.
@@ -623,22 +651,28 @@ def area_element(
623651
Defaults to the base volume discretization.
624652
:arg _use_geoderiv_connection: For internal use. See
625653
:func:`forward_metric_nth_derivative` for an explanation.
654+
:arg dtype: the :class:`numpy.dtype` with which to return the area element
655+
data.
626656
:returns: a :class:`~meshmode.dof_array.DOFArray` containing the transformed
627657
volumes for each element.
628658
"""
629659
if dd is None:
630660
dd = DD_VOLUME
631661

632662
@memoize_in(dcoll, (area_element, dd, _use_geoderiv_connection))
633-
def _area_elements():
663+
def _area_elements(dtype: np.dtype[Any]):
634664
result = actx.np.sqrt(
635665
pseudoscalar(
636666
actx, dcoll, dd=dd,
637-
_use_geoderiv_connection=_use_geoderiv_connection).norm_squared())
667+
_use_geoderiv_connection=_use_geoderiv_connection,
668+
dtype=dtype).norm_squared())
638669

639670
return freeze(result, actx)
640671

641-
return thaw(_area_elements(), actx)
672+
if dtype is not None:
673+
dtype = to_real_dtype(dtype)
674+
675+
return thaw(_area_elements(dtype), actx)
642676

643677
# }}}
644678

0 commit comments

Comments
 (0)