Skip to content

Commit b246d34

Browse files
committed
improved stack
* generalized it to more than one dimension (closes #56) - works for both stack([(ndkey, value), ...], axis=axes) and stack({ndkey: value}, several_axes) - deprecated axis argument in favor of axes * changed the exception when using **kwargs without an axis on Python < 3.6 to a warning * allowed using a dict without axis (closes #581, closes #755) This will produces a warning on Python < 3.7 * added support for res_axes
1 parent 0fe2463 commit b246d34

File tree

5 files changed

+452
-167
lines changed

5 files changed

+452
-167
lines changed

doc/source/changes/version_0_30.rst.inc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ Syntax changes
66

77
* :py:obj:`LArray.as_table()` is deprecated. Please use :py:obj:`LArray.dump()` instead.
88

9+
* :py:obj:`stack()` ``axis`` argument was renamed to ``axes`` to reflect the fact that the function can now stack
10+
along multiple axes at once (see below).
11+
912

1013
Backward incompatible changes
1114
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -196,6 +199,30 @@ Miscellaneous improvements
196199
A0 0 1
197200
A1 2 3
198201

202+
* py:obj:`stack()` can now stack along several axes at once (closes :issue:`56`).
203+
204+
>>> country = Axis('country=BE,FR,DE')
205+
>>> gender = Axis('gender=M,F')
206+
>>> stack({('BE', 'M'): 0,
207+
... ('BE', 'F'): 1,
208+
... ('FR', 'M'): 2,
209+
... ('FR', 'F'): 3,
210+
... ('DE', 'M'): 4,
211+
... ('DE', 'F'): 5},
212+
... (country, gender))
213+
country\gender M F
214+
BE 0 1
215+
FR 2 3
216+
DE 4 5
217+
218+
* py:obj:`stack()` using a dictionary as elements can now use a simple axis name instead of requiring a full axis
219+
object. This will print a warning on Python < 3.7 though because the ordering of labels is not guaranteed in
220+
that case. Closes :issue:`755` and :issue:`581`.
221+
222+
* py:obj:`stack()` using keyword arguments can now use a simple axis name instead of requiring a full axis
223+
object, even on Python < 3.6. This will print a warning though because the ordering of labels is not guaranteed in
224+
that case.
225+
199226
* added option ``exact`` to ``join`` argument of :py:obj:`Axis.align()` and :py:obj:`LArray.align()` methods.
200227
Instead of aligning, passing ``join='exact'`` to the ``align`` method will raise an error when axes are not equal.
201228
Closes :issue:`338`.

larray/core/array.py

Lines changed: 185 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
# ? implement multi group in one axis getitem: lipro['P01,P02;P05'] <=> (lipro['P01,P02'], lipro['P05'])
99

1010
# * we need an API to get to the "next" label. Sometimes, we want to use label+1, but that is problematic when labels
11-
# are not numeric, or have not a step of 1. X.agegroup[X.agegroup.after(25):]
12-
# X.agegroup[X.agegroup[25].next():]
11+
# are not numeric, or have not a step of 1.
12+
# X.agegroup[X.agegroup.after(25):]
13+
# X.agegroup[X.agegroup[25].next():]
1314

1415
# * implement keepaxes=True for _group_aggregate instead of/in addition to group tuples
1516

@@ -25,10 +26,7 @@
2526

2627
# * test structured arrays
2728

28-
# ? move "utils" to its own project (so that it is not duplicated between larray and liam2)
29-
# OR
30-
# include utils only in larray project and make larray a dependency of liam2
31-
# (and potentially rename it to reflect the broader scope)
29+
# * use larray "utils" in LIAM2 (to avoid duplicated code)
3230

3331
from collections import Iterable, Sequence, OrderedDict, abc
3432
from itertools import product, chain, groupby, islice
@@ -8852,8 +8850,69 @@ def eye(rows, columns=None, k=0, title=None, dtype=None, meta=None):
88528850
# ('FR', 'M'): 2, ('FR', 'F'): 3,
88538851
# ('DE', 'M'): 4, ('DE', 'F'): 5})
88548852

8855-
8856-
def stack(elements=None, axis=None, title=None, meta=None, dtype=None, **kwargs):
8853+
# for 2D, I think the best compromise is the nested dict (especially for python 3.7+):
8854+
8855+
# stack({'BE': {'M': 0, 'F': 1},
8856+
# 'FR': {'M': 2, 'F': 3},
8857+
# 'DE': {'M': 4, 'F': 5}}, axes=('nationality', 'sex'))
8858+
8859+
# we could make this valid too (combine pos and labels) but I don't think it worth it unless it comes
8860+
# naturally from the implementation:
8861+
8862+
# stack({'BE': {'M,F': [0, 1]},
8863+
# 'FR': {'M,F': [2, 3]},
8864+
# 'DE': {'M,F': [4, 5]}}, axes=('nationality', 'sex'))
8865+
8866+
# It looks especially nice if the labels have been extracted to variables:
8867+
8868+
# BE, FR, DE = nat['BE,FR,DE']
8869+
# M, F = sex['M,F']
8870+
8871+
# stack({BE: {M: 0, F: 1},
8872+
# FR: {M: 2, F: 3},
8873+
# DE: {M: 4, F: 5}})
8874+
8875+
# for 3D:
8876+
8877+
# stack({'a0': {'b0': {'c0': 0, 'c1': 1},
8878+
# 'b1': {'c0': 2, 'c1': 3},
8879+
# 'b2': {'c0': 4, 'c1': 5}},
8880+
# 'a1': {'b0': {'c0': 6, 'c1': 7},
8881+
# 'b1': {'c0': 8, 'c1': 9},
8882+
# 'b2': {'c0': 10, 'c1': 11}}},
8883+
# axes=('a', 'b', 'c'))
8884+
8885+
# a0, a1 = a['a0,a1']
8886+
# b0, b1, b2 = b['b0,b1,b2']
8887+
# c0, c1 = c['c0,c1']
8888+
8889+
# stack({a0: {b0: {c0: 0, c1: 1},
8890+
# b1: {c0: 2, c1: 3},
8891+
# b2: {c0: 4, c1: 5}},
8892+
# a1: {b0: {c0: 6, c1: 7},
8893+
# b1: {c0: 8, c1: 9},
8894+
# b2: {c0: 10, c1: 11}}},
8895+
# axes=(a, b, c))
8896+
8897+
# if we implement:
8898+
# arr[key] = {'a0': 0, 'a1': 1}
8899+
# where key must not be related to the "a" axis
8900+
# if would make it relatively easy to implement the nested dict syntax I think:
8901+
# first do a pass at the structure to get axes (if not provided) then:
8902+
# for k, v in d.items():
8903+
# arr[k] = v
8904+
# but that syntax could be annoying if we want to have an array of dicts
8905+
8906+
# alternatives:
8907+
8908+
# arr['a0'] = 0; arr['a1'] = 1 # <-- this already works
8909+
# arr['a0,a1'] = [0, 1] # <-- unsure if this works, but we should make it work (it is annoying if we
8910+
# # have an array of lists
8911+
# arr[:] = {'a0': 0, 'a1': 1}
8912+
# arr[:] = stack({'a0': 0, 'a1': 1}) # <-- not equivalent if a has more labels
8913+
8914+
@deprecate_kwarg('axis', 'axes')
8915+
def stack(elements=None, axes=None, title=None, meta=None, dtype=None, res_axes=None, **kwargs):
88578916
r"""
88588917
Combines several arrays or sessions along an axis.
88598918
@@ -8866,15 +8925,17 @@ def stack(elements=None, axis=None, title=None, meta=None, dtype=None, **kwargs)
88668925
88678926
Stacking sessions will return a new session containing the arrays of all sessions stacked together. An array
88688927
missing in a session will be replaced by NaN.
8869-
axis : str or Axis or Group, optional
8870-
Axis to create. If None, defaults to a range() axis.
8928+
axes : str, Axis, Group or sequence of Axis, optional
8929+
Axes to create. If None, defaults to a range() axis.
88718930
title : str, optional
88728931
Deprecated. See 'meta' below.
88738932
meta : list of pairs or dict or OrderedDict or Metadata, optional
88748933
Metadata (title, description, author, creation_date, ...) associated with the array.
88758934
Keys must be strings. Values must be of type string, int, float, date, time or datetime.
88768935
dtype : type, optional
88778936
Output dtype. Defaults to None (inspect all output values to infer it automatically).
8937+
res_axes : AxisCollection, optional
8938+
Axes of the output. Defaults to None (union of axes of all values and the stacking axes).
88788939
88798940
Returns
88808941
-------
@@ -8894,17 +8955,12 @@ def stack(elements=None, axis=None, title=None, meta=None, dtype=None, **kwargs)
88948955
sex M F
88958956
0.0 0.0
88968957
8897-
In the case the axis to create has already been defined in a variable (Axis or Group)
8958+
In case the axis to create has already been defined in a variable (Axis or Group)
88988959
88998960
>>> stack({'BE': arr1, 'FO': arr2}, nat)
89008961
sex\nat BE FO
89018962
M 1.0 0.0
89028963
F 1.0 0.0
8903-
>>> all_nat = Axis('nat=BE,DE,FR,NL,UK')
8904-
>>> stack({'BE': arr1, 'DE': arr2}, all_nat[:'DE'])
8905-
sex\nat BE DE
8906-
M 1.0 0.0
8907-
F 1.0 0.0
89088964
89098965
Otherwise (when one wants to create an axis from scratch), any of these syntaxes works:
89108966
@@ -8938,7 +8994,7 @@ def stack(elements=None, axis=None, title=None, meta=None, dtype=None, **kwargs)
89388994
When labels are "simple" strings (ie no integers, no string starting with integers, etc.), using keyword
89398995
arguments can be an attractive alternative.
89408996
8941-
>>> stack(FO=arr2, BE=arr1, axis=nat)
8997+
>>> stack(FO=arr2, BE=arr1, axes=nat)
89428998
sex\nat BE FO
89438999
M 1.0 0.0
89449000
F 1.0 0.0
@@ -8947,11 +9003,25 @@ def stack(elements=None, axis=None, title=None, meta=None, dtype=None, **kwargs)
89479003
or later because keyword arguments are NOT ordered on earlier Python versions.
89489004
89499005
>>> # use this only on Python 3.6 and later
8950-
>>> stack(BE=arr1, FO=arr2, axis='nat') # doctest: +SKIP
9006+
>>> stack(BE=arr1, FO=arr2, axes='nat') # doctest: +SKIP
89519007
sex\nat BE FO
89529008
M 1.0 0.0
89539009
F 1.0 0.0
89549010
9011+
One can also stack along several axes
9012+
9013+
>>> test = Axis('test=T1,T2')
9014+
>>> stack({('BE', 'T1'): arr1,
9015+
... ('BE', 'T2'): arr2,
9016+
... ('FO', 'T1'): arr2,
9017+
... ('FO', 'T2'): arr1},
9018+
... (nat, test))
9019+
sex nat\test T1 T2
9020+
M BE 1.0 0.0
9021+
M FO 0.0 1.0
9022+
F BE 1.0 0.0
9023+
F FO 0.0 1.0
9024+
89559025
To stack sessions, let us first create two test sessions. For example suppose we have a session storing the results
89569026
of a baseline simulation:
89579027
@@ -8976,83 +9046,127 @@ def stack(elements=None, axis=None, title=None, meta=None, dtype=None, **kwargs)
89769046
M 0.0 0.5
89779047
F 0.0 0.5
89789048
"""
9049+
from larray import Session
9050+
89799051
meta = _handle_meta(meta, title)
89809052

8981-
from larray import Session
9053+
if elements is not None and kwargs:
9054+
raise TypeError("stack() accepts either keyword arguments OR a collection of elements, not both")
9055+
9056+
if isinstance(axes, basestring) and '=' in axes:
9057+
axes = Axis(axes)
9058+
elif isinstance(axes, Group):
9059+
axes = Axis(axes)
9060+
9061+
if axes is not None and not isinstance(axes, basestring):
9062+
axes = AxisCollection(axes)
89829063

8983-
if isinstance(axis, str) and '=' in axis:
8984-
axis = Axis(axis)
8985-
if isinstance(axis, Group):
8986-
axis = Axis(axis)
8987-
if elements is None:
8988-
if not isinstance(axis, Axis) and sys.version_info[:2] < (3, 6):
8989-
# XXX: this should probably be a warning, not an error
8990-
raise TypeError("axis argument should provide label order when using keyword arguments on Python < 3.6")
9064+
if kwargs:
9065+
if not isinstance(axes, AxisCollection) and sys.version_info[:2] < (3, 6):
9066+
warnings.warn("keyword arguments ordering is not guaranteed for Python < 3.6 so it is not "
9067+
"recommended to use them in stack() without providing labels order in the axes argument")
89919068
elements = kwargs.items()
8992-
elif kwargs:
8993-
raise TypeError("stack() accept either keyword arguments OR a collection of elements, not both")
89949069

8995-
if isinstance(axis, Axis) and all(isinstance(e, tuple) for e in elements):
8996-
assert all(len(e) == 2 for e in elements)
8997-
elements = {k: v for k, v in elements}
9070+
if isinstance(elements, dict):
9071+
if not isinstance(axes, AxisCollection) and sys.version_info[:2] < (3, 7):
9072+
# stacklevel=3 because of deprecate_kwarg
9073+
warnings.warn("dict ordering is not guaranteed for Python < 3.7 so it is not recommended to use "
9074+
"them in stack() without providing labels order in the axes argument", stacklevel=3)
9075+
9076+
elements = elements.items()
89989077

89999078
if isinstance(elements, LArray):
9000-
if axis is None:
9001-
axis = -1
9002-
axis = elements.axes[axis]
9003-
values = [elements[k] for k in axis]
9004-
elif isinstance(elements, dict):
9005-
# TODO: support having no Axis object for Python3.7 (without error or warning)
9006-
# XXX: we probably want to support this with a warning on Python < 3.7
9007-
assert isinstance(axis, Axis)
9008-
values = [elements[v] for v in axis.labels]
9079+
if axes is None:
9080+
axes = -1
9081+
axes = elements.axes[axes]
9082+
items = elements.items(axes)
90099083
elif isinstance(elements, Iterable):
90109084
if not isinstance(elements, Sequence):
90119085
elements = list(elements)
90129086

90139087
if all(isinstance(e, tuple) for e in elements):
90149088
assert all(len(e) == 2 for e in elements)
9015-
keys = [k for k, v in elements]
9016-
values = [v for k, v in elements]
9017-
assert all(np.isscalar(k) for k in keys)
9018-
# this case should already be handled
9019-
assert not isinstance(axis, Axis)
9020-
# axis should be None or str
9021-
axis = Axis(keys, axis)
9089+
if axes is None or isinstance(axes, basestring):
9090+
keys = [k for k, v in elements]
9091+
values = [v for k, v in elements]
9092+
# assert that all keys are indexers
9093+
assert all(np.isscalar(k) or isinstance(k, (Group, tuple)) for k in keys)
9094+
# TODO: add support for more than one axis here
9095+
axes = AxisCollection(Axis(keys, axes))
9096+
items = list(zip(axes[0], values))
9097+
else:
9098+
def translate_and_sort_key(key, axes):
9099+
dict_of_igroups = {k.axis: k for k in axes._key_to_igroups(key)}
9100+
return tuple(dict_of_igroups[axis] for axis in axes)
9101+
9102+
# passing only via _key_to_igroup should be enough if we allow for partial axes
9103+
dict_elements = {translate_and_sort_key(key, axes): value for key, value in elements}
9104+
items = [(k, dict_elements[k]) for k in axes.iter_labels()]
90229105
else:
9023-
values = elements
9024-
if axis is None or isinstance(axis, basestring):
9025-
axis = Axis(len(elements), axis)
9106+
if axes is None or isinstance(axes, basestring):
9107+
axes = AxisCollection(Axis(len(elements), axes))
90269108
else:
9027-
assert len(axis) == len(elements)
9109+
# TODO: add support for more than one axis here
9110+
assert axes.ndim == 1 and len(axes[0]) == len(elements)
9111+
items = list(zip(axes[0], elements))
90289112
else:
90299113
raise TypeError('unsupported type for arrays: %s' % type(elements).__name__)
90309114

9031-
if any(isinstance(v, Session) for v in values):
9032-
sessions = values
9033-
if not all(isinstance(s, Session) for s in sessions):
9115+
if any(isinstance(v, Session) for k, v in items):
9116+
if not all(isinstance(v, Session) for k, v in items):
90349117
raise TypeError("stack() only supports stacking Session with other Session objects")
90359118

9036-
all_keys = unique_multi(s.keys() for s in sessions)
9037-
res = []
9038-
for name in all_keys:
9119+
array_names = unique_multi(sess.keys() for sess_name, sess in items)
9120+
9121+
def stack_one(array_name):
90399122
try:
9040-
stacked = stack([s.get(name, nan) for s in sessions], axis=axis)
9123+
return stack([(sess_name, sess.get(array_name, nan))
9124+
for sess_name, sess in items], axes=axes)
90419125
# TypeError for str arrays, ValueError for incompatible axes, ...
90429126
except Exception:
9043-
stacked = nan
9044-
res.append((name, stacked))
9045-
return Session(res, meta=meta)
9127+
return nan
9128+
9129+
return Session([(array_name, stack_one(array_name)) for array_name in array_names], meta=meta)
90469130
else:
9047-
# XXX : use concat?
9048-
values = [aslarray(v) if not np.isscalar(v) else v
9049-
for v in values]
9050-
result_axes = AxisCollection.union(*[get_axes(v) for v in values])
9051-
result_axes.append(axis)
9052-
if dtype is None:
9053-
dtype = common_type(values)
9054-
result = empty(result_axes, dtype=dtype, meta=meta)
9055-
for k, v in zip(axis, values):
9131+
if res_axes is None or dtype is None:
9132+
values = [aslarray(v) if not np.isscalar(v) else v
9133+
for k, v in items]
9134+
9135+
if res_axes is None:
9136+
# we need a kludge to support stacking along an anonymous axis because AxisCollection.extend
9137+
# (and thus AxisCollection.union) support for anonymous axes is kinda messy.
9138+
if axes[0].name is None:
9139+
axes = axes.rename(0, '__anonymous__')
9140+
kludge = True
9141+
else:
9142+
kludge = False
9143+
9144+
# XXX: with the current semantics of stack, we need to compute the union of axes for values but axis
9145+
# needs to be added unconditionally. We *might* want to change the semantics to mean either stack
9146+
# or concat depending on whether or not the axis already exists.
9147+
# this would be more convenient for users I think, but would mean one class of error we cannot
9148+
# detect anymore: if a user unintentionally stacks an array with the axis already present.
9149+
# (this is very similar to the debate about combining LArray.append and LArray.extend)
9150+
all_axes = [get_axes(v) for v in values] + [axes]
9151+
res_axes = AxisCollection.union(*all_axes)
9152+
if kludge:
9153+
res_axes = res_axes.rename(axes[0], None)
9154+
elif not isinstance(res_axes, AxisCollection):
9155+
res_axes = AxisCollection(res_axes)
9156+
9157+
if dtype is None:
9158+
# dtype = common_type(values + [fill_value])
9159+
dtype = common_type(values)
9160+
9161+
# if needs_fill:
9162+
# result = full(res_axes, fill_value, dtype=dtype, meta=meta)
9163+
# else:
9164+
result = empty(res_axes, dtype=dtype, meta=meta)
9165+
9166+
# FIXME: this is *much* faster but it only works for scalars and not for stacking arrays
9167+
# keys = tuple(zip(*[k for k, v in items]))
9168+
# result.points[keys] = values
9169+
for k, v in items:
90569170
result[k] = v
90579171
return result
90589172

0 commit comments

Comments
 (0)