6262from arraycontext import (
6363 make_loopy_program ,
6464 map_array_container ,
65- serialize_container ,
65+ get_container_context_recursively ,
6666 DeviceScalar
6767)
6868from arraycontext .container import ArrayOrContainerT
@@ -94,7 +94,6 @@ def norm(dcoll: DiscretizationCollection, vec, p, dd=None) -> "DeviceScalar":
9494 if dd is None :
9595 dd = dof_desc .DD_VOLUME
9696
97- from arraycontext import get_container_context_recursively
9897 actx = get_container_context_recursively (vec )
9998
10099 dd = dof_desc .as_dofdesc (dd )
@@ -128,7 +127,7 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
128127
129128 # NOTE: Don't move this
130129 from mpi4py import MPI
131- actx = vec . array_context
130+ actx = get_container_context_recursively ( vec )
132131
133132 return actx .from_numpy (
134133 comm .allreduce (actx .to_numpy (nodal_sum_loc (dcoll , dd , vec )), op = MPI .SUM ))
@@ -143,15 +142,13 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
143142 :class:`~arraycontext.container.ArrayContainer` of them.
144143 :returns: a scalar denoting the rank-local nodal sum.
145144 """
146- if not isinstance (vec , DOFArray ):
147- return sum (
148- nodal_sum_loc (dcoll , dd , comp )
149- for _ , comp in serialize_container (vec )
150- )
151-
152- actx = vec .array_context
153-
154- return sum ([actx .np .sum (grp_ary ) for grp_ary in vec ])
145+ actx = get_container_context_recursively (vec )
146+ result = actx .np .sum (vec )
147+ # Fix actx._force_device_scalars == False case
148+ if np .isscalar (result ):
149+ return actx .from_numpy (result )
150+ else :
151+ return result
155152
156153
157154def nodal_min (dcoll : DiscretizationCollection , dd , vec ) -> "DeviceScalar" :
@@ -169,7 +166,7 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
169166
170167 # NOTE: Don't move this
171168 from mpi4py import MPI
172- actx = vec . array_context
169+ actx = get_container_context_recursively ( vec )
173170
174171 return actx .from_numpy (
175172 comm .allreduce (actx .to_numpy (nodal_min_loc (dcoll , dd , vec )), op = MPI .MIN ))
@@ -185,17 +182,13 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
185182 :class:`~arraycontext.container.ArrayContainer` of them.
186183 :returns: a scalar denoting the rank-local nodal minimum.
187184 """
188- if not isinstance (vec , DOFArray ):
189- return min (
190- nodal_min_loc (dcoll , dd , comp )
191- for _ , comp in serialize_container (vec )
192- )
193-
194- actx = vec .array_context
195-
196- return reduce (
197- lambda acc , grp_ary : actx .np .minimum (acc , actx .np .min (grp_ary )),
198- vec , actx .from_numpy (np .array (np .inf )))
185+ actx = get_container_context_recursively (vec )
186+ result = actx .np .min (vec )
187+ # Fix actx._force_device_scalars == False case
188+ if np .isscalar (result ):
189+ return actx .from_numpy (result )
190+ else :
191+ return result
199192
200193
201194def nodal_max (dcoll : DiscretizationCollection , dd , vec ) -> "DeviceScalar" :
@@ -213,7 +206,7 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
213206
214207 # NOTE: Don't move this
215208 from mpi4py import MPI
216- actx = vec . array_context
209+ actx = get_container_context_recursively ( vec )
217210
218211 return actx .from_numpy (
219212 comm .allreduce (actx .to_numpy (nodal_max_loc (dcoll , dd , vec )), op = MPI .MAX ))
@@ -229,17 +222,13 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
229222 :class:`~arraycontext.container.ArrayContainer`.
230223 :returns: a scalar denoting the rank-local nodal maximum.
231224 """
232- if not isinstance (vec , DOFArray ):
233- return max (
234- nodal_max_loc (dcoll , dd , comp )
235- for _ , comp in serialize_container (vec )
236- )
237-
238- actx = vec .array_context
239-
240- return reduce (
241- lambda acc , grp_ary : actx .np .maximum (acc , actx .np .max (grp_ary )),
242- vec , actx .from_numpy (np .array (- np .inf )))
225+ actx = get_container_context_recursively (vec )
226+ result = actx .np .max (vec )
227+ # Fix actx._force_device_scalars == False case
228+ if np .isscalar (result ):
229+ return actx .from_numpy (result )
230+ else :
231+ return result
243232
244233
245234def integral (dcoll : DiscretizationCollection , dd , vec ) -> "DeviceScalar" :
@@ -253,9 +242,10 @@ def integral(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
253242 """
254243 from grudge .op import _apply_mass_operator
255244
245+ actx = get_container_context_recursively (vec )
256246 dd = dof_desc .as_dofdesc (dd )
257247
258- ones = dcoll .discr_from_dd (dd ).zeros (vec . array_context ) + 1.0
248+ ones = dcoll .discr_from_dd (dd ).zeros (actx ) + 1.0
259249 return nodal_sum (
260250 dcoll , dd , vec * _apply_mass_operator (dcoll , dd , dd , ones )
261251 )
@@ -295,7 +285,7 @@ def _apply_elementwise_reduction(
295285 partial (_apply_elementwise_reduction , op_name , dcoll , dd ), vec
296286 )
297287
298- actx = vec . array_context
288+ actx = get_container_context_recursively ( vec )
299289
300290 if actx .supports_nonscalar_broadcasting :
301291 return DOFArray (
@@ -456,11 +446,12 @@ def elementwise_integral(
456446 else :
457447 raise TypeError ("invalid number of arguments" )
458448
449+ actx = get_container_context_recursively (vec )
459450 dd = dof_desc .as_dofdesc (dd )
460451
461452 from grudge .op import _apply_mass_operator
462453
463- ones = dcoll .discr_from_dd (dd ).zeros (vec . array_context ) + 1.0
454+ ones = dcoll .discr_from_dd (dd ).zeros (actx ) + 1.0
464455 return elementwise_sum (
465456 dcoll , dd , vec * _apply_mass_operator (dcoll , dd , dd , ones )
466457 )
0 commit comments