Skip to content

Commit 95af562

Browse files
committed
adapt to changes in levels behaviour in CategoricalDistributions
1 parent 1c0b061 commit 95af562

File tree

3 files changed

+25
-20
lines changed

3 files changed

+25
-20
lines changed

src/models/discriminant_analysis.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ const ERR_LONE_TARGET_CLASS = ArgumentError(
2020
)
2121

2222
function _check_lda_data(model, X, y)
23-
pool = MMI.classes(y[1]) # Class list containing entries in pool of `y`.
23+
pool = CategoricalDistributions.levels(y[1]) # Class list containing entries in pool
24+
# of `y`.
2425
classes_seen = unique(y) # Class list of actual entries in seen in `y`.
2526
nc = length(classes_seen) # Number of actual classes seen in `y`.
2627

@@ -109,7 +110,7 @@ function MMI.predict(m::LDA, (core_res, classes_seen, pool), Xnew)
109110
Pr .*= -1
110111
# apply a softmax transformation
111112
softmax!(Pr)
112-
return MMI.UnivariateFinite(classes_seen, Pr, pool=pool)
113+
return MMI.UnivariateFinite(classes_seen, Pr)
113114
end
114115

115116
metadata_model(
@@ -160,7 +161,7 @@ function _check_prob01(priors)
160161
end
161162

162163
@inline function _check_lda_priors(priors::UnivariateFinite, classes_seen, pool)
163-
if MMI.classes(priors) != pool
164+
if CategoricalDistributions.levels(priors) != pool
164165
throw(
165166
ArgumentError(
166167
"UnivariateFinite `priors` must have common pool with training target."
@@ -236,7 +237,7 @@ function MMI.fitted_params(::BayesianLDA, (core_res, classes_seen, pool, priors
236237
return (
237238
classes = classes_seen,
238239
projection_matrix=MS.projection(core_res),
239-
priors=MMI.UnivariateFinite(classes_seen, priors, pool=pool)
240+
priors=MMI.UnivariateFinite(classes_seen, priors)
240241
)
241242
end
242243

@@ -261,7 +262,7 @@ function MMI.predict(m::BayesianLDA, (core_res, classes_seen, pool, priors, n),
261262

262263
# apply a softmax transformation to convert Pr to a probability matrix
263264
softmax!(Pr)
264-
return MMI.UnivariateFinite(classes_seen, Pr, pool=pool)
265+
return MMI.UnivariateFinite(classes_seen, Pr)
265266
end
266267

267268
function MMI.transform(m::T, (core_res, ), X) where T<:Union{LDA, BayesianLDA}
@@ -353,7 +354,7 @@ function MMI.predict(m::SubspaceLDA, (core_res, outdim, classes_seen, pool), Xne
353354
Pr .*= -1
354355
# apply a softmax transformation
355356
softmax!(Pr)
356-
return MMI.UnivariateFinite(classes_seen, Pr, pool=pool)
357+
return MMI.UnivariateFinite(classes_seen, Pr)
357358
end
358359

359360
metadata_model(
@@ -430,7 +431,7 @@ function MMI.fitted_params(
430431
return (
431432
classes = classes_seen,
432433
projection_matrix=core_res.projw * view(core_res.projLDA, :, 1:outdim),
433-
priors=MMI.UnivariateFinite(classes_seen, priors, pool=pool)
434+
priors=MMI.UnivariateFinite(classes_seen, priors)
434435
)
435436
end
436437

@@ -470,7 +471,7 @@ function MMI.predict(
470471

471472
# apply a softmax transformation to convert Pr to a probability matrix
472473
softmax!(Pr)
473-
return MMI.UnivariateFinite(classes_seen, Pr, pool=pool)
474+
return MMI.UnivariateFinite(classes_seen, Pr)
474475
end
475476

476477
function MMI.transform(
@@ -724,7 +725,8 @@ The fields of `fitted_params(mach)` are:
724725
section below).
725726
726727
- `priors`: The class priors for classification. As inferred from training target `y`, if
727-
not user-specified. A `UnivariateFinite` object with levels consistent with `levels(y)`.
728+
not user-specified. A `UnivariateFinite` object with levels (classes) consistent with
729+
`levels(y)`.
728730
729731
# Report
730732
@@ -954,7 +956,8 @@ The fields of `fitted_params(mach)` are:
954956
section below).
955957
956958
- `priors`: The class priors for classification. As inferred from training target `y`, if
957-
not user-specified. A `UnivariateFinite` object with levels consistent with `levels(y)`.
959+
not user-specified. A `UnivariateFinite` object with levels (classes) consistent with
960+
`levels(y)`.
958961
959962
# Report
960963

test/models/discriminant_analysis.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,15 @@ end
7373
ytest = selectrows(y, test)
7474

7575
BLDA_model = BayesianLDA(regcoef=0)
76-
76+
7777
## Check model `fit`
7878
fitresult, cache, report = fit(BLDA_model, 1, Xtrain, ytrain)
7979
classes_seen, projection_matrix, priors = fitted_params(BLDA_model, fitresult)
80-
@test classes(priors) == classes(y)
80+
@test levels(priors) == levels(y)
8181
@test pdf.(priors, support(priors)) == [491/998, 507/998]
8282
@test classes_seen == ["Up", "Down"]
8383
@test round.((report.class_means)', sigdigits = 3) == [-0.0395 -0.0313; 0.0428 0.0339] #[0.0428 0.0339; -0.0395 -0.0313]
84-
84+
8585
## Check model `predict`
8686
preds = predict(BLDA_model, fitresult, Xtest)
8787
mce = cross_entropy(preds, ytest) |> mean
@@ -94,7 +94,7 @@ end
9494
fitresult1, cache1, report1 = fit(BLDA_model1, 1, Xtrain, ytrain)
9595
classes_seen1, projection_matrix1, priors1 = fitted_params(BLDA_model1, fitresult1)
9696
BLDA_model2 = BayesianLDA(
97-
regcoef=0, priors=UnivariateFinite(classes(ytrain), [491/998, 507/998])
97+
regcoef=0, priors=UnivariateFinite(levels(ytrain), [491/998, 507/998])
9898
)
9999
fitresult2, cache2, report2 = fit(BLDA_model2, 1, Xtrain, ytrain)
100100
classes_seen2, projection_matrix2, priors2 = fitted_params(BLDA_model2, fitresult2)
@@ -156,7 +156,7 @@ end
156156
LDA_model1, fitresult1
157157
)
158158
LDA_model2 = BayesianSubspaceLDA(
159-
priors=UnivariateFinite(classes(y), [1/3, 1/3, 1/3])
159+
priors=UnivariateFinite(levels(y), [1/3, 1/3, 1/3])
160160
)
161161
fitresult2, cache2, report2 = fit(LDA_model2, 1, X, y)
162162
classes_seen2, projection_matrix2, priors2 = fitted_params(
@@ -231,24 +231,24 @@ end
231231
y2 = y[[1,2,1,2]]
232232
@test_throws ArgumentError fit(model, 1, X, y2)
233233

234-
## Check to make sure error is thrown if UnivariateFinite `priors` doesn't have
234+
## Check to make sure error is thrown if UnivariateFinite `priors` doesn't have
235235
## common pool with target vector used in training.
236236
model = BayesianLDA(priors=UnivariateFinite([0.1, 0.5, 0.4], pool=missing))
237237
@test_throws ArgumentError fit(model, 1, X, y)
238238

239-
## Check to make sure error is thrown if keys used in `priors` dictionary are in pool
239+
## Check to make sure error is thrown if keys used in `priors` dictionary are in pool
240240
## of training target used in training.
241241
model = BayesianLDA(priors=Dict("apples" => 0.1, "oranges"=>0.5, "bannana"=>0.4))
242242
@test_throws ArgumentError fit(model, 1, X, y)
243243

244244
## Check to make sure error is thrown if sum(`priors`) isn't approximately equal to 1.
245-
model = BayesianLDA(priors=UnivariateFinite(classes(y), [0.1, 0.5, 0.4, 0.2]))
245+
model = BayesianLDA(priors=UnivariateFinite(levels(y), [0.1, 0.5, 0.4, 0.2]))
246246
@test_throws ArgumentError fit(model, 1, X, y)
247247

248248
## Check to make sure error is thrown if `priors .< 0` or `priors .> 1`.
249-
model = BayesianLDA(priors=Dict(classes(y) .=> [-0.1, 0.0, 1.0, 0.1]))
249+
model = BayesianLDA(priors=Dict(levels(y) .=> [-0.1, 0.0, 1.0, 0.1]))
250250
@test_throws ArgumentError fit(model, 1, X, y)
251-
model = BayesianLDA(priors=Dict(classes(y) .=> [1.1, 0.0, 0.0, -0.1]))
251+
model = BayesianLDA(priors=Dict(levels(y) .=> [1.1, 0.0, 0.0, -0.1]))
252252
@test_throws ArgumentError fit(model, 1, X, y)
253253

254254
X2 = (x=rand(5),)

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ import Dates
22
import MLJMultivariateStatsInterface: _replace!
33
import MultivariateStats
44
import Random
5+
import CategoricalDistributions.levels
56

67
using LinearAlgebra
78
using MLJBase
9+
using StatisticalMeasures
810
using MLJMultivariateStatsInterface
911
using StableRNGs
1012
using Test

0 commit comments

Comments
 (0)