Skip to content

Commit 6fdaefb

Browse files
authored
Merge pull request #111 from invenia/logical-indexing
Support logical indexing
2 parents 6443de7 + 14efa74 commit 6fdaefb

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

src/indexing.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ This internal function determines the new set of axes that are constructed upon
4848
indexing with I.
4949
"""
5050
reaxis(A::AxisArray, I::Idx...) = _reaxis(make_axes_match(axes(A), I), I)
51+
function reaxis(A::AxisArray, I::AbstractArray{Bool})
52+
vecI = vec(I)
53+
_reaxis(make_axes_match(axes(A), (vecI,)), (vecI,))
54+
end
5155
# Ensure the number of axes matches the number of indexing dimensions
5256
@inline make_axes_match(axs, idxs) = _make_axes_match((), axs, Base.index_ndims(idxs...))
5357
# Move the axes into newaxes, until we run out of both simultaneously
@@ -106,6 +110,12 @@ using Base.AbstractCartesianIndex
106110
# Setindex is so much simpler. Just assign it to the data:
107111
@propagate_inbounds Base.setindex!(A::AxisArray, v, idxs::Idx...) = (A.data[idxs...] = v)
108112

113+
# Logical indexing
114+
@propagate_inbounds function Base.getindex(A::AxisArray, idx::AbstractArray{Bool})
115+
AxisArray(A.data[idx], reaxis(A, idx))
116+
end
117+
@propagate_inbounds Base.setindex!(A::AxisArray, v, idx::AbstractArray{Bool}) = (A.data[idx] = v)
118+
109119
### Fancier indexing capabilities provided only by AxisArrays ###
110120
@propagate_inbounds Base.getindex(A::AxisArray, idxs...) = A[to_index(A,idxs...)...]
111121
@propagate_inbounds Base.setindex!(A::AxisArray, v, idxs...) = (A[to_index(A,idxs...)...] = v)

test/indexing.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@ B = B2[1:2,:]
4848
@test B.axes[1].val == A.axes[1].val[1:2]
4949
@test B.axes[2].val == 1:Base.trailingsize(A,2)
5050

51+
# Logical indexing
52+
all_inds = collect(1:length(A))
53+
odd_inds = collect(1:2:length(A))
54+
@test @inferred(A[trues(A)]) == A[:] == A[all_inds]
55+
@test axes(A[trues(A)]) == axes(A[all_inds])
56+
@test @inferred(A[isodd.(A)]) == A[1:2:length(A)] == A[odd_inds]
57+
@test axes(A[isodd.(A)]) == axes(A[odd_inds])
58+
@test @inferred(A[vec(trues(A))]) == A[:] == A[all_inds]
59+
@test axes(A[vec(trues(A))]) == axes(A[all_inds])
60+
@test @inferred(A[vec(isodd.(A))]) == A[1:2:length(A)] == A[odd_inds]
61+
@test axes(A[vec(isodd.(A))]) == axes(A[odd_inds])
62+
63+
5164
B = AxisArray(reshape(1:15, 5,3), .1:.1:0.5, [:a, :b, :c])
5265

5366
# Test indexing by Intervals

0 commit comments

Comments
 (0)