5858"""
5959
6060
61+ from typing import Optional , Any
62+
6163import numpy as np
6264
6365from arraycontext import thaw , freeze , ArrayContext
6466from meshmode .dof_array import DOFArray
6567
68+ from grudge .tools import to_real_dtype
6669from grudge import DiscretizationCollection
6770import grudge .dof_desc as dof_desc
6871
@@ -111,7 +114,8 @@ def to_quad(vec):
111114def forward_metric_nth_derivative (
112115 actx : ArrayContext , dcoll : DiscretizationCollection ,
113116 xyz_axis , ref_axes , dd = None ,
114- * , _use_geoderiv_connection = False ) -> DOFArray :
117+ * , _use_geoderiv_connection = False ,
118+ dtype : Optional [np .dtype [Any ]] = None ) -> DOFArray :
115119 r"""Pointwise metric derivatives representing repeated derivatives of the
116120 physical coordinate enumerated by *xyz_axis*: :math:`x_{\mathrm{xyz\_axis}}`
117121 with respect to the coordiantes on the reference element :math:`\xi_i`:
@@ -169,7 +173,8 @@ def forward_metric_nth_derivative(
169173 vec = num_reference_derivative (
170174 dcoll .discr_from_dd (inner_dd ),
171175 flat_ref_axes ,
172- thaw (dcoll .discr_from_dd (inner_dd ).nodes (), actx )[xyz_axis ]
176+ thaw (dcoll .discr_from_dd (inner_dd ).nodes (
177+ dtype = to_real_dtype (dtype )), actx )[xyz_axis ]
173178 )
174179
175180 return _geometry_to_quad_if_requested (
@@ -178,7 +183,8 @@ def forward_metric_nth_derivative(
178183
179184def forward_metric_derivative_vector (
180185 actx : ArrayContext , dcoll : DiscretizationCollection , rst_axis , dd = None ,
181- * , _use_geoderiv_connection = False ) -> np .ndarray :
186+ * , _use_geoderiv_connection = False , dtype : Optional [np .dtype [Any ]] = None
187+ ) -> np .ndarray :
182188 r"""Computes an object array containing the forward metric derivatives
183189 of each physical coordinate.
184190
@@ -195,15 +201,17 @@ def forward_metric_derivative_vector(
195201 return make_obj_array ([
196202 forward_metric_nth_derivative (
197203 actx , dcoll , i , rst_axis , dd = dd ,
198- _use_geoderiv_connection = _use_geoderiv_connection )
204+ _use_geoderiv_connection = _use_geoderiv_connection ,
205+ dtype = dtype )
199206 for i in range (dcoll .ambient_dim )
200207 ]
201208 )
202209
203210
204211def forward_metric_derivative_mv (
205212 actx : ArrayContext , dcoll : DiscretizationCollection , rst_axis , dd = None ,
206- * , _use_geoderiv_connection = False ) -> MultiVector :
213+ * , _use_geoderiv_connection = False , dtype : Optional [np .dtype [Any ]] = None
214+ ) -> MultiVector :
207215 r"""Computes a :class:`pymbolic.geometric_algebra.MultiVector` containing
208216 the forward metric derivatives of each physical coordinate.
209217
@@ -220,13 +228,15 @@ def forward_metric_derivative_mv(
220228 return MultiVector (
221229 forward_metric_derivative_vector (
222230 actx , dcoll , rst_axis , dd = dd ,
223- _use_geoderiv_connection = _use_geoderiv_connection )
231+ _use_geoderiv_connection = _use_geoderiv_connection ,
232+ dtype = dtype )
224233 )
225234
226235
227236def forward_metric_derivative_mat (
228237 actx : ArrayContext , dcoll : DiscretizationCollection , dd = None ,
229- * , _use_geoderiv_connection = False ) -> np .ndarray :
238+ * , _use_geoderiv_connection = False ,
239+ dtype : Optional [np .dtype [Any ]] = None ) -> np .ndarray :
230240 r"""Computes the forward metric derivative matrix, also commonly
231241 called the Jacobian matrix, with entries defined as the
232242 forward metric derivatives:
@@ -260,13 +270,15 @@ def forward_metric_derivative_mat(
260270 for j in range (dim ):
261271 result [:, j ] = forward_metric_derivative_vector (
262272 actx , dcoll , j , dd = dd ,
263- _use_geoderiv_connection = _use_geoderiv_connection )
273+ _use_geoderiv_connection = _use_geoderiv_connection ,
274+ dtype = dtype )
264275
265276 return result
266277
267278
268279def first_fundamental_form (actx : ArrayContext , dcoll : DiscretizationCollection ,
269- dd = None , * , _use_geoderiv_connection = False ) -> np .ndarray :
280+ dd = None , * , _use_geoderiv_connection = False ,
281+ dtype : Optional [np .dtype [Any ]] = None ) -> np .ndarray :
270282 r"""Computes the first fundamental form using the Jacobian matrix:
271283
272284 .. math::
@@ -295,14 +307,16 @@ def first_fundamental_form(actx: ArrayContext, dcoll: DiscretizationCollection,
295307 dd = DD_VOLUME
296308
297309 mder = forward_metric_derivative_mat (
298- actx , dcoll , dd = dd , _use_geoderiv_connection = _use_geoderiv_connection )
310+ actx , dcoll , dd = dd , _use_geoderiv_connection = _use_geoderiv_connection ,
311+ dtype = dtype )
299312
300313 return mder .T .dot (mder )
301314
302315
303316def inverse_metric_derivative_mat (
304317 actx : ArrayContext , dcoll : DiscretizationCollection , dd = None ,
305- * , _use_geoderiv_connection = False ) -> np .ndarray :
318+ * , _use_geoderiv_connection = False , dtype : Optional [np .dtype [Any ]] = None
319+ ) -> np .ndarray :
306320 r"""Computes the inverse metric derivative matrix, which is
307321 the inverse of the Jacobian (forward metric derivative) matrix.
308322
@@ -324,15 +338,16 @@ def inverse_metric_derivative_mat(
324338 for j in range (ambient_dim ):
325339 result [i , j ] = inverse_metric_derivative (
326340 actx , dcoll , i , j , dd = dd ,
327- _use_geoderiv_connection = _use_geoderiv_connection
328- )
341+ _use_geoderiv_connection = _use_geoderiv_connection ,
342+ dtype = dtype )
329343
330344 return result
331345
332346
333347def inverse_first_fundamental_form (
334348 actx : ArrayContext , dcoll : DiscretizationCollection , dd = None ,
335- * , _use_geoderiv_connection = False ) -> np .ndarray :
349+ * , _use_geoderiv_connection = False , dtype : Optional [np .dtype [Any ]]
350+ ) -> np .ndarray :
336351 r"""Computes the inverse of the first fundamental form:
337352
338353 .. math::
@@ -361,11 +376,13 @@ def inverse_first_fundamental_form(
361376
362377 if dcoll .ambient_dim == dim :
363378 inv_mder = inverse_metric_derivative_mat (
364- actx , dcoll , dd = dd , _use_geoderiv_connection = _use_geoderiv_connection )
379+ actx , dcoll , dd = dd , _use_geoderiv_connection = _use_geoderiv_connection ,
380+ dtype = dtype )
365381 inv_form1 = inv_mder .dot (inv_mder .T )
366382 else :
367383 form1 = first_fundamental_form (
368- actx , dcoll , dd = dd , _use_geoderiv_connection = _use_geoderiv_connection )
384+ actx , dcoll , dd = dd , _use_geoderiv_connection = _use_geoderiv_connection ,
385+ dtype = dtype )
369386
370387 if dim == 1 :
371388 inv_form1 = 1.0 / form1
@@ -383,7 +400,7 @@ def inverse_first_fundamental_form(
383400
384401def inverse_metric_derivative (
385402 actx : ArrayContext , dcoll : DiscretizationCollection , rst_axis , xyz_axis , dd ,
386- * , _use_geoderiv_connection = False
403+ * , _use_geoderiv_connection = False , dtype : Optional [ np . dtype [ Any ]] = None
387404 ) -> DOFArray :
388405 r"""Computes the inverse metric derivative of the physical
389406 coordinate enumerated by *xyz_axis* with respect to the
@@ -409,7 +426,8 @@ def inverse_metric_derivative(
409426 par_vecs = [
410427 forward_metric_derivative_mv (
411428 actx , dcoll , rst , dd ,
412- _use_geoderiv_connection = _use_geoderiv_connection )
429+ _use_geoderiv_connection = _use_geoderiv_connection ,
430+ dtype = dtype )
413431 for rst in range (dim )]
414432
415433 # Yay Cramer's rule!
@@ -442,7 +460,8 @@ def outprod_with_unit(i, at):
442460def inverse_surface_metric_derivative (
443461 actx : ArrayContext , dcoll : DiscretizationCollection ,
444462 rst_axis , xyz_axis , dd = None ,
445- * , _use_geoderiv_connection = False ):
463+ * , _use_geoderiv_connection = False ,
464+ dtype : Optional [np .dtype [Any ]]):
446465 r"""Computes the inverse surface metric derivative of the physical
447466 coordinate enumerated by *xyz_axis* with respect to the
448467 reference axis *rst_axis*. These geometric terms are used in the
@@ -467,24 +486,24 @@ def inverse_surface_metric_derivative(
467486 dd = dof_desc .as_dofdesc (dd )
468487
469488 if ambient_dim == dim :
470- result = inverse_metric_derivative (
489+ return inverse_metric_derivative (
471490 actx , dcoll , rst_axis , xyz_axis , dd = dd ,
472- _use_geoderiv_connection = _use_geoderiv_connection
473- )
491+ _use_geoderiv_connection = _use_geoderiv_connection ,
492+ dtype = dtype )
474493 else :
475494 inv_form1 = inverse_first_fundamental_form (actx , dcoll , dd = dd )
476- result = sum (
495+ return sum (
477496 inv_form1 [rst_axis , d ]* forward_metric_nth_derivative (
478497 actx , dcoll , xyz_axis , d , dd = dd ,
479- _use_geoderiv_connection = _use_geoderiv_connection
498+ _use_geoderiv_connection = _use_geoderiv_connection ,
499+ dtype = dtype ,
480500 ) for d in range (dim ))
481501
482- return result
483-
484502
485503def inverse_surface_metric_derivative_mat (
486504 actx : ArrayContext , dcoll : DiscretizationCollection , dd = None ,
487- * , times_area_element = False , _use_geoderiv_connection = False ):
505+ * , times_area_element = False , _use_geoderiv_connection = False ,
506+ dtype : Optional [np .dtype [Any ]] = None ):
488507 r"""Computes the matrix of inverse surface metric derivatives, indexed by
489508 ``(xyz_axis, rst_axis)``. It returns all values of
490509 :func:`inverse_surface_metric_derivative_mat` in cached matrix form.
@@ -505,7 +524,7 @@ def inverse_surface_metric_derivative_mat(
505524
506525 @memoize_in (dcoll , (inverse_surface_metric_derivative_mat , dd ,
507526 times_area_element , _use_geoderiv_connection ))
508- def _inv_surf_metric_deriv ():
527+ def _inv_surf_metric_deriv (dtype ):
509528 if times_area_element :
510529 multiplier = area_element (actx , dcoll , dd = dd ,
511530 _use_geoderiv_connection = _use_geoderiv_connection )
@@ -517,13 +536,17 @@ def _inv_surf_metric_deriv():
517536 multiplier
518537 * inverse_surface_metric_derivative (actx , dcoll ,
519538 rst_axis , xyz_axis , dd = dd ,
520- _use_geoderiv_connection = _use_geoderiv_connection )
539+ _use_geoderiv_connection = _use_geoderiv_connection ,
540+ dtype = dtype )
521541 for rst_axis in range (dcoll .dim )])
522542 for xyz_axis in range (dcoll .ambient_dim )])
523543
524544 return freeze (mat , actx )
525545
526- return thaw (_inv_surf_metric_deriv (), actx )
546+ if dtype is not None :
547+ dtype = to_real_dtype (dtype )
548+
549+ return thaw (_inv_surf_metric_deriv (dtype ), actx )
527550
528551
529552def _signed_face_ones (
@@ -557,7 +580,8 @@ def _signed_face_ones(
557580
558581def parametrization_derivative (
559582 actx : ArrayContext , dcoll : DiscretizationCollection , dd ,
560- * , _use_geoderiv_connection = False ) -> MultiVector :
583+ * , _use_geoderiv_connection = False ,
584+ dtype : Optional [np .dtype [Any ]]) -> MultiVector :
561585 r"""Computes the product of forward metric derivatives spanning the
562586 tangent space with topological dimension *dim*.
563587
@@ -585,13 +609,15 @@ def parametrization_derivative(
585609 return product (
586610 forward_metric_derivative_mv (
587611 actx , dcoll , rst_axis , dd ,
588- _use_geoderiv_connection = _use_geoderiv_connection )
612+ _use_geoderiv_connection = _use_geoderiv_connection ,
613+ dtype = dtype )
589614 for rst_axis in range (dim )
590615 )
591616
592617
593618def pseudoscalar (actx : ArrayContext , dcoll : DiscretizationCollection ,
594- dd = None , * , _use_geoderiv_connection = False ) -> MultiVector :
619+ dd = None , * , _use_geoderiv_connection = False ,
620+ dtype : Optional [np .dtype [Any ]]) -> MultiVector :
595621 r"""Computes the field of pseudoscalars for the domain/discretization
596622 identified by *dd*.
597623
@@ -607,12 +633,14 @@ def pseudoscalar(actx: ArrayContext, dcoll: DiscretizationCollection,
607633
608634 return parametrization_derivative (
609635 actx , dcoll , dd ,
610- _use_geoderiv_connection = _use_geoderiv_connection ).project_max_grade ()
636+ _use_geoderiv_connection = _use_geoderiv_connection ,
637+ dtype = dtype ).project_max_grade ()
611638
612639
613640def area_element (
614641 actx : ArrayContext , dcoll : DiscretizationCollection , dd = None ,
615- * , _use_geoderiv_connection = False
642+ * , _use_geoderiv_connection = False ,
643+ dtype : Optional [np .dtype [Any ]] = None
616644 ) -> DOFArray :
617645 r"""Computes the scale factor used to transform integrals from reference
618646 to global space.
@@ -623,22 +651,28 @@ def area_element(
623651 Defaults to the base volume discretization.
624652 :arg _use_geoderiv_connection: For internal use. See
625653 :func:`forward_metric_nth_derivative` for an explanation.
654+ :arg dtype: the :class:`numpy.dtype` with which to return the area element
655+ data.
626656 :returns: a :class:`~meshmode.dof_array.DOFArray` containing the transformed
627657 volumes for each element.
628658 """
629659 if dd is None :
630660 dd = DD_VOLUME
631661
632662 @memoize_in (dcoll , (area_element , dd , _use_geoderiv_connection ))
633- def _area_elements ():
663+ def _area_elements (dtype : np . dtype [ Any ] ):
634664 result = actx .np .sqrt (
635665 pseudoscalar (
636666 actx , dcoll , dd = dd ,
637- _use_geoderiv_connection = _use_geoderiv_connection ).norm_squared ())
667+ _use_geoderiv_connection = _use_geoderiv_connection ,
668+ dtype = dtype ).norm_squared ())
638669
639670 return freeze (result , actx )
640671
641- return thaw (_area_elements (), actx )
672+ if dtype is not None :
673+ dtype = to_real_dtype (dtype )
674+
675+ return thaw (_area_elements (dtype ), actx )
642676
643677# }}}
644678
0 commit comments