Skip to content

Commit 87a8da1

Browse files
committed
fix code and add more tests
1 parent f576e14 commit 87a8da1

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

src/arrays.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,18 +196,18 @@ function Base.Broadcast.broadcasted(
196196
_classes = classes(u)
197197
_classes_pool = CategoricalArrays.pool(_classes)
198198
T = eltype(v) >: Missing ? Missing : Union{}
199-
v_loc_flat = Vector{Union{Int, T}}(undef, length(v))
199+
v_loc_flat = Vector{Tuple{Union{R, T}, Int}}(undef, length(v))
200200

201201

202202
for (i, x) in enumerate(v)
203-
v_loc_flat_i = ismissing(x) ? missing : get(_classes_pool, x, zero(R))
204-
isequal(v_loc_flat_i, 0) && throw(err_missing_class(x))
205-
v_loc_flat[i] = v_loc_flat_i
203+
cv_ref = ismissing(x) ? missing : get(_classes_pool, x, zero(R))
204+
isequal(cv_ref, 0) && throw(err_missing_class(x))
205+
v_loc_flat[i] = (cv_ref, i)
206206
end
207207

208-
getter(cv_loc) =
209-
_getindex(get(u.prob_given_ref, _classes[cv_loc], zero(P)), cv_loc)
210-
getter(::Missing) = missing
208+
getter((cv_ref, i)) =
209+
_getindex(get(u.prob_given_ref, cv_ref, zero(P)), i)
210+
getter(::Tuple{Missing,Any}) = missing
211211
ret_flat = getter.(v_loc_flat)
212212
return reshape(ret_flat, size(u))
213213
end

test/arrays.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ end
149149
u2 = UnivariateFinite(v[1:2], probs, augment=true)
150150
@test pdf.(u2, v[3]) == zeros(3)
151151
@test isequal(logpdf.(u2, v[3]), log.(zeros(3)))
152+
153+
## Check that the appropriate errors are thrown
154+
@test_throws DomainError pdf.(u,"strange_level")
152155
end
153156

154157
_skip(v) = collect(skipmissing(v))
@@ -168,23 +171,23 @@ _skip(v) = collect(skipmissing(v))
168171
end
169172

170173
## Check that the appropriate errors are thrown
171-
v1 = [v0[1:end-1];"strange_level"]
174+
v1 = categorical([v0[1:end-1];"strange_level"])
172175
v2 = [v0...;rand(rng, v0)] #length(u) !== length(v2)
173176
@test_throws DimensionMismatch broadcast(pdf, u, v2)
174177
@test_throws DomainError broadcast(pdf, u, v1)
175178

176179
end
177180

178-
@testset "broadcasting: check indexing in `getter((cv, i), dtype)` see PR#375" begin
181+
@testset "broadcasting: check indexing in `getter((cv_ref, i))` see PR#375 from MLJBase" begin
179182
c = categorical([0,1,1])
180183
d = UnivariateFinite(c[1:1], [1 1 1]')
181184
v = categorical([0,1,1,1])
182185
@test broadcast(pdf, d, v[2:end]) == [0,0,0]
183186
end
184187

185188
@testset "_getindex" begin
186-
@test CategoricalDistributions._getindex(collect(1:4), 2, Int64) == 2
187-
@test CategoricalDistributions._getindex(nothing, 2, Int64) == zero(Int64)
189+
@test CategoricalDistributions._getindex(collect(1:4), 2) == 2
190+
@test CategoricalDistributions._getindex(0, 2) === 0
188191
end
189192

190193
@testset "broadcasting mode" begin

0 commit comments

Comments
 (0)