Skip to content

Commit 68e69f5

Browse files
committed
Merge remote-tracking branch 'origin/generic-part-bdry' into generic-part-bdry
2 parents 5413209 + f7cfe04 commit 68e69f5

24 files changed

+362
-194
lines changed

doc/conf.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,16 @@
4141
"https://mpi4py.readthedocs.io/en/stable": None,
4242
"h5py": ("https://docs.h5py.org/en/stable", None),
4343
}
44+
45+
46+
# Some modules need to import things just so that sphinx can resolve symbols in
47+
# type annotations. Often, we do not want these imports (e.g. of PyOpenCL) when
48+
# in normal use (because they would introduce unintended side effects or hard
49+
# dependencies). This flag exists so that these imports only occur during doc
50+
# build. Since sphinx appears to resolve type hints lexically (as it should),
51+
# this needs to be cross-module (since, e.g. an inherited arraycontext
52+
# docstring can be read by sphinx when building meshmode, a dependent package),
53+
# this needs a setting of the same name across all packages involved, that's
54+
# why this name is as global-sounding as it is.
55+
import sys
56+
sys._BUILDING_SPHINX_DOCS = True

examples/moving-geometry.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from meshmode.array_context import PyOpenCLArrayContext
2727
from meshmode.transform_metadata import FirstAxisIsElementsTag
28-
from arraycontext import thaw
2928

3029
from pytools import keyed_memoize_in
3130
from pytools.obj_array import make_obj_array
@@ -186,15 +185,15 @@ def velocity_field(nodes, alpha=1.0):
186185

187186
def source(t, x):
188187
discr = reconstruct_discr_from_nodes(actx, discr0, x)
189-
u = velocity_field(thaw(discr.nodes(), actx))
188+
u = velocity_field(actx.thaw(discr.nodes()))
190189

191190
# {{{
192191

193192
# NOTE: these are just here because this was at some point used to
194193
# profile some more operators (turned out well!)
195194

