Skip to content

Commit 5f2660c

Browse files
authored
Merge pull request #36 from JuliaAI/dev
For a 0.1.9 release
2 parents 8084912 + 3265ec6 commit 5f2660c

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CategoricalDistributions"
22
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.8"
4+
version = "0.1.9"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/types.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,9 +451,11 @@ function _UnivariateFinite(support,
451451
issubset(support, _classes) ||
452452
error("Specified support, $support, not contained in "*
453453
"specified pool, $(levels(classes)). ")
454-
_support = filter(_classes) do c
455-
c in support
456-
end
454+
idxs = getindex.(
455+
Ref(CategoricalArrays.DataAPI.invrefpool(_classes)),
456+
support
457+
)
458+
_support = _classes[idxs]
457459
end
458460

459461
# calls core method:

test/types.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using StableRNGs
77
using FillArrays
88
using ScientificTypes
99
import Random
10+
import CategoricalDistributions: classes
1011

1112
# coverage of constructor testing is expanded in the other test files
1213

@@ -36,6 +37,20 @@ end
3637

3738
UnivariateFinite(supp, probs, pool=missing, augment=true);
3839

40+
# construction from pool and support does not
41+
# consist of categorical elements (See issue #34)
42+
v = categorical(["x", "x", "y", "z", "y", "z", "p"])
43+
probs1 = [0.1, 0.2, 0.7]
44+
probs2 = [0.1 0.2 0.7; 0.5 0.2 0.3; 0.8 0.1 0.1]
45+
unf1 = UnivariateFinite(["y", "x", "z"], probs1, pool=v)
46+
unf2 = UnivariateFinite(["y", "x", "z"], probs2, pool=v)
47+
@test CategoricalArrays.pool(classes(unf1)) == CategoricalArrays.pool(v)
48+
@test CategoricalArrays.pool(classes(unf2)) == CategoricalArrays.pool(v)
49+
@test pdf.(unf1, ["y", "x", "z"]) == probs1
50+
@test pdf.(unf2, "y") == probs2[:, 1]
51+
@test pdf.(unf2, "x") == probs2[:, 2]
52+
@test pdf.(unf2, "z") == probs2[:, 3]
53+
3954
# dimension mismatches:
4055
badprobs = rand(rng, 40, 3)
4156
@test_throws(CategoricalDistributions.err_dim(supp, badprobs),

0 commit comments

Comments
 (0)