2424
2525
2626import numpy as np
27- from arraycontext .container import is_array_container , serialize_container
27+ from arraycontext .container import is_array_container_type , serialize_container
2828from arraycontext .container .traversal import (
2929 rec_map_array_container , multimapped_over_array_containers )
3030from pytools import memoize_in
@@ -182,7 +182,7 @@ def _new_like(self, ary, alloc_like):
182182 # e.g. `np.zeros_like(x)` returns `array([0, 0, ...], dtype=object)`
183183 # FIXME: what about object arrays nested in an ArrayContainer?
184184 raise NotImplementedError ("operation not implemented for object arrays" )
185- elif is_array_container (ary ):
185+ elif is_array_container_type (ary . __class__ ):
186186 return rec_map_array_container (alloc_like , ary )
187187 elif isinstance (ary , Number ):
188188 # NOTE: `np.zeros_like(x)` returns `array(x, shape=())`, which
@@ -225,6 +225,26 @@ def _scalar_list_norm(ary, ord):
225225 raise NotImplementedError (f"unsupported value of 'ord': { ord } " )
226226
227227
228+ def _reduce_norm (actx , arys , ord ):
229+ from numbers import Number
230+ from functools import reduce
231+
232+ if ord is None :
233+ ord = 2
234+
235+ # NOTE: these are ordered by an expected usage frequency
236+ if ord == 2 :
237+ return actx .np .sqrt (sum (subary * subary for subary in arys ))
238+ elif ord == np .inf :
239+ return reduce (actx .np .maximum , arys )
240+ elif ord == - np .inf :
241+ return reduce (actx .np .minimum , arys )
242+ elif isinstance (ord , Number ) and ord > 0 :
243+ return sum (subary ** ord for subary in arys )** (1 / ord )
244+ else :
245+ raise NotImplementedError (f"unsupported value of 'ord': { ord } " )
246+
247+
228248class BaseFakeNumpyLinalgNamespace :
229249 def __init__ (self , array_context ):
230250 self ._array_context = array_context
@@ -253,8 +273,8 @@ def norm(self, ary, ord=None):
253273
254274 return flat_norm (ary , ord = ord )
255275
256- if is_array_container (ary ):
257- return _scalar_list_norm ( [
276+ if is_array_container_type (ary . __class__ ):
277+ return _reduce_norm ( actx , [
258278 self .norm (subary , ord = ord )
259279 for _ , subary in serialize_container (ary )
260280 ], ord = ord )
@@ -266,14 +286,16 @@ def norm(self, ary, ord=None):
266286 raise NotImplementedError ("only vector norms are implemented" )
267287
268288 if ary .size == 0 :
269- return 0
289+ return ary . dtype . type ( 0 )
270290
291+ if ord == 2 :
292+ return actx .np .sqrt (actx .np .sum (abs (ary )** 2 ))
271293 if ord == np .inf :
272- return self . _array_context .np .max (abs (ary ))
294+ return actx .np .max (abs (ary ))
273295 elif ord == - np .inf :
274- return self . _array_context .np .min (abs (ary ))
296+ return actx .np .min (abs (ary ))
275297 elif isinstance (ord , Number ) and ord > 0 :
276- return self . _array_context .np .sum (abs (ary )** ord )** (1 / ord )
298+ return actx .np .sum (abs (ary )** ord )** (1 / ord )
277299 else :
278300 raise NotImplementedError (f"unsupported value of 'ord': { ord } " )
279301# }}}
0 commit comments