Skip to content

Commit 30d3b77

Browse files
committed
use array context reductions in nodal reductions
1 parent bc521de commit 30d3b77

File tree

1 file changed

+30
-39
lines changed

1 file changed

+30
-39
lines changed

grudge/reductions.py

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from arraycontext import (
6363
make_loopy_program,
6464
map_array_container,
65-
serialize_container,
65+
get_container_context_recursively,
6666
DeviceScalar
6767
)
6868
from arraycontext.container import ArrayOrContainerT
@@ -94,7 +94,6 @@ def norm(dcoll: DiscretizationCollection, vec, p, dd=None) -> "DeviceScalar":
9494
if dd is None:
9595
dd = dof_desc.DD_VOLUME
9696

97-
from arraycontext import get_container_context_recursively
9897
actx = get_container_context_recursively(vec)
9998

10099
dd = dof_desc.as_dofdesc(dd)
@@ -128,7 +127,7 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
128127

129128
# NOTE: Don't move this
130129
from mpi4py import MPI
131-
actx = vec.array_context
130+
actx = get_container_context_recursively(vec)
132131

133132
return actx.from_numpy(
134133
comm.allreduce(actx.to_numpy(nodal_sum_loc(dcoll, dd, vec)), op=MPI.SUM))
@@ -143,15 +142,13 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
143142
:class:`~arraycontext.container.ArrayContainer` of them.
144143
:returns: a scalar denoting the rank-local nodal sum.
145144
"""
146-
if not isinstance(vec, DOFArray):
147-
return sum(
148-
nodal_sum_loc(dcoll, dd, comp)
149-
for _, comp in serialize_container(vec)
150-
)
151-
152-
actx = vec.array_context
153-
154-
return sum([actx.np.sum(grp_ary) for grp_ary in vec])
145+
actx = get_container_context_recursively(vec)
146+
result = actx.np.sum(vec)
147+
# Fix actx._force_device_scalars == False case
148+
if np.isscalar(result):
149+
return actx.from_numpy(result)
150+
else:
151+
return result
155152

156153

157154
def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
@@ -169,7 +166,7 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
169166

170167
# NOTE: Don't move this
171168
from mpi4py import MPI
172-
actx = vec.array_context
169+
actx = get_container_context_recursively(vec)
173170

174171
return actx.from_numpy(
175172
comm.allreduce(actx.to_numpy(nodal_min_loc(dcoll, dd, vec)), op=MPI.MIN))
@@ -185,17 +182,13 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
185182
:class:`~arraycontext.container.ArrayContainer` of them.
186183
:returns: a scalar denoting the rank-local nodal minimum.
187184
"""
188-
if not isinstance(vec, DOFArray):
189-
return min(
190-
nodal_min_loc(dcoll, dd, comp)
191-
for _, comp in serialize_container(vec)
192-
)
193-
194-
actx = vec.array_context
195-
196-
return reduce(
197-
lambda acc, grp_ary: actx.np.minimum(acc, actx.np.min(grp_ary)),
198-
vec, actx.from_numpy(np.array(np.inf)))
185+
actx = get_container_context_recursively(vec)
186+
result = actx.np.min(vec)
187+
# Fix actx._force_device_scalars == False case
188+
if np.isscalar(result):
189+
return actx.from_numpy(result)
190+
else:
191+
return result
199192

200193

201194
def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
@@ -213,7 +206,7 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
213206

214207
# NOTE: Don't move this
215208
from mpi4py import MPI
216-
actx = vec.array_context
209+
actx = get_container_context_recursively(vec)
217210

218211
return actx.from_numpy(
219212
comm.allreduce(actx.to_numpy(nodal_max_loc(dcoll, dd, vec)), op=MPI.MAX))
@@ -229,17 +222,13 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
229222
:class:`~arraycontext.container.ArrayContainer`.
230223
:returns: a scalar denoting the rank-local nodal maximum.
231224
"""
232-
if not isinstance(vec, DOFArray):
233-
return max(
234-
nodal_max_loc(dcoll, dd, comp)
235-
for _, comp in serialize_container(vec)
236-
)
237-
238-
actx = vec.array_context
239-
240-
return reduce(
241-
lambda acc, grp_ary: actx.np.maximum(acc, actx.np.max(grp_ary)),
242-
vec, actx.from_numpy(np.array(-np.inf)))
225+
actx = get_container_context_recursively(vec)
226+
result = actx.np.max(vec)
227+
# Fix actx._force_device_scalars == False case
228+
if np.isscalar(result):
229+
return actx.from_numpy(result)
230+
else:
231+
return result
243232

244233

245234
def integral(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
@@ -253,9 +242,10 @@ def integral(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
253242
"""
254243
from grudge.op import _apply_mass_operator
255244

245+
actx = get_container_context_recursively(vec)
256246
dd = dof_desc.as_dofdesc(dd)
257247

258-
ones = dcoll.discr_from_dd(dd).zeros(vec.array_context) + 1.0
248+
ones = dcoll.discr_from_dd(dd).zeros(actx) + 1.0
259249
return nodal_sum(
260250
dcoll, dd, vec * _apply_mass_operator(dcoll, dd, dd, ones)
261251
)
@@ -295,7 +285,7 @@ def _apply_elementwise_reduction(
295285
partial(_apply_elementwise_reduction, op_name, dcoll, dd), vec
296286
)
297287

298-
actx = vec.array_context
288+
actx = get_container_context_recursively(vec)
299289

300290
if actx.supports_nonscalar_broadcasting:
301291
return DOFArray(
@@ -456,11 +446,12 @@ def elementwise_integral(
456446
else:
457447
raise TypeError("invalid number of arguments")
458448

449+
actx = get_container_context_recursively(vec)
459450
dd = dof_desc.as_dofdesc(dd)
460451

461452
from grudge.op import _apply_mass_operator
462453

463-
ones = dcoll.discr_from_dd(dd).zeros(vec.array_context) + 1.0
454+
ones = dcoll.discr_from_dd(dd).zeros(actx) + 1.0
464455
return elementwise_sum(
465456
dcoll, dd, vec * _apply_mass_operator(dcoll, dd, dd, ones)
466457
)

0 commit comments

Comments
 (0)