Skip to content

Commit a90e8ba

Browse files
alexfiklinducer
authored andcommitted
do not force host transfers when computing norms
1 parent 133e8fa commit a90e8ba

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

arraycontext/fake_numpy.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ def _get_fake_numpy_linalg_namespace(self):
147147

148148
def __getattr__(self, name):
149149
def loopy_implemented_elwise_func(*args):
150+
from numbers import Number
151+
if all(isinstance(ary, Number) for ary in args):
152+
return getattr(np, name)(*args)
153+
150154
actx = self._array_context
151155
prg = _get_scalar_func_loopy_program(actx,
152156
c_name, nargs=len(args), naxes=len(args[0].shape))
@@ -221,6 +225,26 @@ def _scalar_list_norm(ary, ord):
221225
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
222226

223227

228+
def _reduce_norm(actx, arys, ord):
229+
from numbers import Number
230+
from functools import reduce
231+
232+
if ord is None:
233+
ord = 2
234+
235+
# NOTE: these are ordered by an expected usage frequency
236+
if ord == 2:
237+
return actx.np.sqrt(sum(subary*subary for subary in arys))
238+
elif ord == np.inf:
239+
return reduce(actx.np.maximum, arys)
240+
elif ord == -np.inf:
241+
return reduce(actx.np.minimum, arys)
242+
elif isinstance(ord, Number) and ord > 0:
243+
return sum(subary**ord for subary in arys)**(1/ord)
244+
else:
245+
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
246+
247+
224248
class BaseFakeNumpyLinalgNamespace:
225249
def __init__(self, array_context):
226250
self._array_context = array_context
@@ -250,7 +274,7 @@ def norm(self, ary, ord=None):
250274
return flat_norm(ary, ord=ord)
251275

252276
if is_array_container(ary):
253-
return _scalar_list_norm([
277+
return _reduce_norm(actx, [
254278
self.norm(subary, ord=ord)
255279
for _, subary in serialize_container(ary)
256280
], ord=ord)
@@ -262,14 +286,16 @@ def norm(self, ary, ord=None):
262286
raise NotImplementedError("only vector norms are implemented")
263287

264288
if ary.size == 0:
265-
return 0
289+
return ary.dtype.type(0)
266290

291+
if ord == 2:
292+
return actx.np.sqrt(actx.np.sum(abs(ary)**2))
267293
if ord == np.inf:
268-
return self._array_context.np.max(abs(ary))
294+
return actx.np.max(abs(ary))
269295
elif ord == -np.inf:
270-
return self._array_context.np.min(abs(ary))
296+
return actx.np.min(abs(ary))
271297
elif isinstance(ord, Number) and ord > 0:
272-
return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
298+
return actx.np.sum(abs(ary)**ord)**(1/ord)
273299
else:
274300
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
275301
# }}}

0 commit comments

Comments
 (0)