196195
from meshmode.discretization import num_reference_derivative
197-
x = thaw(discr.nodes()[0], actx)
196+
x = actx.thaw(discr.nodes()[0])
198197
gradx = sum(
199198
num_reference_derivative(discr, (i,), x)
200199
for i in range(discr.dim))
@@ -214,7 +213,7 @@ def source(t, x):
214213
maxiter = int(tmax // timestep) + 1
215214
dt = tmax / maxiter + 1.0e-15
216215

217-
x = thaw(discr0.nodes(), actx)
216+
x = actx.thaw(discr0.nodes())
218217
t = 0.0
219218

220219
if visualize:

examples/parallel-vtkhdf.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ def main(*, ambient_dim: int) -> None:
6363

6464
logger.info("[%4d] discretization: finished", mpirank)
6565

66-
from arraycontext import thaw
67-
vector_field = thaw(discr.nodes(), actx)
68-
scalar_field = actx.np.sin(thaw(discr.nodes()[0], actx))
66+
vector_field = actx.thaw(discr.nodes())
67+
scalar_field = actx.np.sin(vector_field[0])
6968
part_id = 1 + mpirank + discr.zeros(actx)
7069
logger.info("[%4d] fields: finished", mpirank)
7170

examples/plot-connectivity.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pyopencl as cl
33

44
from meshmode.array_context import PyOpenCLArrayContext
5-
from arraycontext import thaw
65

76
order = 4
87

@@ -30,7 +29,7 @@ def main():
3029
vis = make_visualizer(actx, discr, order)
3130

3231
vis.write_vtk_file("geometry.vtu", [
33-
("f", thaw(discr.nodes()[0], actx)),
32+
("f", actx.thaw(discr.nodes()[0])),
3433
])
3534

3635
from meshmode.discretization.visualization import \

examples/simple-dg.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from meshmode.array_context import (PyOpenCLArrayContext,
3636
PytatoPyOpenCLArrayContext)
3737
from arraycontext import (
38-
freeze, thaw,
3938
ArrayContainer,
4039
map_array_container,
4140
with_container_arithmetic,
@@ -57,7 +56,7 @@
5756
# {{{ discretization
5857

5958
def parametrization_derivative(actx, discr):
60-
thawed_nodes = thaw(discr.nodes(), actx)
59+
thawed_nodes = actx.thaw(discr.nodes())
6160

6261
from meshmode.discretization import num_reference_derivative
6362
result = np.zeros((discr.ambient_dim, discr.dim), dtype=object)
@@ -175,17 +174,17 @@ def get_discr(self, where):
175174

176175
@memoize_method
177176
def parametrization_derivative(self):
178-
return freeze(
177+
return self._setup_actx.freeze(
179178
parametrization_derivative(self._setup_actx, self.volume_discr))
180179

181180
@memoize_method
182181
def vol_jacobian(self):
183-
[a, b], [c, d] = thaw(self.parametrization_derivative(), self._setup_actx)
184-
return freeze(a*d-b*c)
182+
[a, b], [c, d] = self._setup_actx.thaw(self.parametrization_derivative())
183+
return self._setup_actx.freeze(a*d - b*c)
185184

186185
@memoize_method
187186
def inverse_parametrization_derivative(self):
188-
[a, b], [c, d] = thaw(self.parametrization_derivative(), self._setup_actx)
187+
[a, b], [c, d] = self._setup_actx.thaw(self.parametrization_derivative())
189188

190189
result = np.zeros((2, 2), dtype=object)
191190
det = a*d-b*c
@@ -194,13 +193,13 @@ def inverse_parametrization_derivative(self):
194193
result[1, 0] = -c/det
195194
result[1, 1] = a/det
196195

197-
return freeze(result)
196+
return self._setup_actx.freeze(result)
198197

199198
def zeros(self, actx):
200199
return self.volume_discr.zeros(actx)
201200

202201
def grad(self, vec):
203-
ipder = thaw(self.inverse_parametrization_derivative(), vec.array_context)
202+
ipder = vec.array_context.thaw(self.inverse_parametrization_derivative())
204203

205204
from meshmode.discretization import num_reference_derivative
206205
dref = [
@@ -222,15 +221,15 @@ def normal(self, where):
222221
((a,), (b,)) = parametrization_derivative(self._setup_actx, bdry_discr)
223222

224223
nrm = 1/(a**2+b**2)**0.5
225-
return freeze(flat_obj_array(b*nrm, -a*nrm))
224+
return self._setup_actx.freeze(flat_obj_array(b*nrm, -a*nrm))
226225

227226
@memoize_method
228227
def face_jacobian(self, where):
229228
bdry_discr = self.get_discr(where)
230229

231230
((a,), (b,)) = parametrization_derivative(self._setup_actx, bdry_discr)
232231

233-
return freeze((a**2 + b**2)**0.5)
232+
return self._setup_actx.freeze((a**2 + b**2)**0.5)
234233

235234
@memoize_method
236235
def get_inverse_mass_matrix(self, grp, dtype):
@@ -261,7 +260,7 @@ def inverse_mass(self, vec):
261260
tagged=(FirstAxisIsElementsTag(),)
262261
) for grp, vec_i in zip(discr.groups, vec)
263262
)
264-
) / thaw(self.vol_jacobian(), actx)
263+
) / actx.thaw(self.vol_jacobian())
265264

266265
@memoize_method
267266
def get_local_face_mass_matrix(self, afgrp, volgrp, dtype):
@@ -300,7 +299,7 @@ def face_mass(self, vec):
300299
all_faces_discr = all_faces_conn.to_discr
301300
vol_discr = all_faces_conn.from_discr
302301

303-
fj = thaw(self.face_jacobian("all_faces"), vec.array_context)
302+
fj = vec.array_context.thaw(self.face_jacobian("all_faces"))
304303
vec = vec*fj
305304

