Skip to content

Commit 14efa74

Browse files
committed
Support logical indexing
1 parent 83b4fde commit 14efa74

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

src/indexing.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ This internal function determines the new set of axes that are constructed upon
3737
indexing with I.
3838
"""
3939
reaxis(A::AxisArray, I::Idx...) = _reaxis(make_axes_match(axes(A), I), I)
40+
function reaxis(A::AxisArray, I::AbstractArray{Bool})
41+
vecI = vec(I)
42+
_reaxis(make_axes_match(axes(A), (vecI,)), (vecI,))
43+
end
4044
# Ensure the number of axes matches the number of indexing dimensions
4145
@inline make_axes_match(axs, idxs) = _make_axes_match((), axs, Base.index_ndims(idxs...))
4246
# Move the axes into newaxes, until we run out of both simultaneously
@@ -95,6 +99,12 @@ using Base.AbstractCartesianIndex
9599
# Setindex is so much simpler. Just assign it to the data:
96100
@propagate_inbounds Base.setindex!(A::AxisArray, v, idxs::Idx...) = (A.data[idxs...] = v)
97101

102+
# Logical indexing
103+
@propagate_inbounds function Base.getindex(A::AxisArray, idx::AbstractArray{Bool})
104+
AxisArray(A.data[idx], reaxis(A, idx))
105+
end
106+
@propagate_inbounds Base.setindex!(A::AxisArray, v, idx::AbstractArray{Bool}) = (A.data[idx] = v)
107+
98108
### Fancier indexing capabilities provided only by AxisArrays ###
99109
@propagate_inbounds Base.getindex(A::AxisArray, idxs...) = A[to_index(A,idxs...)...]
100110
@propagate_inbounds Base.setindex!(A::AxisArray, v, idxs...) = (A[to_index(A,idxs...)...] = v)

test/indexing.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@ B = B2[1:2,:]
4848
@test B.axes[1].val == A.axes[1].val[1:2]
4949
@test B.axes[2].val == 1:Base.trailingsize(A,2)
5050

51+
# Logical indexing
52+
all_inds = collect(1:length(A))
53+
odd_inds = collect(1:2:length(A))
54+
@test @inferred(A[trues(A)]) == A[:] == A[all_inds]
55+
@test axes(A[trues(A)]) == axes(A[all_inds])
56+
@test @inferred(A[isodd.(A)]) == A[1:2:length(A)] == A[odd_inds]
57+
@test axes(A[isodd.(A)]) == axes(A[odd_inds])
58+
@test @inferred(A[vec(trues(A))]) == A[:] == A[all_inds]
59+
@test axes(A[vec(trues(A))]) == axes(A[all_inds])
60+
@test @inferred(A[vec(isodd.(A))]) == A[1:2:length(A)] == A[odd_inds]
61+
@test axes(A[vec(isodd.(A))]) == axes(A[odd_inds])
62+
63+
5164
B = AxisArray(reshape(1:15, 5,3), .1:.1:0.5, [:a, :b, :c])
5265

5366
# Test indexing by Intervals

0 commit comments

Comments
 (0)