Skip to content

Commit 70aa8f4

Browse files
committed
add tests
1 parent 858871a commit 70aa8f4

File tree

3 files changed

+72
-3
lines changed

3 files changed

+72
-3
lines changed

src/arrays.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ for func in [:pdf, :logpdf]
102102
eval(quote
103103
function Distributions.$func(
104104
u::AbstractArray{UnivariateFinite{S,V,R,P},N},
105-
C::AbstractVector{<:Union{
106-
V,
107-
CategoricalValue{V,R}}}) where {S,V,R,P,N}
105+
C::AbstractVector) where {S,V,R,P,N}
108106

109107
ret = zeros(P, size(u)..., length(C))
110108
# note that we do not require C to use 1-base indexing

test/arrays.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,4 +290,69 @@ end
290290

291291
end
292292

293+
function (x::T, y::T) where {T<:UnivariateFinite}
294+
return x.decoder == y.decoder &&
295+
x.prob_given_ref == y.prob_given_ref &&
296+
x.scitype == y.scitype
297+
end
298+
299+
@testset "CartesianIndex" begin
300+
v = categorical(["a", "b"], ordered=true)
301+
m = UnivariateFinite(v, rand(rng, 5, 2), augment=true)
302+
@test m[1, 1] m[CartesianIndex(1, 1)] m[CartesianIndex(1, 1, 1)]
303+
@test_throws BoundsError m[CartesianIndex(1)]
304+
@test all(zip(Matrix(m), copy(m), m)) do (x, y, z)
305+
return x y z
306+
end
307+
@test Matrix(m) isa Matrix
308+
# TODO: probably it would be better for copy to keep it
309+
# UnivariateFiniteArray but it would be breaking
310+
@test copy(m) isa Matrix
311+
@test similar(m) isa Matrix
312+
end
313+
314+
@testset "broadcasted pdf" begin
315+
v = categorical(["a", "b"], ordered=true)
316+
v2 = categorical(["a", "b"], ordered=true, levels=["b", "a"])
317+
x = UnivariateFinite(v, rand(rng, 5), augment=true)
318+
@test pdf.(x, v[1]) == pdf.(x, v2[1]) == pdf.(x, "a")
319+
@test pdf.(x, v[2]) == pdf.(x, v2[2]) == pdf.(x, "b")
320+
321+
x = UnivariateFinite(v, rand(rng, 5, 2), augment=true)
322+
@test size(pdf.(x, missing)) == (5, 2)
323+
324+
v3 = categorical(["a" "b"], ordered=true)
325+
v4 = categorical(["a" "b"], ordered=true, levels=["b", "a"])
326+
# note that v5 and v6 have the same shape and contents as v3 and v4
327+
# just they are Matrix{Any} not CategoricalMatrix
328+
v5 = Any[v3[1] v3[2]]
329+
v6 = Any[v4[1] v4[2]]
330+
x = UnivariateFinite(v, hcat([0.1, 0.2]), augment=true)
331+
332+
# these tests show that now we have corrected refpools
333+
# but still there is an inconsistency in behavior
334+
@test pdf.(x, v) == hcat([0.9, 0.2])
335+
@test pdf.(x, v2) == hcat([0.9, 0.2])
336+
@test pdf.(x, v3) == hcat([0.9, 0.2])
337+
@test pdf.(x, v4) == hcat([0.9, 0.2])
338+
@test pdf.(x, v5) == [0.9 0.1; 0.8 0.2]
339+
@test pdf.(x, v6) == [0.9 0.1; 0.8 0.2]
340+
end
341+
342+
@testset "pdf with various types" begin
343+
v = categorical(["a", "b"], ordered=true)
344+
a = view("a", 1:1) # quite common case when splitting strings
345+
b = view("b", 1:1)
346+
x = UnivariateFinite(v, [0.1, 0.2, 0.3], augment=true)
347+
@test pdf.(x, a) == pdf.(x, "a") == pdf.(x, v[1])
348+
@test logpdf.(x, a) == logpdf.(x, "a") == logpdf.(x, v[1])
349+
@test pdf(x, [a, b]) == pdf(x, ["a", "b"]) == pdf(x, v)
350+
@test logpdf(x, [a, b]) == logpdf(x, ["a", "b"]) == logpdf(x, v)
351+
352+
x = UnivariateFinite(v, 0.1, augment=true)
353+
@test pdf.(x, a) == pdf.(x, "a") == pdf.(x, v[1]) == 0.9
354+
@test logpdf.(x, a) == logpdf.(x, "a") == logpdf.(x, v[1]) == log(0.9)
355+
@test pdf(x, a) == pdf(x, "a") == pdf(x, v[1]) == 0.9
356+
@test logpdf(x, a) == logpdf(x, "a") == logpdf(x, v[1]) == log(0.9)
357+
end
293358
true

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ rng = StableRNGs.StableRNG(123)
99

1010
import CategoricalDistributions: classes, decoder, int
1111

12+
ambiguities_vec = Test.detect_ambiguities(CategoricalDistributions,
13+
recursive=true)
14+
if !isempty(ambiguities_vec)
15+
@warn "$(length(ambiguities_vec)) method ambiguities detected"
16+
end
17+
1218
@testset "utilities" begin
1319
@test include("utilities.jl")
1420
end

0 commit comments

Comments
 (0)