Skip to content

Commit 40b467a

Browse files
committed
broadcast ufuncs kwargs
1 parent c3b32c0 commit 40b467a

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

larray/core/array.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8541,6 +8541,18 @@ def raw_broadcastable(values, min_axes=None):
85418541
return raw, res_axes
85428542

85438543

8544+
def make_args_broadcastable(args, kwargs=None, min_axes=None):
8545+
"""
8546+
Make args and kwargs (NumPy) broadcastable between them.
8547+
"""
8548+
values = (args + tuple(kwargs.values())) if kwargs is not None else args
8549+
first_kw = len(args)
8550+
raw_bcast_values, res_axes = raw_broadcastable(values, min_axes=min_axes)
8551+
raw_bcast_args = raw_bcast_values[:first_kw]
8552+
raw_bcast_kwargs = dict(zip(kwargs.keys(), raw_bcast_values[first_kw:]))
8553+
return raw_bcast_args, raw_bcast_kwargs, res_axes
8554+
8555+
85448556
_default_float_error_handler = float_error_handler_factory(3)
85458557

85468558

larray/core/ufuncs.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@
33

44
import numpy as np
55

6-
from larray.core.array import LArray, raw_broadcastable
6+
from larray.core.array import LArray, make_args_broadcastable
77

88

99
def broadcastify(func):
1010
# intentionally not using functools.wraps, because it does not work for wrapping a function from another module
1111
def wrapper(*args, **kwargs):
12-
# TODO: normalize args/kwargs like in LIAM2 so that we can also broadcast if args are given via kwargs
13-
# (eg out=)
14-
raw_args, combined_axes = raw_broadcastable(args)
12+
raw_bcast_args, raw_bcast_kwargs, res_axes = make_args_broadcastable(args, kwargs)
1513

1614
# We pass only raw numpy arrays to the ufuncs even though numpy is normally meant to handle those cases itself
1715
# via __array_wrap__
@@ -25,9 +23,9 @@ def wrapper(*args, **kwargs):
2523
# It fails on "np.minimum(ndarray, LArray)" because it calls __array_wrap__(high, result) which cannot work if
2624
# there was broadcasting involved (high has potentially less labels than result).
2725
# it does this because numpy calls __array_wrap__ on the argument with the highest __array_priority__
28-
res_data = func(*raw_args, **kwargs)
29-
if combined_axes:
30-
return LArray(res_data, combined_axes)
26+
res_data = func(*raw_bcast_args, **raw_bcast_kwargs)
27+
if res_axes:
28+
return LArray(res_data, res_axes)
3129
else:
3230
return res_data
3331
# copy meaningful attributes (numpy ufuncs do not have __annotations__ nor __qualname__)

0 commit comments

Comments
 (0)