306305
assert len(all_faces_discr.groups) == len(vol_discr.groups)
@@ -367,7 +366,7 @@ def wave_flux(actx, discr, c, q_tpair):
367366
u = q_tpair.u
368367
v = q_tpair.v
369368

370-
normal = thaw(discr.normal(q_tpair.where), actx)
369+
normal = actx.thaw(discr.normal(q_tpair.where))
371370

372371
flux_weak = WaveState(
373372
u=np.dot(v.avg, normal),
@@ -422,7 +421,7 @@ def bump(actx, discr, t=0):
422421
source_width = 0.05
423422
source_omega = 3
424423

425-
nodes = thaw(discr.volume_discr.nodes(), actx)
424+
nodes = actx.thaw(discr.volume_discr.nodes())
426425
center_dist = flat_obj_array([
427426
nodes[0] - source_center[0],
428427
nodes[1] - source_center[1],
@@ -492,8 +491,8 @@ def rhs(t, q):
492491
compiled_rhs = actx_rhs.compile(rhs)
493492

494493
def rhs_wrapper(t, q):
495-
r = compiled_rhs(t, thaw(freeze(q, actx_outer), actx_rhs))
496-
return thaw(freeze(r, actx_rhs), actx_outer)
494+
r = compiled_rhs(t, actx_rhs.thaw(actx_outer.freeze(q)))
495+
return actx_outer.thaw(actx_rhs.freeze(r))
497496

498497
t = np.float64(0)
499498
t_final = 3

examples/to_firedrake.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import pyopencl as cl
2626

2727
from meshmode.array_context import PyOpenCLArrayContext
28-
from arraycontext import thaw
2928

3029

3130
# Nb: Some of the initial setup was adapted from meshmode/examplse/simple-dg.py
@@ -75,7 +74,7 @@ def main():
7574
# = e^x cos(y)
7675
nodes = discr.nodes()
7776
for i in range(len(nodes)):
78-
nodes[i] = thaw(nodes[i], actx)
77+
nodes[i] = actx.thaw(nodes[i])
7978
# First index is dimension
8079
candidate_sol = actx.np.exp(nodes[0]) * actx.np.cos(nodes[1])
8180

meshmode/array_context.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ def thaw(actx, ary):
4141
"meshmode.array_context.thaw will continue to work until 2022.",
4242
DeprecationWarning, stacklevel=2)
4343

44-
from arraycontext import thaw as _thaw
45-
# /!\ arg order flipped
46-
return _thaw(ary, actx)
44+
return actx.thaw(ary)
4745

4846

4947
# {{{ kernel transform function
@@ -235,6 +233,27 @@ def transform_loopy_program(self, t_unit):
235233
# {{{ pytato pyopencl array context subclass
236234

237235
class PytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContextBase):
236+
def transform_dag(self, dag):
237+
dag = super().transform_dag(dag)
238+
239+
# {{{ /!\ Remove tags from NamedArrays
240+
# See <https://www.github.com/inducer/pytato/issues/195>
241+
242+
import pytato as pt
243+
244+
def untag_loopy_call_results(expr):
245+
if isinstance(expr, pt.NamedArray):
246+
return expr.copy(tags=frozenset(),
247+
axes=(pt.Axis(frozenset()),)*expr.ndim)
248+
else:
249+
return expr
250+
251+
dag = pt.transform.map_and_copy(dag, untag_loopy_call_results)
252+
253+
# }}}
254+
255+
return dag
256+
238257
def transform_loopy_program(self, t_unit):
239258
# FIXME: Do not parallelize for now.
240259
return t_unit

meshmode/discretization/__init__.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
import numpy as np
3737

3838
import loopy as lp
39-
from arraycontext import ArrayContext, make_loopy_program
39+
from arraycontext import ArrayContext, make_loopy_program, tag_axes
4040
from pytools import memoize_in, memoize_method, keyed_memoize_in
4141
from pytools.obj_array import make_obj_array
4242
from meshmode.transform_metadata import (
43-
ConcurrentElementInameTag, ConcurrentDOFInameTag, FirstAxisIsElementsTag)
43+
ConcurrentElementInameTag, ConcurrentDOFInameTag,
44+
FirstAxisIsElementsTag, DiscretizationElementAxisTag,
45+
DiscretizationDOFAxisTag)
4446

