Skip to content

Commit ae3a105

Browse files
committed
Streamline discr fetching through to discr_from_dd
1 parent 23dfa75 commit ae3a105

File tree

2 files changed

+42
-63
lines changed

2 files changed

+42
-63
lines changed

grudge/discretization.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
make_face_restriction
5353
)
5454
from meshmode.mesh import Mesh, BTAG_PARTITION
55+
from meshmode.discretization import Discretization
5556
from meshmode.discretization.connection import DiscretizationConnection
5657

5758
from typing import Dict
@@ -67,8 +68,6 @@ class DiscretizationCollection:
6768
(volume, interior facets, boundaries) and associated element
6869
groups.
6970
70-
.. automethod:: __init__
71-
7271
.. autoattribute:: dim
7372
.. autoattribute:: ambient_dim
7473
.. autoattribute:: mesh
@@ -91,12 +90,17 @@ class DiscretizationCollection:
9190

9291
# {{{ constructor
9392

94-
def __init__(self, array_context: ArrayContext, mesh: Mesh,
93+
def __init__(self,
94+
array_context: ArrayContext,
95+
mesh: Mesh,
96+
base_discr: Discretization,
9597
discr_tag_to_group_factory,
96-
discr_from_dd,
9798
dist_boundary_connections,
9899
mpi_communicator=None):
99100
"""
101+
:arg base_discr: A :class:`~meshmode.discretization.Discretization`
102+
corresponding to the :class:`grudge.dof_desc.DISCR_TAG_BASE`
103+
descriptor.
100104
:arg discr_tag_to_group_factory: A mapping from discretization tags
101105
(typically one of: :class:`grudge.dof_desc.DISCR_TAG_BASE`,
102106
:class:`grudge.dof_desc.DISCR_TAG_MODAL`, or
@@ -106,20 +110,15 @@ def __init__(self, array_context: ArrayContext, mesh: Mesh,
106110
to be carried out, or *None* to indicate that operations with this
107111
discretization tag should be carried out with the standard volume
108112
discretization.
109-
:arg discr_from_dd: A mapping from discretization tags
110-
from a :class:`grudge.dof_desc.DOFDesc`
111-
to a :class:`meshmode.discretization.Discretization`.
112-
At minimum, this should include a base discretization given by
113-
the dof descriptor :class:`grudge.dof_desc.DD_VOLUME`.
114113
:arg dist_boundary_connections: A dictionary whose keys denote the
115114
partition group index and map to the appropriate face connections
116115
for distributed boundaries, if any.
117116
:arg mpi_communicator: An (optional) MPI communicator.
118117
"""
119118
self._setup_actx = array_context
120119
self._mesh = mesh
120+
self._base_discr = base_discr
121121
self.discr_tag_to_group_factory = discr_tag_to_group_factory
122-
self._discr_from_dd = discr_from_dd
123122

124123
# NOTE: Can be removed when symbolics are completely removed
125124
# {{{ management of discretization-scoped common subexpressions
@@ -203,22 +202,18 @@ def discr_from_dd(self, dd):
203202
convertible to one.
204203
"""
205204
dd = as_dofdesc(dd)
206-
207-
if dd in self._discr_from_dd:
208-
return self._discr_from_dd[dd]
209-
210205
discr_tag = dd.discretization_tag
211206

212207
if discr_tag is DISCR_TAG_MODAL:
213208
return self._modal_discr(dd.domain_tag)
214209

215210
if dd.is_volume():
216-
return self._discr_tag_volume_discr(discr_tag)
211+
if discr_tag is not DISCR_TAG_BASE:
212+
return self._discr_tag_volume_discr(discr_tag)
213+
return self._base_discr
217214

218215
if discr_tag is not DISCR_TAG_BASE:
219216
no_quad_discr = self.discr_from_dd(DOFDesc(dd.domain_tag))
220-
221-
from meshmode.discretization import Discretization
222217
return Discretization(
223218
self._setup_actx,
224219
no_quad_discr.mesh,
@@ -282,7 +277,6 @@ def geo_group_factory(megrp, index):
282277
else:
283278
return base_group_factory(megrp, index)
284279

285-
from meshmode.discretization import Discretization
286280
geo_deriv_discr = Discretization(
287281
self._setup_actx, base_discr.mesh,
288282
geo_group_factory)
@@ -405,7 +399,7 @@ def connection_from_dds(self, from_dd, to_dd):
405399
from meshmode.discretization.connection import \
406400
make_same_mesh_connection
407401
to_discr = self._discr_tag_volume_discr(to_discr_tag)
408-
from_discr = self._discr_from_dd[DD_VOLUME]
402+
from_discr = self._base_discr
409403
return make_same_mesh_connection(self._setup_actx, to_discr,
410404
from_discr)
411405

@@ -443,20 +437,17 @@ def group_factory_for_discretization_tag(self, discretization_tag):
443437
def _discr_tag_volume_discr(self, discretization_tag):
444438
assert discretization_tag is not None
445439

446-
# Refuse to re-make the volume discretization
440+
# Refuse to re-make the base volume discretization
447441
if discretization_tag is DISCR_TAG_BASE:
448-
return self.discr_from_dd("vol")
442+
return self._base_discr
449443

450-
from meshmode.discretization import Discretization
451444
return Discretization(
452445
self._setup_actx, self._mesh,
453446
self.group_factory_for_discretization_tag(discretization_tag)
454447
)
455448

456449
@memoize_method
457450
def _modal_discr(self, domain_tag):
458-
from meshmode.discretization import Discretization
459-
460451
discr_base = self.discr_from_dd(DOFDesc(domain_tag, DISCR_TAG_BASE))
461452
return Discretization(
462453
self._setup_actx, discr_base.mesh,
@@ -503,7 +494,7 @@ def _nodal_to_modal_connection(self, from_dd):
503494
def _boundary_connection(self, boundary_tag):
504495
return make_face_restriction(
505496
self._setup_actx,
506-
self._discr_from_dd[DD_VOLUME],
497+
self._base_discr,
507498
self.group_factory_for_discretization_tag(DISCR_TAG_BASE),
508499
boundary_tag=boundary_tag
509500
)
@@ -516,7 +507,7 @@ def _boundary_connection(self, boundary_tag):
516507
def _interior_faces_connection(self):
517508
return make_face_restriction(
518509
self._setup_actx,
519-
self._discr_from_dd[DD_VOLUME],
510+
self._base_discr,
520511
self.group_factory_for_discretization_tag(DISCR_TAG_BASE),
521512
FACE_RESTR_INTERIOR,
522513

@@ -547,7 +538,7 @@ def opposite_face_connection(self):
547538
def _all_faces_volume_connection(self):
548539
return make_face_restriction(
549540
self._setup_actx,
550-
self._discr_from_dd[DD_VOLUME],
541+
self._base_discr,
551542
self.group_factory_for_discretization_tag(DISCR_TAG_BASE),
552543
FACE_RESTR_ALL,
553544

@@ -562,22 +553,22 @@ def _all_faces_volume_connection(self):
562553
@property
563554
def dim(self):
564555
"""Return the topological dimension."""
565-
return self._discr_from_dd[DD_VOLUME].dim
556+
return self._base_discr.dim
566557

567558
@property
568559
def ambient_dim(self):
569560
"""Return the dimension of the ambient space."""
570-
return self._discr_from_dd[DD_VOLUME].ambient_dim
561+
return self._base_discr.ambient_dim
571562

572563
@property
573564
def real_dtype(self):
574565
"""Return the data type used for real-valued arithmetic."""
575-
return self._discr_from_dd[DD_VOLUME].real_dtype
566+
return self._base_discr.real_dtype
576567

577568
@property
578569
def complex_dtype(self):
579570
"""Return the data type used for complex-valued arithmetic."""
580-
return self._discr_from_dd[DD_VOLUME].complex_dtype
571+
return self._base_discr.complex_dtype
581572

582573
@property
583574
def mesh(self):
@@ -595,7 +586,7 @@ def empty(self, array_context: ArrayContext, dtype=None):
595586
vector of dtype :attr:`complex_dtype`. If
596587
*None* (the default), a real vector will be returned.
597588
"""
598-
return self._discr_from_dd[DD_VOLUME].empty(array_context, dtype)
589+
return self._base_discr.empty(array_context, dtype)
599590

600591
def zeros(self, array_context: ArrayContext, dtype=None):
601592
"""Return a zero-initialized :class:`~meshmode.dof_array.DOFArray`
@@ -606,7 +597,7 @@ def zeros(self, array_context: ArrayContext, dtype=None):
606597
vector of dtype :attr:`complex_dtype`. If
607598
*None* (the default), a real vector will be returned.
608599
"""
609-
return self._discr_from_dd[DD_VOLUME].zeros(array_context, dtype)
600+
return self._base_discr.zeros(array_context, dtype)
610601

611602
def is_volume_where(self, where):
612603
return where is None or as_dofdesc(where).is_volume()
@@ -620,7 +611,7 @@ def order(self):
620611

621612
from pytools import single_valued
622613
return single_valued(
623-
egrp.order for egrp in self._discr_from_dd[DD_VOLUME].groups
614+
egrp.order for egrp in self._base_discr.groups
624615
)
625616

626617
# {{{ Discretization-specific geometric properties
@@ -715,16 +706,12 @@ def make_discretization_collection(
715706
discr_tag_to_group_factory[DISCR_TAG_BASE]
716707
)
717708

718-
# Define the base and modal discretization
719-
from meshmode.discretization import Discretization
720-
709+
# Define the base discretization
721710
base_discr = Discretization(
722711
array_context, mesh,
723712
discr_tag_to_group_factory[DISCR_TAG_BASE]
724713
)
725714

726-
discr_from_dd = {DD_VOLUME: base_discr}
727-
728715
# Define boundary connections
729716
dist_boundary_connections = set_up_distributed_communication(
730717
array_context, mesh,
@@ -735,8 +722,8 @@ def make_discretization_collection(
735722
return DiscretizationCollection(
736723
array_context=array_context,
737724
mesh=mesh,
725+
base_discr=base_discr,
738726
discr_tag_to_group_factory=discr_tag_to_group_factory,
739-
discr_from_dd=discr_from_dd,
740727
dist_boundary_connections=dist_boundary_connections,
741728
mpi_communicator=mpi_communicator
742729
)

grudge/eager.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self, array_context: ArrayContext, mesh: Mesh,
6767
discr_tag_to_group_factory = quad_tag_to_group_factory
6868

6969
from meshmode.discretization.poly_element import \
70-
PolynomialWarpAndBlendGroupFactory
70+
default_simplex_group_factory
7171

7272
from grudge.dof_desc import DISCR_TAG_BASE, DISCR_TAG_MODAL
7373

@@ -77,9 +77,9 @@ def __init__(self, array_context: ArrayContext, mesh: Mesh,
7777
"one of 'order' and 'discr_tag_to_group_factory' must be given"
7878
)
7979

80-
# Default choice: warp and blend simplex element group
8180
discr_tag_to_group_factory = {
82-
DISCR_TAG_BASE: PolynomialWarpAndBlendGroupFactory(order=order)
81+
DISCR_TAG_BASE: default_simplex_group_factory(base_dim=mesh.dim,
82+
order=order)
8383
}
8484
else:
8585
if order is not None:
@@ -91,47 +91,39 @@ def __init__(self, array_context: ArrayContext, mesh: Mesh,
9191
)
9292

9393
discr_tag_to_group_factory[DISCR_TAG_BASE] = \
94-
PolynomialWarpAndBlendGroupFactory(order=order)
94+
default_simplex_group_factory(base_dim=mesh.dim, order=order)
9595

96-
# Modal discr should always comes from the base discretization
97-
from grudge.discretization import _generate_modal_group_factory
96+
# Supply modal group factory if not provided
97+
if DISCR_TAG_MODAL not in discr_tag_to_group_factory:
98+
from grudge.discretization import _generate_modal_group_factory
9899

99-
discr_tag_to_group_factory[DISCR_TAG_MODAL] = \
100-
_generate_modal_group_factory(
101-
discr_tag_to_group_factory[DISCR_TAG_BASE]
102-
)
100+
discr_tag_to_group_factory[DISCR_TAG_MODAL] = \
101+
_generate_modal_group_factory(
102+
discr_tag_to_group_factory[DISCR_TAG_BASE]
103+
)
103104

104-
# Define the base and modal discretization
105-
from grudge.dof_desc import DD_VOLUME, DD_VOLUME_MODAL
105+
# Define the base discretization
106106
from meshmode.discretization import Discretization
107107

108-
volume_discr = Discretization(
108+
base_discr = Discretization(
109109
array_context, mesh,
110110
discr_tag_to_group_factory[DISCR_TAG_BASE]
111111
)
112112

113-
modal_vol_discr = Discretization(
114-
array_context, mesh,
115-
discr_tag_to_group_factory[DISCR_TAG_MODAL]
116-
)
117-
118-
discr_from_dd = {DD_VOLUME: volume_discr,
119-
DD_VOLUME_MODAL: modal_vol_discr}
120-
121113
# Define boundary connections
122114
from grudge.discretization import set_up_distributed_communication
123115

124116
dist_boundary_connections = set_up_distributed_communication(
125117
array_context, mesh,
126-
volume_discr,
118+
base_discr,
127119
discr_tag_to_group_factory, comm=mpi_communicator
128120
)
129121

130122
super().__init__(
131123
array_context=array_context,
132124
mesh=mesh,
125+
base_discr=base_discr,
133126
discr_tag_to_group_factory=discr_tag_to_group_factory,
134-
discr_from_dd=discr_from_dd,
135127
dist_boundary_connections=dist_boundary_connections,
136128
mpi_communicator=mpi_communicator
137129
)

0 commit comments

Comments
 (0)