Skip to content

Commit e22b022

Browse files
committed
implemented LArray.apply
1 parent bf270d4 commit e22b022

File tree

3 files changed

+165
-1
lines changed

3 files changed

+165
-1
lines changed

doc/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ Modifying/Selecting
319319
LArray.drop
320320
LArray.ignore_labels
321321
LArray.filter
322+
LArray.apply
322323

323324
.. _la_axes_labels:
324325

doc/source/changes/version_0_30.rst.inc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.. py:currentmodule:: larray
1+
.. py:currentmodule:: larray
22

33

44
Syntax changes
@@ -129,6 +129,10 @@ New features
129129
* implemented :py:obj:`LArray.unique()` method to compute unique values (or sub-arrays) for an array,
130130
optionally along axes.
131131

132+
* implemented :py:obj:`LArray.apply()` method to apply a python function to all values of an array or to all sub-arrays
133+
along some axes of an array and return the result. This is an extremely versatile method as it can be used both with
134+
aggregating functions or element-wise functions.
135+
132136
* implemented :py:obj:`Axis.apply()` method to transform an axis labels by a function and return a new Axis.
133137

134138
>>> sex = Axis('sex=MALE,FEMALE')

larray/core/array.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7819,6 +7819,165 @@ def reverse(self, axes=None):
78197819
reversed_axes = tuple(axis[::-1] for axis in axes)
78207820
return self[reversed_axes]
78217821

