Skip to content

Commit f576e14

Browse files
committed
optimize pdf broadcasting
1 parent 793f90c commit f576e14

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

src/arrays.jl

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@ end
151151

152152
# dummy function
153153
# returns `x[i]` for `Array` inputs `x`
154-
# For non-Array inputs returns `zero(dtype)`
154+
# For non-Array inputs returns the input
155155
#This avoids using an if statement
156-
_getindex(x::Array, i, dtype)=x[i]
157-
_getindex(::Nothing, i, dtype) = zero(dtype)
156+
_getindex(x::Array, i) = x[i]
157+
_getindex(x, i) = x
158158

159159
# pdf.(u, cv)
160160
function Base.Broadcast.broadcasted(
@@ -164,15 +164,16 @@ function Base.Broadcast.broadcasted(
164164

165165
# we assume that we compare categorical values by their unwrapped value
166166
# and pick the index of this value from classes(u)
167-
cv_loc = findfirst(==(cv), classes(u))
168-
cv_loc == 0 && throw(err_missing_class(cv))
167+
_classes = classes(u)
168+
cv_loc = get(CategoricalArrays.pool(_classes), cv, zero(R))
169+
isequal(cv_loc, 0) && throw(err_missing_class(cv))
169170

170171
f() = zeros(P, size(u)) #default caller function
171172

172173
return Base.Broadcast.Broadcasted(
173174
identity,
174175
(get(f, u.prob_given_ref, cv_loc),)
175-
)
176+
)
176177
end
177178

178179
Base.Broadcast.broadcasted(
@@ -191,15 +192,23 @@ function Base.Broadcast.broadcasted(
191192
length(u) == length(v) ||throw(DimensionMismatch(
192193
"Arrays could not be broadcast to a common size; "*
193194
"got a dimension with lengths $(length(u)) and $(length(v))"))
195+
196+
_classes = classes(u)
197+
_classes_pool = CategoricalArrays.pool(_classes)
198+
T = eltype(v) >: Missing ? Missing : Union{}
199+
v_loc_flat = Vector{Union{Int, T}}(undef, length(v))
200+
201+
202+
for (i, x) in enumerate(v)
203+
v_loc_flat_i = ismissing(x) ? missing : get(_classes_pool, x, zero(R))
204+
isequal(v_loc_flat_i, 0) && throw(err_missing_class(x))
205+
v_loc_flat[i] = v_loc_flat_i
206+
end
194207

195-
v_loc_flat = [(ismissing(x) ? missing : findfirst(==(x), classes(u)), i)
196-
for (i, x) in enumerate(v)]
197-
any(isequal(0), v_loc_flat) && throw(err_missing_class(cv))
198-
199-
getter((cv_loc, i), dtype) =
200-
_getindex(get(u.prob_given_ref, cv_loc, nothing), i, dtype)
201-
getter(::Tuple{Missing,Any}, dtype) = missing
202-
ret_flat = getter.(v_loc_flat, P)
208+
getter(cv_loc) =
209+
_getindex(get(u.prob_given_ref, _classes[cv_loc], zero(P)), cv_loc)
210+
getter(::Missing) = missing
211+
ret_flat = getter.(v_loc_flat)
203212
return reshape(ret_flat, size(u))
204213
end
205214

test/arrays.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,13 @@ _skip(v) = collect(skipmissing(v))
166166
@test _skip(broadcast(logpdf, u, unwrap.(v))) ==
167167
_skip([logpdf(u[i], v[i]) for i in 1:length(u)])
168168
end
169+
170+
## Check that the appropriate errors are thrown
171+
v1 = [v0[1:end-1];"strange_level"]
172+
v2 = [v0...;rand(rng, v0)] #length(u) !== length(v2)
173+
@test_throws DimensionMismatch broadcast(pdf, u, v2)
174+
@test_throws DomainError broadcast(pdf, u, v1)
175+
169176
end
170177

171178
@testset "broadcasting: check indexing in `getter((cv, i), dtype)` see PR#375" begin

0 commit comments

Comments
 (0)