Skip to content

Commit c8ff6f6

Browse files
committed
Allow atvalue on more types, introduce TolValue and ExactValue
1 parent 83b4fde commit c8ff6f6

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

src/indexing.jl

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,22 @@ const Idx = Union{Real,Colon,AbstractArray{Int}}
22

33
using Base: ViewIndex, @propagate_inbounds, tail
44

5-
struct Value{T}
5+
abstract type Value{T} end
6+
7+
struct TolValue{T} <: Value{T}
68
val::T
79
tol::T
810
end
9-
Value(x, tol=Base.rtoldefault(typeof(x))*abs(x)) = Value(promote(x,tol)...)
10-
atvalue(x; rtol=Base.rtoldefault(typeof(x)), atol=zero(x)) = Value(x, atol+rtol*abs(x))
11+
12+
TolValue(x, tol=Base.rtoldefault(typeof(x))*abs(x)) = TolValue(promote(x,tol)...)
13+
14+
struct ExactValue{T} <: Value{T}
15+
val::T
16+
end
17+
18+
atvalue(x::Number; rtol=Base.rtoldefault(typeof(x)), atol=zero(x)) = TolValue(x, atol+rtol*abs(x))
19+
atvalue(x) = ExactValue(x)
20+
1121
const Values = AbstractArray{<:Value}
1222

1323
# For throwing a BoundsError with a Value index, we need to define the following
@@ -17,8 +27,9 @@ Base.next(x::Value, state) = (x, true)
1727
Base.done(x::Value, state) = state
1828

1929
# How to show Value objects (e.g. in a BoundsError)
20-
Base.show(io::IO, v::Value) =
21-
print(io, string("Value(", v.val, ", tol=", v.tol, ")"))
30+
Base.show(io::IO, v::TolValue) =
31+
print(io, string("TolValue(", v.val, ", tol=", v.tol, ")"))
32+
Base.show(io::IO, v::ExactValue) = print(io, string("ExactValue(", v.val, ")"))
2233

2334
# Defer IndexStyle to the wrapped array
2435
Base.IndexStyle(::Type{AxisArray{T,N,D,Ax}}) where {T,N,D,Ax} = IndexStyle(D)
@@ -168,7 +179,7 @@ function axisindexes(::Type{Dimensional}, ax::AbstractVector, idx)
168179
idxs[1]
169180
end
170181
# Dimensional axes may always be indexed by value if in a Value type wrapper.
171-
function axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::Value)
182+
function axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::TolValue)
172183
idxs = searchsorted(ax, ClosedInterval(idx.val,idx.val))
173184
length(idxs) > 1 && error("more than one datapoint lies on axis value $idx; use an interval to return all values")
174185
if length(idxs) == 1
@@ -179,6 +190,15 @@ function axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::Value)
179190
throw(BoundsError(ax, idx))
180191
end
181192
end
193+
function axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::ExactValue)
194+
idxs = searchsorted(ax, ClosedInterval(idx.val,idx.val))
195+
length(idxs) > 1 && error("more than one datapoint lies on axis value $idx; use an interval to return all values")
196+
if length(idxs) == 1
197+
idxs[1]
198+
else # it's zero
199+
throw(BoundsError(ax, idx))
200+
end
201+
end
182202

183203
# Dimensional axes may be indexed by intervals to select a range
184204
axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::ClosedInterval) = searchsorted(ax, idx)

test/indexing.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,16 @@ A = AxisArray(OffsetArrays.OffsetArray([1 2; 3 4], 0:1, 1:2),
214214
@test_throws ArgumentError A[1.0f0]
215215
@test_throws ArgumentError A[:,6.1]
216216

217+
# Indexing with `atvalue` on Categorical axes
218+
A = AxisArray([1 2; 3 4], Axis{:x}([:a, :b]), Axis{:y}(["c", "d"]))
219+
@test @inferred(A[atvalue(:a)]) == @inferred(A[atvalue(:a), :]) == [1,2]
220+
@test @inferred(A[atvalue(:b)]) == @inferred(A[atvalue(:b), :]) == [3,4]
221+
@test_throws ArgumentError A[atvalue(:c)]
222+
@test @inferred(A[atvalue(:a), atvalue("c")]) == 1
223+
@test @inferred(A[:, atvalue("c")]) == [1,3]
224+
@test @inferred(A[Axis{:x}(atvalue(:b))]) == [3,4]
225+
@test @inferred(A[Axis{:y}(atvalue("d"))]) == [2,4]
226+
217227
# Test using dates
218228
using Base.Dates: Day, Month
219229
A = AxisArray(1:365, Date(2017,1,1):Date(2017,12,31))

0 commit comments

Comments
 (0)