Skip to content

Commit bf270d4

Browse files
committed
implemented zip_array_values and zip_array_items
1 parent b246d34 commit bf270d4

File tree

4 files changed

+182
-29
lines changed

4 files changed

+182
-29
lines changed

doc/source/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,8 @@ Miscellaneous
695695
eye
696696
ipfp
697697
wrap_elementwise_array_func
698+
zip_array_values
699+
zip_array_items
698700

699701
.. _api-session:
700702

doc/source/changes/version_0_30.rst.inc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ New features
147147

148148
* implemented :py:obj:`AxisCollection.rename()` to rename axes of an AxisCollection, independently of any array.
149149

150+
* implemented :py:obj:`zip_array_values()` and :py:obj:`zip_array_items()` to loop respectively on several arrays values
151+
or (key, value) pairs.
152+
150153
* implemented :py:obj:`AxisCollection.iter_labels()` to iterate over all (possible combinations of) labels of the axes
151154
of the collection.
152155

larray/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from larray.core.array import (LArray, zeros, zeros_like, ones, ones_like, empty, empty_like, full,
99
full_like, sequence, labels_array, ndtest, aslarray, identity, diag,
1010
eye, all, any, sum, prod, cumsum, cumprod, min, max, mean, ptp, var,
11-
std, median, percentile, stack)
11+
std, median, percentile, stack, zip_array_values, zip_array_items)
1212
from larray.core.session import Session, local_arrays, global_arrays, arrays
1313
from larray.core.constants import nan, inf, pi, e, euler_gamma
1414
from larray.core.metadata import Metadata

larray/core/array.py

Lines changed: 176 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
# * use larray "utils" in LIAM2 (to avoid duplicated code)
3030

3131
from collections import Iterable, Sequence, OrderedDict, abc
32-
from itertools import product, chain, groupby, islice
32+
from itertools import product, chain, groupby, islice, repeat
3333
import os
3434
import sys
3535
import functools
@@ -3300,7 +3300,7 @@ def keys(self, axes=None, ascending=True):
33003300
return self.axes.iter_labels(axes, ascending=ascending)
33013301

33023302
# TODO: implement values_by
3303-
def values(self, axes=None, ascending=True, expand=False):
3303+
def values(self, axes=None, ascending=True):
33043304
r"""Returns a view on the values of the array along axes.
33053305
33063306
Parameters
@@ -3310,9 +3310,6 @@ def values(self, axes=None, ascending=True, expand=False):
33103310
in the array).
33113311
ascending : bool, optional
33123312
Whether or not to iterate the axes in ascending order (from start to end). Defaults to True.
3313-
expand : bool, optional
3314-
Whether or not to expand array using axes. This allows one to iterate on axes which do not exist in
3315-
the array, which is useful when iterating on several arrays with different axes. Defaults to False.
33163313
33173314
Returns
33183315
-------
@@ -3367,15 +3364,6 @@ def values(self, axes=None, ascending=True, expand=False):
33673364
1 3
33683365
a a0 a1
33693366
0 2
3370-
>>> # iterate on the "c" axis, which does not exist in arr, that is return arr for each label along the "c" axis
3371-
... for value in arr.values('c=c0,c1', expand=True):
3372-
... print(value)
3373-
a\b b0 b1
3374-
a0 0 1
3375-
a1 2 3
3376-
a\b b0 b1
3377-
a0 0 1
3378-
a1 2 3
33793367
33803368
One can also access elements of the value sequence directly, instead of iterating over it. Say we want to
33813369
retrieve the first and last values of our array, we could write:
@@ -3391,30 +3379,19 @@ def values(self, axes=None, ascending=True, expand=False):
33913379
# combined[::-1] *is* indexable
33923380
return combined if ascending else combined[::-1]
33933381

3394-
if not isinstance(axes, (tuple, AxisCollection)):
3382+
if not isinstance(axes, (tuple, list, AxisCollection)):
33953383
axes = (axes,)
33963384

3397-
def get_axis(a):
3398-
if isinstance(a, basestring):
3399-
return Axis(a) if '=' in a else self.axes[a]
3400-
elif isinstance(a, int):
3401-
return self.axes[a]
3402-
else:
3403-
assert isinstance(a, Axis)
3404-
return a
3405-
axes = [get_axis(a) for a in axes]
3406-
array = self.expand(axes, readonly=True) if expand else self
3407-
axes = array.axes[axes]
3385+
axes = self.axes[axes]
34083386
# move axes in front
3409-
transposed = array.transpose(axes)
3387+
transposed = self.transpose(axes)
34103388
# combine axes if necessary
34113389
combined = transposed.combine_axes(axes, wildcard=True) if len(axes) > 1 else transposed
34123390
# trailing .i is to support the case where axis < self.axes (ie the elements of the result are arrays)
34133391
return combined.i if ascending else combined.i[::-1].i
34143392

34153393
# TODO: we currently return a tuple of groups even for 1D arrays, which can be both a bad or a good thing.
34163394
# if we returned an NDGroup in all cases, it would solve the problem
3417-
# TODO: implement expand=True
34183395
def items(self, axes=None, ascending=True):
34193396
r"""Returns a (label, value) view of the array along axes.
34203397
@@ -9249,6 +9226,177 @@ def make_args_broadcastable(args, kwargs=None, min_axes=None):
92499226
return raw_bcast_args, raw_bcast_kwargs, res_axes
92509227