4547
# underscored because it shouldn't be imported from here.
4648
from meshmode.dof_array import DOFArray as _DOFArray
@@ -542,9 +544,14 @@ def _new_array(self, actx, creation_func, dtype=None):
542544
else:
543545
dtype = np.dtype(dtype)
544546

545-
return _DOFArray(actx, tuple(
546-
creation_func(shape=(grp.nelements, grp.nunit_dofs), dtype=dtype)
547-
for grp in self.groups))
547+
return tag_axes(actx, {
548+
0: DiscretizationElementAxisTag(),
549+
1: DiscretizationDOFAxisTag()},
550+
_DOFArray(actx,
551+
tuple(creation_func(shape=(grp.nelements,
552+
grp.nunit_dofs),
553+
dtype=dtype)
554+
for grp in self.groups)))
548555

549556
def empty(self, actx: ArrayContext,
550557
dtype: Optional[np.dtype] = None) -> _DOFArray:
@@ -642,21 +649,32 @@ def nodes(self, cached: bool = True) -> np.ndarray:
642649
raise ElementGroupTypeError("Element groups must be nodal.")
643650

644651
def resample_mesh_nodes(grp, iaxis):
652+
name_hint = f"nodes{iaxis}_{self.ambient_dim}d"
645653
# TODO: would be nice to have the mesh use an array context already
646-
nodes = actx.from_numpy(grp.mesh_el_group.nodes[iaxis])
654+
nodes = tag_axes(actx,
655+
{0: DiscretizationElementAxisTag(),
656+
1: DiscretizationDOFAxisTag()},
657+
actx.from_numpy(grp.mesh_el_group.nodes[iaxis]))
647658

648659
grp_unit_nodes = grp.unit_nodes.reshape(-1)
649660
meg_unit_nodes = grp.mesh_el_group.unit_nodes.reshape(-1)
650661

662+
from arraycontext.metadata import NameHint
663+
651664
tol = 10 * np.finfo(grp_unit_nodes.dtype).eps
652665
if (grp_unit_nodes.shape == meg_unit_nodes.shape
653666
and np.linalg.norm(grp_unit_nodes - meg_unit_nodes) < tol):
654-
return nodes
667+
return actx.tag(NameHint(name_hint), nodes)
655668

656669
return actx.einsum("ij,ej->ei",
657-
actx.from_numpy(grp.from_mesh_interp_matrix()),
670+
actx.tag_axis(
671+
0,
672+
DiscretizationDOFAxisTag(),
673+
actx.from_numpy(grp.from_mesh_interp_matrix())),
658674
nodes,
659-
tagged=(FirstAxisIsElementsTag(),))
675+
tagged=(
676+
FirstAxisIsElementsTag(),
677+
NameHint(name_hint)))
660678

661679
result = make_obj_array([
662680
_DOFArray(None, tuple([
@@ -714,7 +732,9 @@ def get_mat(grp, gref_axes):
714732

715733
return _DOFArray(actx, tuple(
716734
actx.einsum("ij,ej->ei",
717-
get_mat(grp, ref_axes),
735+
actx.tag_axis(0,
736+
DiscretizationDOFAxisTag(),
737+
get_mat(grp, ref_axes)),
718738
vec[igrp],
719739
tagged=(FirstAxisIsElementsTag(),))
720740
for igrp, grp in enumerate(discr.groups)))

0 commit comments

Comments
 (0)