Skip to content

Commit 8084912

Browse files
authored
Merge pull request #32 from JuliaAI/dev
For a 0.1.8 release
2 parents 8592544 + 403ef67 commit 8084912

File tree

4 files changed

+33
-7
lines changed

4 files changed

+33
-7
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CategoricalDistributions"
22
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.7"
4+
version = "0.1.8"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -25,6 +25,7 @@ julia = "1.3"
2525
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2626
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2727
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
28+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
2829

2930
[targets]
30-
test = ["Random", "StableRNGs", "Test"]
31+
test = ["Random", "StableRNGs", "Test", "FillArrays"]

src/methods.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Ordered list of classes associated with non-zero probabilities.
4949
5050
v = categorical(["yes", "maybe", "no", "yes"])
5151
d = UnivariateFinite(v[1:2], [0.3, 0.7])
52-
support(d) # CategoricalArray{String,1,UInt32}["maybe", "no"]
52+
support(d) # CategoricalArray{String,1,UInt32}["maybe", "yes"]
5353
5454
"""
5555
Dist.support(d::UnivariateFiniteUnion) =

src/types.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ _check_augmentable(support, probs) = _check_probs_01(probs) &&
242242

243243
## AUGMENTING ARRAYS TO MAKE THEM PROBABILITY ARRAYS
244244

245-
_unwrap(A::Array) = A
246-
_unwrap(A::Vector) = first(A)
245+
_unwrap(A::AbstractArray) = A
246+
_unwrap(A::AbstractVector) = first(A)
247247

248248
isbinary(support) = length(support) == 2
249249

@@ -276,6 +276,9 @@ function _augment_probs(::Val{true},
276276
return aug_probs
277277
end
278278

279+
_array_or_scalar(x::Array) = x
280+
_array_or_scalar(x::AbstractArray) = copyto!(similar(Array{eltype(x)}, axes(x)), x)
281+
_array_or_scalar(x) = x
279282

280283
## CONSTRUCTORS - FROM DICTIONARY
281284

@@ -306,7 +309,7 @@ function UnivariateFinite(
306309
issubset(_support, parent_classes) ||
307310
error("Categorical elements are not from the same pool. ")
308311

309-
pairs = [int(c) => prob_given_class[c]
312+
pairs = [int(c) => _array_or_scalar(prob_given_class[c])
310313
for c in _support]
311314

312315
probs1 = first(values(prob_given_class))
@@ -341,7 +344,7 @@ function UnivariateFinite(d::AbstractDict{V,<:Prob};
341344
_classes = classes(pool)
342345
issubset(raw_support, _classes) ||
343346
error("Specified support, $raw_support, not contained in "*
344-
"specified pool, $(levels(classes)). ")
347+
"specified pool, $(levels(_classes)). ")
345348
support = filter(_classes) do c
346349
c in raw_support
347350
end

test/types.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Test
44
using CategoricalDistributions
55
using CategoricalArrays
66
using StableRNGs
7+
using FillArrays
78
using ScientificTypes
89
import Random
910

@@ -55,6 +56,27 @@ end
5556
augment=true,
5657
pool=missing))
5758

59+
# Test construction from non `Array` `AbstractArray`
60+
v = categorical(['x', 'x', 'y', 'x', 'z', 'w'])
61+
probs_fillarray = FillArrays.Ones(100, 3)
62+
probs_array = ones(100, 3)
63+
64+
probs1_fillarray = FillArrays.Fill(0.2, 100, 2)
65+
probs1_array = fill(0.2, 100, 2)
66+
67+
u_from_array = UnivariateFinite(['x', 'y', 'z'], probs_array, pool=v)
68+
u_from_fillarray = UnivariateFinite(['x', 'y', 'z'], probs_fillarray, pool=v)
69+
70+
u1_from_array = UnivariateFinite(
71+
['x', 'y', 'z'], probs1_array, pool=v, augment=true
72+
)
73+
u1_from_fillarray = UnivariateFinite(
74+
['x', 'y', 'z'], probs1_fillarray, pool=v, augment=true
75+
)
76+
77+
@test u_from_array.prob_given_ref == u_from_fillarray.prob_given_ref
78+
@test u1_from_array.prob_given_ref == u1_from_fillarray.prob_given_ref
79+
5880
# autosupport:
5981
u = UnivariateFinite(probs, pool=missing, augment=true);
6082
probs = probs ./ sum(probs)

0 commit comments

Comments
 (0)