Skip to content

Commit 57fef5a

Browse files
committed
extracted the non-numpy specific bits of broadcastify into a public API function: wrap_elementwise_array_func
1 parent 3de198d commit 57fef5a

File tree

4 files changed

+61
-5
lines changed

4 files changed

+61
-5
lines changed

doc/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,7 @@ Miscellaneous
673673
diag
674674
eye
675675
ipfp
676+
wrap_elementwise_array_func
676677

677678
.. _api-session:
678679

doc/source/changes/version_0_30.rst.inc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ New features
126126
* added :py:obj:`Session.update()` method to add and modify items from an existing session by passing
127127
either another session or a dict-like object or an iterable object with (key, value) pairs (closes :issue:`754`).
128128

129+
* implemented :py:obj:`wrap_elementwise_array_func()` function to make a function defined in another library work with
130+
LArray arguments instead of with numpy arrays.
131+
129132

130133
Miscellaneous improvements
131134
^^^^^^^^^^^^^^^^^^^^^^^^^^

larray/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from larray.core.session import Session, local_arrays, global_arrays, arrays
1313
from larray.core.constants import nan, inf, pi, e, euler_gamma
1414
from larray.core.metadata import Metadata
15-
from larray.core.ufuncs import maximum, minimum, where
15+
from larray.core.ufuncs import wrap_elementwise_array_func, maximum, minimum, where
1616
from larray.core.npufuncs import (sin, cos, tan, arcsin, arccos, arctan, hypot, arctan2, degrees,
1717
radians, unwrap, sinh, cosh, tanh, arcsinh, arccosh, arctanh,
1818
angle, real, imag, conj,
@@ -62,6 +62,7 @@
6262
# metadata
6363
'Metadata',
6464
# ufuncs
65+
'wrap_elementwise_array_func',
6566
'maximum', 'minimum', 'where',
6667
'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', 'hypot', 'arctan2', 'degrees', 'radians',
6768
'unwrap', 'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',

larray/core/ufuncs.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,49 @@
66
from larray.core.array import LArray, make_args_broadcastable
77

88

9-
def broadcastify(func):
10-
# intentionally not using functools.wraps, because it does not work for wrapping a function from another module
9+
def wrap_elementwise_array_func(func):
10+
r"""
11+
Wrap a function using numpy arrays to work with LArray arrays instead.
12+
13+
Parameters
14+
----------
15+
func : function
16+
A function taking numpy arrays as arguments and returning numpy arrays of the same shape. If the function
17+
takes several arguments, this wrapping code assumes the result will have the combination of all axes present.
18+
In numpy talk, arguments will be broadcasted to each other.
19+
20+
Returns
21+
-------
22+
function
23+
A function taking LArray arguments and returning LArrays.
24+
25+
Examples
26+
--------
27+
For example, if we want to apply the Hodrick-Prescott filter from statsmodels we can use this:
28+
29+
>>> from statsmodels.tsa.filters.hp_filter import hpfilter # doctest: +SKIP
30+
>>> hpfilter = wrap_elementwise_array_func(hpfilter) # doctest: +SKIP
31+
32+
hpfilter is now a function taking a one dimensional LArray as input and returning a one dimensional LArray as output
33+
34+
Now let us suppose we have a ND array such as:
35+
36+
>>> from larray.random import normal
37+
>>> arr = normal(axes="sex=M,F;year=2016..2018") # doctest: +SKIP
38+
>>> arr # doctest: +SKIP
39+
sex\year 2016 2017 2018
40+
M -1.15 0.56 -1.06
41+
F -0.48 -0.39 -0.98
42+
43+
We can apply an Hodrick-Prescott filter to it by using:
44+
45+
>>> # 6.25 is the recommended smoothing value for annual data
46+
>>> cycle, trend = arr.apply(hpfilter, 6.25, axes="year") # doctest: +SKIP
47+
>>> trend # doctest: +SKIP
48+
sex\year 2016 2017 2018
49+
M -0.61 -0.52 -0.52
50+
F -0.37 -0.61 -0.87
51+
"""
1152
def wrapper(*args, **kwargs):
1253
raw_bcast_args, raw_bcast_kwargs, res_axes = make_args_broadcastable(args, kwargs)
1354

@@ -25,11 +66,21 @@ def wrapper(*args, **kwargs):
2566
# it does this because numpy calls __array_wrap__ on the argument with the highest __array_priority__
2667
res_data = func(*raw_bcast_args, **raw_bcast_kwargs)
2768
if res_axes:
28-
return LArray(res_data, res_axes)
69+
if isinstance(res_data, tuple):
70+
return tuple(LArray(res_arr, res_axes) for res_arr in res_data)
71+
else:
72+
return LArray(res_data, res_axes)
2973
else:
3074
return res_data
31-
# copy meaningful attributes (numpy ufuncs do not have __annotations__ nor __qualname__)
75+
# copy function name. We are intentionally not using functools.wraps, because it does not work for wrapping a
76+
# function from another module
3277
wrapper.__name__ = func.__name__
78+
return wrapper
79+
80+
81+
# TODO: rename to wrap_numpy_func
82+
def broadcastify(func):
83+
wrapper = wrap_elementwise_array_func(func)
3384
# update documentation by inserting a warning message after the short description of the numpy function
3485
# (otherwise the description of ufuncs given in the corresponding API 'autosummary' tables will always
3586
# start with 'larray specific variant of ...' without giving a meaningful description of what does the ufunc)

0 commit comments

Comments
 (0)