Skip to content

Commit 69993e3

Browse files
committed
fixed arr[bool_axis[bool_key]] (closes #735)
1 parent 925e786 commit 69993e3

File tree

3 files changed

+50
-9
lines changed

3 files changed

+50
-9
lines changed

doc/source/changes/version_0_31.rst.inc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,6 @@ Miscellaneous improvements
3333
Fixes
3434
^^^^^
3535

36-
* fixed something (closes :issue:`1`).
36+
* fixed taking a subset of an array with boolean labels for an axis if the user explicitly specify the axis
37+
(closes :issue:`735`). When the user does not specify the axis, it currently fails but it is unclear what to do in
38+
that case (see :issue:`794`).

larray/core/axis.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -887,9 +887,6 @@ def index(self, key):
887887
# stop is inclusive in the input key and exclusive in the output !
888888
stop = mapping[key.stop] + 1 if key.stop is not None else None
889889
return slice(start, stop, key.step)
890-
# XXX: bool LArray do not pass through???
891-
elif isinstance(key, np.ndarray) and key.dtype.kind is 'b':
892-
return key
893890
elif isinstance(key, (tuple, list, OrderedSet)):
894891
# TODO: the result should be cached
895892
# Note that this is faster than array_lookup(np.array(key), mapping)
@@ -2730,7 +2727,7 @@ def _key_to_igroups(self, key):
27302727
if isinstance(axis_key, LArray) and np.issubdtype(axis_key.dtype, np.bool_):
27312728
extra_key_axes = axis_key.axes - self
27322729
if extra_key_axes:
2733-
raise ValueError("subset key contains more axes ({}) than array ({})"
2730+
raise ValueError("boolean subset key contains more axes ({}) than array ({})"
27342731
.format(axis_key.axes, self))
27352732
# nonzero (currently) returns a tuple of IGroups containing 1D LArrays (one IGroup per axis)
27362733
nonboolkey.extend(axis_key.nonzero())

larray/tests/test_array.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import absolute_import, division, print_function
33

44
import os
5+
import re
56
import sys
67

78
import pytest
@@ -30,6 +31,7 @@
3031
# Test Value Strings #
3132
# ================== #
3233

34+
3335
def test_value_string_split():
3436
assert_array_equal(_to_ticks('M,F'), np.asarray(['M', 'F']))
3537
assert_array_equal(_to_ticks('M, F'), np.asarray(['M', 'F']))
@@ -721,7 +723,7 @@ def test_getitem_abstract_positional(array):
721723
array[X.age.i[2, 3], X.age.i[1, 5]]
722724

723725

724-
def test_getitem_bool_larray_key():
726+
def test_getitem_bool_larray_key_arr_whout_bool_axis():
725727
arr = ndtest((3, 2, 4))
726728
raw = arr.data
727729

@@ -732,8 +734,8 @@ def test_getitem_bool_larray_key():
732734
assert_array_equal(res, raw[raw < 5])
733735

734736
# missing dimension
735-
filt = arr['b1'] % 5 == 0
736-
res = arr[filt]
737+
filter_ = arr['b1'] % 5 == 0
738+
res = arr[filter_]
737739
assert isinstance(res, LArray)
738740
assert res.ndim == 2
739741
assert res.shape == (3, 2)
@@ -752,6 +754,26 @@ def test_getitem_bool_larray_key():
752754
assert_array_equal(res, raw[:, :2])
753755

754756

757+
def test_getitem_bool_larray_key_arr_wh_bool_axis():
758+
gender = Axis([False, True], 'gender')
759+
arr = LArray([0.1, 0.2], gender)
760+
id_axis = Axis('id=0..3')
761+
key = LArray([True, False, True, True], id_axis)
762+
expected = LArray([0.2, 0.1, 0.2, 0.2], id_axis)
763+
764+
# LGroup using the real axis
765+
assert_larray_equal(arr[gender[key]], expected)
766+
767+
# LGroup using an AxisReference
768+
assert_larray_equal(arr[X.gender[key]], expected)
769+
770+
# this test checks that the current behavior does not change unintentionally...
771+
# ... but I am unsure the current behavior is what we actually want
772+
msg = re.escape("boolean subset key contains more axes ({id}) than array ({gender})")
773+
with pytest.raises(ValueError, match=msg):
774+
arr[key]
775+
776+
755777
def test_getitem_bool_larray_and_group_key():
756778
arr = ndtest((3, 6, 4)).set_labels('b', '0..5')
757779

@@ -769,14 +791,34 @@ def test_getitem_bool_larray_and_group_key():
769791
assert_array_equal(res, expected)
770792

771793

772-
def test_getitem_bool_ndarray_key(array):
794+
def test_getitem_bool_ndarray_key_arr_whout_bool_axis(array):
773795
raw = array.data
774796
res = array[raw < 5]
775797
assert isinstance(res, LArray)
776798
assert res.ndim == 1
777799
assert_array_equal(res, raw[raw < 5])
778800

779801

802+
def test_getitem_bool_ndarray_key_arr_wh_bool_axis():
803+
gender = Axis([False, True], 'gender')
804+
arr = LArray([0.1, 0.2], gender)
805+
key = np.array([True, False, True, True])
806+
expected = arr.i[[1, 0, 1, 1]]
807+
808+
# LGroup using the real axis
809+
assert_larray_equal(arr[gender[key]], expected)
810+
811+
# LGroup using an AxisReference
812+
assert_larray_equal(arr[X.gender[key]], expected)
813+
814+
# raw key => ???
815+
# this test checks that the current behavior does not change unintentionally...
816+
# ... but I am unsure the current behavior is what we actually want
817+
msg = re.escape("boolean key with a different shape ((4,)) than array ((2,))")
818+
with pytest.raises(ValueError, match=msg):
819+
arr[key]
820+
821+
780822
def test_getitem_bool_anonymous_axes():
781823
a = ndtest([Axis(2), Axis(3), Axis(4), Axis(5)])
782824
mask = ones(a.axes[1, 3], dtype=bool)

0 commit comments

Comments
 (0)