Skip to content

Commit 90091a2

Browse files
committed
tests pick list fusion
1 parent ca19c54 commit 90091a2

File tree

1 file changed

+225
-0
lines changed

1 file changed

+225
-0
lines changed

test/test_pytato_transforms.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import numpy as np # noqa: F401
2+
import pyopencl as cl
3+
from typing import Union
4+
from meshmode.mesh import BTAG_ALL
5+
from meshmode.mesh.generation import generate_regular_rect_mesh
6+
from arraycontext.metadata import NameHint
7+
from meshmode.array_context import (PytatoPyOpenCLArrayContext,
8+
PyOpenCLArrayContext)
9+
from pytato.transform import CombineMapper
10+
from pytato.array import (Placeholder, DataWrapper, SizeParam, IndexBase,
11+
Array, DictOfNamedArrays)
12+
from meshmode.discretization.connection import (FACE_RESTR_INTERIOR,
13+
FACE_RESTR_ALL)
14+
from pytools.obj_array import make_obj_array
15+
from pyopencl.tools import ( # noqa
16+
pytest_generate_tests_for_pyopencl as pytest_generate_tests)
17+
import grudge
18+
import grudge.op as op
19+
20+
21+
# {{{ utilities for test_push_indirections_*
22+
23+
class _IndexeeArraysMaterializedChecker(CombineMapper[bool]):
24+
def combine(self, *args: bool) -> bool:
25+
return all(args)
26+
27+
def map_placeholder(self, expr: Placeholder) -> bool:
28+
return True
29+
30+
def map_data_wrapper(self, expr: DataWrapper) -> bool:
31+
return True
32+
33+
def map_size_param(self, expr: SizeParam) -> bool:
34+
return True
35+
36+
def _map_index_base(self, expr: IndexBase) -> bool:
37+
from grudge.pytato_transforms.pytato_indirection_transforms import (
38+
_is_materialized)
39+
return self.combine(
40+
_is_materialized(expr.array) or isinstance(expr.array, IndexBase),
41+
self.rec(expr.array)
42+
)
43+
44+
45+
def are_all_indexees_materialized_nodes(
46+
expr: Union[Array, DictOfNamedArrays]) -> bool:
47+
"""
48+
Returns *True* only if all indexee arrays are either materialized nodes,
49+
OR, other indexing nodes that have materialized indexees.
50+
"""
51+
return _IndexeeArraysMaterializedChecker()(expr)
52+
53+
54+
class _IndexerArrayDatawrapperChecker(CombineMapper[bool]):
55+
def combine(self, *args: bool) -> bool:
56+
return all(args)
57+
58+
def map_placeholder(self, expr: Placeholder) -> bool:
59+
return True
60+
61+
def map_data_wrapper(self, expr: DataWrapper) -> bool:
62+
return True
63+
64+
def map_size_param(self, expr: SizeParam) -> bool:
65+
return True
66+
67+
def _map_index_base(self, expr: IndexBase) -> bool:
68+
return self.combine(
69+
*[isinstance(idx, DataWrapper)
70+
for idx in expr.indices
71+
if isinstance(idx, Array)],
72+
super()._map_index_base(expr),
73+
)
74+
75+
76+
def are_all_indexer_arrays_datawrappers(
77+
expr: Union[Array, DictOfNamedArrays]) -> bool:
78+
"""
79+
Returns *True* only if all indexer arrays are instances of
80+
:class:`~pytato.array.DataWrapper`.
81+
"""
82+
return _IndexerArrayDatawrapperChecker()(expr)
83+
84+
# }}}
85+
86+
87+
def _evaluate_dict_of_named_arrays(actx, dict_of_named_arrays):
88+
container = make_obj_array([dict_of_named_arrays._data[name]
89+
for name in sorted(dict_of_named_arrays.keys())])
90+
91+
evaluated_container = actx.thaw(actx.freeze(container))
92+
93+
return {name: evaluated_container[i]
94+
for i, name in enumerate(sorted(dict_of_named_arrays.keys()))}
95+
96+
97+
class FluxOptimizerActx(PytatoPyOpenCLArrayContext):
98+
def __init__(self, *args, **kwargs):
99+
super().__init__(*args, **kwargs)
100+
self.check_completed = False
101+
102+
def transform_dag(self, dag):
103+
from grudge.pytato_transforms.pytato_indirection_transforms import (
104+
fuse_dof_pick_lists, fold_constant_indirections)
105+
from pytato.tags import PrefixNamed
106+
107+
if (
108+
len(dag) == 1
109+
and PrefixNamed("flux_container") in list(dag._data.values())[0].tags
110+
):
111+
assert not are_all_indexer_arrays_datawrappers(dag)
112+
self.check_completed = True
113+
114+
dag = fuse_dof_pick_lists(dag)
115+
dag = fold_constant_indirections(
116+
dag, lambda x: _evaluate_dict_of_named_arrays(self, x))
117+
118+
if (
119+
len(dag) == 1
120+
and PrefixNamed("flux_container") in list(dag._data.values())[0].tags
121+
):
122+
assert are_all_indexer_arrays_datawrappers(dag)
123+
self.check_completed = True
124+
125+
return dag
126+
127+
128+
# {{{ test_resampling_indirections_are_fused_0
129+
130+
def _compute_flux_0(dcoll, actx, u):
131+
u_interior_tpair, = op.interior_trace_pairs(dcoll, u)
132+
flux_on_interior_faces = u_interior_tpair.avg
133+
flux_on_all_faces = op.project(
134+
dcoll, FACE_RESTR_INTERIOR, FACE_RESTR_ALL, flux_on_interior_faces)
135+
136+
flux_on_all_faces = actx.tag(NameHint("flux_container"), flux_on_all_faces)
137+
return flux_on_all_faces
138+
139+
140+
def test_resampling_indirections_are_fused_0(ctx_factory):
141+
cl_ctx = ctx_factory()
142+
cq = cl.CommandQueue(cl_ctx)
143+
144+
ref_actx = PyOpenCLArrayContext(cq)
145+
actx = FluxOptimizerActx(cq)
146+
147+
dim = 3
148+
nel_1d = 4
149+
mesh = generate_regular_rect_mesh(
150+
a=(-0.5,)*dim,
151+
b=(0.5,)*dim,
152+
nelements_per_axis=(nel_1d,)*dim,
153+
boundary_tag_to_face={"bdry": ["-x", "+x",
154+
"-y", "+y",
155+
"-z", "+z"]}
156+
)
157+
dcoll = grudge.make_discretization_collection(ref_actx, mesh, order=2)
158+
159+
x, _, _ = dcoll.nodes()
160+
compiled_flux_0 = actx.compile(lambda ary: _compute_flux_0(dcoll, actx, ary))
161+
162+
ref_output = ref_actx.to_numpy(
163+
_compute_flux_0(dcoll, ref_actx, ref_actx.thaw(x)))
164+
output = actx.to_numpy(
165+
compiled_flux_0(actx.thaw(x)))
166+
167+
np.testing.assert_allclose(ref_output[0], output[0])
168+
assert actx.check_completed
169+
170+
# }}}
171+
172+
173+
# {{{ test_resampling_indirections_are_fused_1
174+
175+
def _compute_flux_1(dcoll, actx, u):
176+
u_interior_tpair, = op.interior_trace_pairs(dcoll, u)
177+
flux_on_interior_faces = u_interior_tpair.avg
178+
flux_on_bdry = op.project(dcoll, "vol", BTAG_ALL, u)
179+
flux_on_all_faces = (
180+
op.project(dcoll,
181+
FACE_RESTR_INTERIOR,
182+
FACE_RESTR_ALL,
183+
flux_on_interior_faces)
184+
+ op.project(dcoll, BTAG_ALL, FACE_RESTR_ALL, flux_on_bdry)
185+
)
186+
187+
result = op.inverse_mass(dcoll, op.face_mass(dcoll, flux_on_all_faces))
188+
189+
result = actx.tag(NameHint("flux_container"), result)
190+
return result
191+
192+
193+
def test_resampling_indirections_are_fused_1(ctx_factory):
194+
cl_ctx = ctx_factory()
195+
cq = cl.CommandQueue(cl_ctx)
196+
197+
ref_actx = PyOpenCLArrayContext(cq)
198+
actx = FluxOptimizerActx(cq)
199+
200+
dim = 3
201+
nel_1d = 4
202+
mesh = generate_regular_rect_mesh(
203+
a=(-0.5,)*dim,
204+
b=(0.5,)*dim,
205+
nelements_per_axis=(nel_1d,)*dim,
206+
boundary_tag_to_face={"bdry": ["-x", "+x",
207+
"-y", "+y",
208+
"-z", "+z"]}
209+
)
210+
dcoll = grudge.make_discretization_collection(ref_actx, mesh, order=2)
211+
212+
x, _, _ = dcoll.nodes()
213+
compiled_flux_1 = actx.compile(lambda ary: _compute_flux_1(dcoll, actx, ary))
214+
215+
ref_output = ref_actx.to_numpy(
216+
_compute_flux_1(dcoll, ref_actx, ref_actx.thaw(x)))
217+
output = actx.to_numpy(
218+
compiled_flux_1(actx.thaw(x)))
219+
220+
np.testing.assert_allclose(ref_output[0], output[0])
221+
assert actx.check_completed
222+
223+
# }}}
224+
225+
# vim: fdm=marker

0 commit comments

Comments
 (0)