Skip to content

Commit 3a65ffd

Browse files
authored
Merge pull request #35 from JuliaAI/raw_support
fix issue #34
2 parents 403ef67 + b5d26d5 commit 3a65ffd

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

src/types.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,9 +451,11 @@ function _UnivariateFinite(support,
451451
issubset(support, _classes) ||
452452
error("Specified support, $support, not contained in "*
453453
"specified pool, $(levels(classes)). ")
454-
_support = filter(_classes) do c
455-
c in support
456-
end
454+
idxs = getindex.(
455+
Ref(CategoricalArrays.DataAPI.invrefpool(_classes)),
456+
support
457+
)
458+
_support = _classes[idxs]
457459
end
458460

459461
# calls core method:

test/types.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using StableRNGs
77
using FillArrays
88
using ScientificTypes
99
import Random
10+
import CategoricalDistributions: classes
1011

1112
# coverage of constructor testing is expanded in the other test files
1213

@@ -36,6 +37,20 @@ end
3637

3738
UnivariateFinite(supp, probs, pool=missing, augment=true);
3839

40+
# construction from pool and support does not
41+
# consist of categorical elements (See issue #34)
42+
v = categorical(["x", "x", "y", "z", "y", "z", "p"])
43+
probs1 = [0.1, 0.2, 0.7]
44+
probs2 = [0.1 0.2 0.7; 0.5 0.2 0.3; 0.8 0.1 0.1]
45+
unf1 = UnivariateFinite(["y", "x", "z"], probs1, pool=v)
46+
unf2 = UnivariateFinite(["y", "x", "z"], probs2, pool=v)
47+
@test CategoricalArrays.pool(classes(unf1)) == CategoricalArrays.pool(v)
48+
@test CategoricalArrays.pool(classes(unf2)) == CategoricalArrays.pool(v)
49+
@test pdf.(unf1, ["y", "x", "z"]) == probs1
50+
@test pdf.(unf2, "y") == probs2[:, 1]
51+
@test pdf.(unf2, "x") == probs2[:, 2]
52+
@test pdf.(unf2, "z") == probs2[:, 3]
53+
3954
# dimension mismatches:
4055
badprobs = rand(rng, 40, 3)
4156
@test_throws(CategoricalDistributions.err_dim(supp, badprobs),

0 commit comments

Comments
 (0)