92519228

9229+
def zip_array_values(values, axes=None, ascending=True):
9230+
r"""Returns a sequence as if simultaneously iterating on several arrays.
9231+
9232+
Parameters
9233+
----------
9234+
axes : int, str or Axis or tuple of them, optional
9235+
Axis or axes along which to iterate and in which order. Defaults to None (union of all axes present in
9236+
all arrays, in the order they are found).
9237+
ascending : bool, optional
9238+
Whether or not to iterate the axes in ascending order (from start to end). Defaults to True.
9239+
9240+
Returns
9241+
-------
9242+
Sequence
9243+
9244+
Examples
9245+
--------
9246+
>>> arr1 = ndtest('a=a0,a1;b=b1,b2')
9247+
>>> arr2 = ndtest('a=a0,a1;c=c1,c2')
9248+
>>> arr1
9249+
a\b b1 b2
9250+
a0 0 1
9251+
a1 2 3
9252+
>>> arr2
9253+
a\c c1 c2
9254+
a0 0 1
9255+
a1 2 3
9256+
>>> for a1, a2 in zip_array_values((arr1, arr2), 'a'):
9257+
... print("==")
9258+
... print(a1)
9259+
... print(a2)
9260+
==
9261+
b b1 b2
9262+
0 1
9263+
c c1 c2
9264+
0 1
9265+
==
9266+
b b1 b2
9267+
2 3
9268+
c c1 c2
9269+
2 3
9270+
>>> for a1, a2 in zip_array_values((arr1, arr2), arr2.c):
9271+
... print("==")
9272+
... print(a1)
9273+
... print(a2)
9274+
==
9275+
a\b b1 b2
9276+
a0 0 1
9277+
a1 2 3
9278+
a a0 a1
9279+
0 2
9280+
==
9281+
a\b b1 b2
9282+
a0 0 1
9283+
a1 2 3
9284+
a a0 a1
9285+
1 3
9286+
>>> for a1, a2 in zip_array_values((arr1, arr2)):
9287+
... print("arr1: {}, arr2: {}".format(a1, a2))
9288+
arr1: 0, arr2: 0
9289+
arr1: 0, arr2: 1
9290+
arr1: 1, arr2: 0
9291+
arr1: 1, arr2: 1
9292+
arr1: 2, arr2: 2
9293+
arr1: 2, arr2: 3
9294+
arr1: 3, arr2: 2
9295+
arr1: 3, arr2: 3
9296+
"""
9297+
def values_with_expand(value, axes, readonly=True, ascending=True):
9298+
if isinstance(value, LArray):
9299+
# an Axis axis is not necessarily in array.axes
9300+
expanded = value.expand(axes, readonly=readonly)
9301+
return expanded.values(axes, ascending=ascending)
9302+
else:
9303+
size = axes.size if axes.ndim else 0
9304+
return Repeater(value, size)
9305+
9306+
all_axes = AxisCollection.union(*[get_axes(v) for v in values])
9307+
if axes is None:
9308+
axes = all_axes
9309+
else:
9310+
if not isinstance(axes, (tuple, list, AxisCollection)):
9311+
axes = (axes,)
9312+
# transform string axes definitions to objects
9313+
axes = [Axis(axis) if isinstance(axis, basestring) and '=' in axis else axis
9314+
for axis in axes]
9315+
# transform string axes references to objects
9316+
axes = AxisCollection([axis if isinstance(axis, Axis) else all_axes[axis]
9317+
for axis in axes])
9318+
9319+
# sequence of tuples (of scalar or arrays)
9320+
return SequenceZip([values_with_expand(v, axes, ascending=ascending) for v in values])
9321+
9322+
9323+
def zip_array_items(values, axes=None, ascending=True):
9324+
r"""Returns a sequence as if simultaneously iterating on several arrays as well as the current iteration "key".
9325+
9326+
Broadcasts all values against each other. Scalars are simply repeated.
9327+
9328+
Parameters
9329+
----------
9330+
values : Iterable
9331+
arrays to iterate on.
9332+
axes : int, str or Axis or tuple of them, optional
9333+
Axis or axes along which to iterate and in which order. Defaults to None (union of all axes present in
9334+
all arrays, in the order they are found).
9335+
ascending : bool, optional
9336+
Whether or not to iterate the axes in ascending order (from start to end). Defaults to True.
9337+
9338+
Returns
9339+
-------
9340+
Sequence
9341+
9342+
Examples
9343+
--------
9344+
>>> arr1 = ndtest('a=a0,a1;b=b0,b1')
9345+
>>> arr2 = ndtest('a=a0,a1;c=c0,c1')
9346+
>>> arr1
9347+
a\b b0 b1
9348+
a0 0 1
9349+
a1 2 3
9350+
>>> arr2
9351+
a\c c0 c1
9352+
a0 0 1
9353+
a1 2 3
9354+
>>> for k, (a1, a2) in zip_array_items((arr1, arr2), 'a'):
9355+
... print("==", k[0], "==")
9356+
... print(a1)
9357+
... print(a2)
9358+
== a0 ==
9359+
b b0 b1
9360+
0 1
9361+
c c0 c1
9362+
0 1
9363+
== a1 ==
9364+
b b0 b1
9365+
2 3
9366+
c c0 c1
9367+
2 3
9368+
>>> for k, (a1, a2) in zip_array_items((arr1, arr2), arr2.c):
9369+
... print("==", k[0], "==")
9370+
... print(a1)
9371+
... print(a2)
9372+
== c0 ==
9373+
a\b b0 b1
9374+
a0 0 1
9375+
a1 2 3
9376+
a a0 a1
9377+
0 2
9378+
== c1 ==
9379+
a\b b0 b1
9380+
a0 0 1
9381+
a1 2 3
9382+
a a0 a1
9383+
1 3
9384+
>>> for k, (a1, a2) in zip_array_items((arr1, arr2)):
9385+
... print(k, "arr1: {}, arr2: {}".format(a1, a2))
9386+
(a.i[0], b.i[0], c.i[0]) arr1: 0, arr2: 0
9387+
(a.i[0], b.i[0], c.i[1]) arr1: 0, arr2: 1
9388+
(a.i[0], b.i[1], c.i[0]) arr1: 1, arr2: 0
9389+
(a.i[0], b.i[1], c.i[1]) arr1: 1, arr2: 1
9390+
(a.i[1], b.i[0], c.i[0]) arr1: 2, arr2: 2
9391+
(a.i[1], b.i[0], c.i[1]) arr1: 2, arr2: 3
9392+
(a.i[1], b.i[1], c.i[0]) arr1: 3, arr2: 2
9393+
(a.i[1], b.i[1], c.i[1]) arr1: 3, arr2: 3
9394+
"""
9395+
res_axes = AxisCollection.union(*[get_axes(v) for v in values])
9396+
return SequenceZip((res_axes.iter_labels(axes, ascending=ascending),
9397+
zip_array_values(values, axes=axes, ascending=ascending)))
9398+
9399+
92529400
_default_float_error_handler = float_error_handler_factory(3)
92539401

92549402

0 commit comments

Comments
 (0)