Skip to content

Commit 4d8548e

Browse files
committed
fix indexing bug #43
1 parent d8a0b4b commit 4d8548e

File tree

2 files changed

+60
-14
lines changed

2 files changed

+60
-14
lines changed

src/arrays.jl

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,49 @@ const UniFinArr = UnivariateFiniteArray
44

55
Base.size(u::UniFinArr, args...) =
66
size(first(values(u.prob_given_ref)), args...)
7-
8-
function Base.getindex(u::UniFinArr{<:Any,<:Any,R,P,N},
9-
i::Integer...) where {R,P,N}
10-
prob_given_ref = LittleDict{R,P}()
11-
for ref in keys(u.prob_given_ref)
12-
prob_given_ref[ref] = getindex(u.prob_given_ref[ref], i...)
7+
8+
function Base.getindex(u::UniFinArr{<:Any,<:Any,R, P}, i...) where {R, P}
9+
# It's faster to generate `Array`s of `refs` and indexed `ref_probs`
10+
# and pass them to the `LittleDict` constructor.
11+
# The first element of `u.prob_given_ref` is used to get the dimensions
12+
# for allocating these arrays.
13+
u_dict = u.prob_given_ref
14+
a, rest = Iterators.peel(u_dict)
15+
# `a` is of the form `key => value`.
16+
a_ref, a_prob = first(a), getindex(last(a), i...)
17+
18+
# Preallocate Arrays using the key and value of the first
19+
# element (i.e `a`) of `u_dict`.
20+
n_refs = length(u_dict)
21+
refs = Vector{R}(undef, n_refs)
22+
if a_prob isa AbstractArray
23+
ref_probs = Vector{Array{P, ndims(a_prob)}}(undef, n_refs)
24+
unf_constructor = UniFinArr
25+
else
26+
ref_probs = Vector{P}(undef, n_refs)
27+
unf_constructor = UnivariateFinite
1328
end
14-
return UnivariateFinite(u.scitype, u.decoder, prob_given_ref)
15-
end
1629

17-
function Base.getindex(u::UniFinArr{<:Any,<:Any,R,P,N},
18-
I...) where {R,P,N}
19-
prob_given_ref = LittleDict{R,Array{P,N}}()
20-
for ref in keys(u.prob_given_ref)
21-
prob_given_ref[ref] = getindex(u.prob_given_ref[ref], I...)
30+
# Fill in the first elements
31+
# Both `refs` and `ref_probs` are both of type `Vector` and hence support
32+
# linear indexing with index starting at `1`
33+
refs[1] = a_ref
34+
ref_probs[1] = a_prob
35+
36+
# Fill in the rest
37+
iter = 2
38+
for (ref, ref_prob) in rest
39+
refs[iter] = ref
40+
ref_probs[iter] = getindex(ref_prob, i...)
41+
iter += 1
2242
end
23-
return UniFinArr(u.scitype, u.decoder, prob_given_ref)
43+
44+
# `keytype(prob_given_ref)` is always same as `keytype(u_dict)`.
45+
# But `ndims(valtype(prob_given_ref))` might not be the same
46+
# as `ndims(valtype(u_dict))`.
47+
prob_given_ref = LittleDict{R, eltype(ref_probs)}(refs, ref_probs)
48+
49+
return unf_constructor(u.scitype, u.decoder, prob_given_ref)
2450
end
2551

2652
function Base.setindex!(u::UniFinArr{S,V,R,P,N},

test/arrays.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,26 @@ end
288288
classes(u2[1:1]))
289289
end
290290

291+
function (x::T, y::T) where {T<:UnivariateFinite}
292+
return x.decoder == y.decoder &&
293+
x.prob_given_ref == y.prob_given_ref &&
294+
x.scitype == y.scitype
295+
end
296+
297+
function (x::AbstractArray, y::AbstractArray)
298+
return all(().(x, y))
299+
end
300+
301+
@testset "indexing of UnivariateFininiteArray (see issue #43)" begin
302+
u = UnivariateFinite(['x', 'z'], rand(2, 3, 2), pool=missing, ordered=true)
303+
v = u[1:2]
304+
@test v isa UnivariateFiniteArray
305+
@test v u[1:2, 1] u[[1,2], 1]
306+
w = u[2]
307+
@test w isa UnivariateFinite
308+
@test w v[2]
309+
end
310+
291311
end
292312

293313
true

0 commit comments

Comments
 (0)