Skip to content

Commit 5c62c4e

Browse files
authored
Merge pull request #48 from JuliaAI/indexing
Optimize broadcasting on the pdf function
2 parents c51956b + 1f35d81 commit 5c62c4e

File tree

2 files changed

+38
-17
lines changed

2 files changed

+38
-17
lines changed

src/arrays.jl

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@ end
151151

152152
# dummy function
153153
# returns `x[i]` for `Array` inputs `x`
154-
# For non-Array inputs returns `zero(dtype)`
154+
# For non-Array inputs returns the input
155155
#This avoids using an if statement
156-
_getindex(x::Array, i, dtype)=x[i]
157-
_getindex(::Nothing, i, dtype) = zero(dtype)
156+
_getindex(x::Array, i) = x[i]
157+
_getindex(x, i) = x
158158

159159
# pdf.(u, cv)
160160
function Base.Broadcast.broadcasted(
@@ -164,15 +164,16 @@ function Base.Broadcast.broadcasted(
164164

165165
# we assume that we compare categorical values by their unwrapped value
166166
# and pick the index of this value from classes(u)
167-
cv_loc = findfirst(==(cv), classes(u))
168-
cv_loc == 0 && throw(err_missing_class(cv))
167+
_classes = classes(u)
168+
cv_loc = get(CategoricalArrays.pool(_classes), cv, zero(R))
169+
isequal(cv_loc, 0) && throw(err_missing_class(cv))
169170

170171
f() = zeros(P, size(u)) #default caller function
171172

172173
return Base.Broadcast.Broadcasted(
173174
identity,
174175
(get(f, u.prob_given_ref, cv_loc),)
175-
)
176+
)
176177
end
177178

178179
Base.Broadcast.broadcasted(
@@ -191,15 +192,23 @@ function Base.Broadcast.broadcasted(
191192
length(u) == length(v) ||throw(DimensionMismatch(
192193
"Arrays could not be broadcast to a common size; "*
193194
"got a dimension with lengths $(length(u)) and $(length(v))"))
195+
196+
_classes = classes(u)
197+
_classes_pool = CategoricalArrays.pool(_classes)
198+
T = eltype(v) >: Missing ? Missing : Union{}
199+
v_loc_flat = Vector{Tuple{Union{R, T}, Int}}(undef, length(v))
200+
201+
202+
for (i, x) in enumerate(v)
203+
cv_ref = ismissing(x) ? missing : get(_classes_pool, x, zero(R))
204+
isequal(cv_ref, 0) && throw(err_missing_class(x))
205+
v_loc_flat[i] = (cv_ref, i)
206+
end
194207

195-
v_loc_flat = [(ismissing(x) ? missing : findfirst(==(x), classes(u)), i)
196-
for (i, x) in enumerate(v)]
197-
any(isequal(0), v_loc_flat) && throw(err_missing_class(cv))
198-
199-
getter((cv_loc, i), dtype) =
200-
_getindex(get(u.prob_given_ref, cv_loc, nothing), i, dtype)
201-
getter(::Tuple{Missing,Any}, dtype) = missing
202-
ret_flat = getter.(v_loc_flat, P)
208+
getter((cv_ref, i)) =
209+
_getindex(get(u.prob_given_ref, cv_ref, zero(P)), i)
210+
getter(::Tuple{Missing,Any}) = missing
211+
ret_flat = getter.(v_loc_flat)
203212
return reshape(ret_flat, size(u))
204213
end
205214

test/arrays.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ end
149149
u2 = UnivariateFinite(v[1:2], probs, augment=true)
150150
@test pdf.(u2, v[3]) == zeros(3)
151151
@test isequal(logpdf.(u2, v[3]), log.(zeros(3)))
152+
153+
## Check that the appropriate errors are thrown
154+
@test_throws DomainError pdf.(u,"strange_level")
152155
end
153156

154157
_skip(v) = collect(skipmissing(v))
@@ -166,18 +169,27 @@ _skip(v) = collect(skipmissing(v))
166169
@test _skip(broadcast(logpdf, u, unwrap.(v))) ==
167170
_skip([logpdf(u[i], v[i]) for i in 1:length(u)])
168171
end
172+
173+
## Check that the appropriate errors are thrown
174+
v1 = categorical([v0[1:end-1]...;"strange_level"])
175+
v2 = [v0...;rand(rng, v0)] #length(u) !== length(v2)
176+
v3 = categorical([vm[end:-1:begin+1]...;"strange_level"])
177+
@test_throws DimensionMismatch broadcast(pdf, u, v2)
178+
@test_throws DomainError broadcast(pdf, u, v1)
179+
@test_throws DomainError broadcast(pdf, u, v3)
180+
169181
end
170182

171-
@testset "broadcasting: check indexing in `getter((cv, i), dtype)` see PR#375" begin
183+
@testset "broadcasting: check indexing in `getter((cv_ref, i))` see PR#375 from MLJBase" begin
172184
c = categorical([0,1,1])
173185
d = UnivariateFinite(c[1:1], [1 1 1]')
174186
v = categorical([0,1,1,1])
175187
@test broadcast(pdf, d, v[2:end]) == [0,0,0]
176188
end
177189

178190
@testset "_getindex" begin
179-
@test CategoricalDistributions._getindex(collect(1:4), 2, Int64) == 2
180-
@test CategoricalDistributions._getindex(nothing, 2, Int64) == zero(Int64)
191+
@test CategoricalDistributions._getindex(collect(1:4), 2) == 2
192+
@test CategoricalDistributions._getindex(0, 2) === 0
181193
end
182194

183195
@testset "broadcasting mode" begin

0 commit comments

Comments
 (0)