3
3
4
4
import numpy as np
5
5
6
- from larray .core .array import LArray , raw_broadcastable
6
+ from larray .core .array import LArray , make_args_broadcastable
7
7
8
8
9
9
def broadcastify (func ):
10
10
# intentionally not using functools.wraps, because it does not work for wrapping a function from another module
11
11
def wrapper (* args , ** kwargs ):
12
- # TODO: normalize args/kwargs like in LIAM2 so that we can also broadcast if args are given via kwargs
13
- # (eg out=)
14
- raw_args , combined_axes = raw_broadcastable (args )
12
+ raw_bcast_args , raw_bcast_kwargs , res_axes = make_args_broadcastable (args , kwargs )
15
13
16
14
# We pass only raw numpy arrays to the ufuncs even though numpy is normally meant to handle those cases itself
17
15
# via __array_wrap__
@@ -25,9 +23,9 @@ def wrapper(*args, **kwargs):
25
23
# It fails on "np.minimum(ndarray, LArray)" because it calls __array_wrap__(high, result) which cannot work if
26
24
# there was broadcasting involved (high has potentially less labels than result).
27
25
# it does this because numpy calls __array_wrap__ on the argument with the highest __array_priority__
28
- res_data = func (* raw_args , ** kwargs )
29
- if combined_axes :
30
- return LArray (res_data , combined_axes )
26
+ res_data = func (* raw_bcast_args , ** raw_bcast_kwargs )
27
+ if res_axes :
28
+ return LArray (res_data , res_axes )
31
29
else :
32
30
return res_data
33
31
# copy meaningful attributes (numpy ufuncs do not have __annotations__ nor __qualname__)
0 commit comments