@@ -147,6 +147,10 @@ def _get_fake_numpy_linalg_namespace(self):
147147
148148 def __getattr__ (self , name ):
149149 def loopy_implemented_elwise_func (* args ):
150+ from numbers import Number
151+ if all (isinstance (ary , Number ) for ary in args ):
152+ return getattr (np , name )(* args )
153+
150154 actx = self ._array_context
151155 prg = _get_scalar_func_loopy_program (actx ,
152156 c_name , nargs = len (args ), naxes = len (args [0 ].shape ))
@@ -221,6 +225,26 @@ def _scalar_list_norm(ary, ord):
221225 raise NotImplementedError (f"unsupported value of 'ord': { ord } " )
222226
223227
228+ def _reduce_norm (actx , arys , ord ):
229+ from numbers import Number
230+ from functools import reduce
231+
232+ if ord is None :
233+ ord = 2
234+
235+ # NOTE: these are ordered by an expected usage frequency
236+ if ord == 2 :
237+ return actx .np .sqrt (sum (subary * subary for subary in arys ))
238+ elif ord == np .inf :
239+ return reduce (actx .np .maximum , arys )
240+ elif ord == - np .inf :
241+ return reduce (actx .np .minimum , arys )
242+ elif isinstance (ord , Number ) and ord > 0 :
243+ return sum (subary ** ord for subary in arys )** (1 / ord )
244+ else :
245+ raise NotImplementedError (f"unsupported value of 'ord': { ord } " )
246+
247+
224248class BaseFakeNumpyLinalgNamespace :
225249 def __init__ (self , array_context ):
226250 self ._array_context = array_context
@@ -250,7 +274,7 @@ def norm(self, ary, ord=None):
250274 return flat_norm (ary , ord = ord )
251275
252276 if is_array_container (ary ):
253- return _scalar_list_norm ( [
277+ return _reduce_norm ( actx , [
254278 self .norm (subary , ord = ord )
255279 for _ , subary in serialize_container (ary )
256280 ], ord = ord )
@@ -262,14 +286,16 @@ def norm(self, ary, ord=None):
262286 raise NotImplementedError ("only vector norms are implemented" )
263287
264288 if ary .size == 0 :
265- return 0
289+ return ary . dtype . type ( 0 )
266290
291+ if ord == 2 :
292+ return actx .np .sqrt (actx .np .sum (abs (ary )** 2 ))
267293 if ord == np .inf :
268- return self . _array_context .np .max (abs (ary ))
294+ return actx .np .max (abs (ary ))
269295 elif ord == - np .inf :
270- return self . _array_context .np .min (abs (ary ))
296+ return actx .np .min (abs (ary ))
271297 elif isinstance (ord , Number ) and ord > 0 :
272- return self . _array_context .np .sum (abs (ary )** ord )** (1 / ord )
298+ return actx .np .sum (abs (ary )** ord )** (1 / ord )
273299 else :
274300 raise NotImplementedError (f"unsupported value of 'ord': { ord } " )
275301# }}}
0 commit comments