7822+
# TODO: add excluded argument (to pass to vectorize but we must also compute res_axes / broadcasted arguments
7823+
# accordingly and handle it when axes is not None)
7824+
# excluded : set, optional
7825+
# Set of strings or integers representing the positional or keyword arguments for which the function
7826+
# will not be vectorized. These will be passed directly to the `transform` function unmodified.
7827+
def apply(self, transform, *args, **kwargs):
7828+
r"""
7829+
Apply a transformation function to array elements.
7830+
7831+
Parameters
7832+
----------
7833+
transform : function
7834+
Function to apply. This function will be called in turn with each element of the array as the first
7835+
argument and must return an LArray, scalar or tuple.
7836+
If returning arrays the axes of those arrays must be the same for all calls to the function.
7837+
*args
7838+
Extra arguments to pass to the function.
7839+
by : str, int or Axis or tuple/list/AxisCollection of the them, optional
7840+
Axis or axes along which to iterate. The function will thus be called with arrays having all axes not
7841+
mentioned. Defaults to None (all axes). Mutually exclusive with the `axes` argument.
7842+
axes : str, int or Axis or tuple/list/AxisCollection of the them, optional
7843+
Axis or axes the arrays passed to the function will have. Defaults to None (the function is given
7844+
scalars). Mutually exclusive with the `by` argument.
7845+
dtype : type or list of types, optional
7846+
Output(s) data type(s). Defaults to None (inspect all output values to infer it automatically).
7847+
ascending : bool, optional
7848+
Whether or not to iterate the axes in ascending order (from start to end). Defaults to True.
7849+
**kwargs
7850+
Extra keyword arguments are passed to the function (as keyword arguments).
7851+
7852+
Returns
7853+
-------
7854+
LArray or scalar, or tuple of them
7855+
Axes will be the union of those in axis and those of values returned by the function.
7856+
7857+
Examples
7858+
--------
7859+
First let us define a test array
7860+
7861+
>>> arr = LArray([[0, 2, 1],
7862+
... [3, 1, 5]], 'a=a0,a1;b=b0..b2')
7863+
>>> arr
7864+
a\b b0 b1 b2
7865+
a0 0 2 1
7866+
a1 3 1 5
7867+
7868+
Here is a simple function we would like to apply to each element of the array.
7869+
Note that this particular example should rather be written as: arr ** 2
7870+
as it is both more concise and much faster.
7871+
7872+
>>> def square(x):
7873+
... return x ** 2
7874+
>>> arr.apply(square)
7875+
a\b b0 b1 b2
7876+
a0 0 4 1
7877+
a1 9 1 25
7878+
7879+
Functions can also be applied along some axes:
7880+
7881+
>>> # this is equivalent to (but much slower than): arr.sum('a')
7882+
... arr.apply(sum, axes='a')
7883+
b b0 b1 b2
7884+
3 3 6
7885+
>>> # this is equivalent to (but much slower than): arr.sum_by('a')
7886+
... arr.apply(sum, by='a')
7887+
a a0 a1
7888+
3 9
7889+
7890+
Applying the function along some axes will return an array with the
7891+
union of those axes and the axes of the returned values. For example,
7892+
let us define a function which returns the k highest values of an array.
7893+
7894+
>>> def topk(a, k=2):
7895+
... return a.sort_values(ascending=False).ignore_labels().i[:k]
7896+
>>> arr.apply(topk, by='a')
7897+
a\b* 0 1
7898+
a0 2 1
7899+
a1 5 3
7900+
7901+
Other arguments can be passed to the function:
7902+
7903+
>>> arr.apply(topk, 3, by='a')
7904+
a\b* 0 1 2
7905+
a0 2 1 0
7906+
a1 5 3 1
7907+
7908+
or by using keyword arguments:
7909+
7910+
>>> arr.apply(topk, by='a', k=3)
7911+
a\b* 0 1 2
7912+
a0 2 1 0
7913+
a1 5 3 1
7914+
7915+
If the function returns several values (as a tuple), the result will be a tuple of arrays. For example,
7916+
let use define a function which decompose an array in its mean and the difference to that mean :
7917+
7918+
>>> def mean_decompose(a):
7919+
... mean = a.mean()
7920+
... return mean, a - mean
7921+
>>> mean_by_a, diff_to_mean = arr.apply(mean_decompose, by='a')
7922+
>>> mean_by_a
7923+
a a0 a1
7924+
1.0 3.0
7925+
>>> diff_to_mean
7926+
a\b b0 b1 b2
7927+
a0 -1.0 1.0 0.0
7928+
a1 0.0 -2.0 2.0
7929+
"""
7930+
# keyword only arguments
7931+
by = kwargs.pop('by', None)
7932+
axes = kwargs.pop('axes', None)
7933+
dtype = kwargs.pop('dtype', None)
7934+
ascending = kwargs.pop('ascending', True)
7935+
# excluded = kwargs.pop('excluded', None)
7936+
7937+
if axes is not None:
7938+
if by is not None:
7939+
raise ValueError("cannot specify both `by` and `axes` arguments in LArray.apply")
7940+
by = self.axes - axes
7941+
7942+
# XXX: we could go one step further than vectorize and support a array of callables which would be broadcasted
7943+
# with the other arguments. I don't know whether that would actually help because I think it always
7944+
# possible to emulate that with a single callable with an extra argument (eg type) which dispatches to
7945+
# potentially different callables. It might be more practical & efficient though.
7946+
if by is None:
7947+
otypes = [dtype] if isinstance(dtype, type) else dtype
7948+
vfunc = np.vectorize(transform, otypes=otypes)
7949+
# XXX: we should probably handle excluded here
7950+
# raw_bcast_args, raw_bcast_kwargs, res_axes = make_args_broadcastable((self,) + args, kwargs)
7951+
raw_bcast_args, raw_bcast_kwargs, res_axes = ((self,) + args, kwargs, self.axes)
7952+
res_data = vfunc(*raw_bcast_args, **raw_bcast_kwargs)
7953+
if isinstance(res_data, tuple):
7954+
return tuple(LArray(res_arr, res_axes) for res_arr in res_data)
7955+
else:
7956+
return LArray(res_data, res_axes)
7957+
else:
7958+
by = self.axes[by]
7959+
7960+
values = (self,) + args + tuple(kwargs.values())
7961+
first_kw = 1 + len(args)
7962+
kwnames = tuple(kwargs.keys())
7963+
key_values = [(k, transform(*a_and_kwa[:first_kw], **dict(zip(kwnames, a_and_kwa[first_kw:]))))
7964+
for k, a_and_kwa in zip_array_items(values, by, ascending)]
7965+
first_key, first_value = key_values[0]
7966+
if isinstance(first_value, tuple):
7967+
# assume all other values are the same shape
7968+
tuple_length = len(first_value)
7969+
# TODO: compute res_axes (potentially different for each return value) in this case too
7970+
res_arrays = [stack([(key, value[i]) for key, value in key_values], axes=by, dtype=dtype)
7971+
for i in range(tuple_length)]
7972+
# transpose back axis where it was
7973+
return tuple(res_arr.transpose(self.axes & res_arr.axes) for res_arr in res_arrays)
7974+
else:
7975+
res_axes = get_axes(first_value).union(by)
7976+
res_arr = stack(key_values, axes=by, dtype=dtype, res_axes=res_axes)
7977+
7978+
# transpose back axis where it was
7979+
return res_arr.transpose(self.axes & res_arr.axes)
7980+
78227981

78237982
def larray_equal(a1, a2):
78247983
import warnings

0 commit comments

Comments
 (0)