Skip to content

Commit 9ac9198

Browse files
authored
Merge pull request #19 from JuliaAI/classes
Make classes(v) work for views of categorical arrays
2 parents c5ca9a2 + 37898bc commit 9ac9198

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

src/utilities.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ always true.
2121
2222
Not to be confused with `levels(x.pool)`. See the example below.
2323
24-
Also, overloaded for `x` a `CategoricalArray` or `CategoricalPool`.
24+
Also, overloaded for `x` a `CategoricalArray`, `CategoricalPool`, and
25+
for views of `CategoricalArray`.
2526
2627
**Private method.*
2728
@@ -57,7 +58,7 @@ Also, overloaded for `x` a `CategoricalArray` or `CategoricalPool`.
5758
classes(p::CategoricalPool) = [p[i] for i in 1:length(p)]
5859
classes(x::CategoricalValue) = classes(CategoricalArrays.pool(x))
5960
classes(v::CategoricalArray) = classes(CategoricalArrays.pool(v))
60-
61+
classes(v::SubArray{<:Any, <:Any, <:CategoricalArray}) = classes(parent(v))
6162

6263
# # CATEGORICAL VALUES TO INTEGERS
6364

test/types.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ end
3636
UnivariateFinite(supp, probs, pool=missing, augment=true);
3737

3838
# dimension mismatches:
39-
badprobs = rand(40, 3)
39+
badprobs = rand(rng, 40, 3)
4040
@test_throws(CategoricalDistributions.err_dim(supp, badprobs),
4141
UnivariateFinite(supp, badprobs, pool=missing))
4242

@@ -60,7 +60,7 @@ end
6060
probs = probs ./ sum(probs)
6161
u = UnivariateFinite(probs, pool=missing);
6262
@test u isa UnivariateFinite
63-
probs = rand(10, 2)
63+
probs = rand(rng, 10, 2)
6464
probs = probs ./ sum(probs, dims=2)
6565
u = UnivariateFinite(probs, pool=missing);
6666
@test u.scitype == Multiclass{2}
@@ -73,9 +73,11 @@ end
7373

7474
v = categorical(1:3)
7575
@test_logs((:warn, r"Ignoring"),
76-
UnivariateFinite(v[1:2], rand(3), augment=true, pool=missing));
76+
UnivariateFinite(v[1:2], rand(rng, 3),
77+
augment=true, pool=missing));
7778
@test_logs((:warn, r"Ignoring"),
78-
UnivariateFinite(v[1:2], rand(3), augment=true, ordered=true));
79+
UnivariateFinite(v[1:2], rand(rng, 3),
80+
augment=true, ordered=true));
7981

8082
# using `UnivariateFiniteArray` as a constructor just falls back
8183
# to `UnivariateFinite` constructor:

test/utilities.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import CategoricalDistributions: classes, transform, decoder, int
1717
levels!(v, reverse(levels(v)))
1818
@test classes(v[1]) == levels(v)
1919
@test classes(v) == levels(v)
20+
vsub = view(v, 1:2)
21+
@test classes(vsub) == classes(v)
2022
end
2123

2224
@testset "int, classes, decoder" begin

0 commit comments

Comments
 (0)