Skip to content

Commit 8523ae0

Browse files
authored
Merge pull request #11 from OkonSamuel/bugfix
fixed bug, improve code efficiency and add more tests
2 parents 0b996da + 826a4c5 commit 8523ae0

File tree

3 files changed

+127
-29
lines changed

3 files changed

+127
-29
lines changed

src/models/discriminant_analysis.jl

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,24 @@ function MMI.fit(model::LDA, ::Int, X, y)
6767
return fitresult, cache, report
6868
end
6969

70+
71+
"""
72+
_replace!(y::AbstractVector, z::AbstractVector:, r::AbstractRange)
73+
74+
internal method essentially the same as
75+
Base.replace!(y, (z .=> r)...)
76+
but more efficient
77+
"""
78+
function _replace!(y::AbstractVector, z::AbstractVector, r::AbstractRange)
79+
length(r) == length(z) ||
80+
throw(ArgumentError("`z` and `r` has to be of the same length"))
81+
@inbounds for i in eachindex(y)
82+
for j in eachindex(z)
83+
isequal(z[j], y[i]) && (y[i] = r[j])
84+
end
85+
end
86+
end
87+
7088
function _check_lda_data(model, X, y)
7189
class_list = MMI.classes(y[1]) # Class list containing entries in pool of y.
7290
nclasses = length(class_list)
@@ -79,7 +97,7 @@ function _check_lda_data(model, X, y)
7997
yplain = MMI.int(y) # Vector of n ints in {1,..., nclasses}.
8098
p, n = size(Xm_t)
8199
# Recode yplain to be in {1,..., nc}
82-
nc == nclasses || replace!(yplain, (integers_seen .=> 1:nc)...)
100+
nc == nclasses || _replace!(yplain, integers_seen, 1:nc)
83101
# Check to make sure we have more than one class in training sample.
84102
# This is to prevent Sb from being a zero matrix.
85103
if nc <= 1
@@ -90,7 +108,6 @@ function _check_lda_data(model, X, y)
90108
)
91109
)
92110
end
93-
94111
# Check to make sure we have more samples than classes.
95112
# This is to prevent Sw from being the zero matrix.
96113
if n <= nc
@@ -134,7 +151,7 @@ function MMI.predict(m::LDA, (core_res, classes_seen), Xnew)
134151
# compute the distances in the transformed space between pairs of rows
135152
# the probability matrix Pr is `n x nc` and normalised accross rows
136153
Pr = pairwise(m.dist, XWt, centroids, dims=1)
137-
Pr .= Pr .* -1
154+
Pr .*= -1
138155
# apply a softmax transformation
139156
softmax!(Pr)
140157
return MMI.UnivariateFinite(classes_seen, Pr)
@@ -239,26 +256,33 @@ function _matrix_transpose(model::Union{LDA, BayesianLDA}, X)
239256
return MMI.matrix(X; transpose=true)
240257
end
241258

242-
function _check_lda_priors(priors, nc, nclasses, integers_seen)
259+
@inline function _check_lda_priors(priors, nc, nclasses, integers_seen)
243260
if length(priors) != nclasses
244261
throw(ArgumentError("Invalid size of `priors`."))
245-
end
262+
end
263+
264+
# `priors` is esssentially always an instance of type `Vector{Float64}`.
265+
# The next two conditions implicitly checks that
266+
# ` 0 .<= priors .<= 1` and `sum(priors) ≈ 1` are true.
246267
if !isapprox(sum(priors), 1)
247268
throw(ArgumentError("probabilities specified in `priors` must sum to 1"))
248269
end
249-
if any(model.priors .< 0)
270+
if all(>=(0), priors)
250271
throw(ArgumentError("probabilities specified in `priors` must non-negative"))
251272
end
252273
# Select priors for unique classes in `y` (For resampling purporses).
253274
priors_ = nc == nclasses ? model.priors : @view model.priors[integers_seen]
254275
return priors_
255276
end
256277

278+
_get_priors(priors::SubArray) = copy(priors)
279+
_get_priors(priors) = priors
280+
257281
function MMI.fitted_params(::BayesianLDA, (core_res, classes_seen, priors, n))
258282
return (
259283
projected_class_means=MS.classmeans(core_res),
260284
projection_matrix=MS.projection(core_res),
261-
priors=priors
285+
priors=_get_priors(priors)
262286
)
263287
end
264288

@@ -278,17 +302,17 @@ function MMI.predict(m::BayesianLDA, (core_res, classes_seen, priors, n), Xnew)
278302
# with (Pᵀxᵢ − Pᵀµₖ)ᵀ(Pᵀxᵢ − Pᵀµₖ) being the SquaredEquclidean distance between
279303
# pairs of rows in the transformed space
280304
Pr = pairwise(SqEuclidean(), XWt, centroids, dims=1)
281-
Pr .*= (-0.5*n)
282-
Pr .+= log.(priors)'
305+
Pr .*= (-n/2)
306+
Pr .+= log.(transpose(priors))
283307

284308
# apply a softmax transformation to convert Pr to a probability matrix
285309
softmax!(Pr)
286310
return MMI.UnivariateFinite(classes_seen, Pr)
287311
end
288312

289-
function MMI.transform(m::T, (core_res,), X) where T<:Union{LDA, BayesianLDA}
313+
function MMI.transform(m::T, (core_res, ), X) where T<:Union{LDA, BayesianLDA}
290314
# projection of X, XWt is nt x o where o = out dims
291-
proj = core_res.projw * core_res.projLDA #proj is the projection_matrix
315+
proj = core_res.proj #proj is the projection_matrix
292316
XWt = MMI.matrix(X) * proj
293317
return MMI.table(XWt, prototype = X)
294318
end
@@ -374,7 +398,7 @@ function MMI.predict(m::SubspaceLDA, (core_res, out_dim, classes_seen), Xnew)
374398
# compute the distances in the transformed space between pairs of rows
375399
# the probability matrix is `nt x nc` and normalised accross rows
376400
Pr = pairwise(m.dist, XWt, centroids, dims=1)
377-
Pr .= Pr .* -1
401+
Pr .*= -1
378402
# apply a softmax transformation
379403
softmax!(Pr)
380404
return MMI.UnivariateFinite(classes_seen, Pr)
@@ -461,13 +485,13 @@ end
461485

462486
function _matrix_transpose(model::Union{SubspaceLDA, BayesianSubspaceLDA}, X)
463487
return transpose(MMI.matrix(X))
464-
end
465-
488+
end
489+
466490
function MMI.fitted_params(::BayesianSubspaceLDA, (core_res, _, _, priors,_))
467491
return (
468492
projected_class_means=MS.classmeans(core_res),
469493
projection_matrix=MS.projection(core_res),
470-
priors=priors
494+
priors=_get_priors(priors)
471495
)
472496
end
473497

@@ -496,8 +520,8 @@ function MMI.predict(
496520
# (Pᵀxᵢ − Pᵀµₖ)ᵀ(Pᵀxᵢ − Pᵀµₖ) is the SquaredEquclidean distance in the
497521
# transformed space
498522
Pr = pairwise(SqEuclidean(), XWt, centroids, dims=1)
499-
Pr .*= (-0.5 * (n-nc)/mult)
500-
Pr .+= log.(priors)'
523+
Pr .*= (-(n-nc)/2mult)
524+
Pr .+= log.(transpose(priors))
501525

502526
# apply a softmax transformation to convert Pr to a probability matrix
503527
softmax!(Pr)

test/models/discriminant_analysis.jl

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@testset "MulticlassLDA" begin
2+
## Data
23
Xfull, y = @load_smarket
34
X = selectcols(Xfull, [:Lag1,:Lag2])
45
train = selectcols(Xfull, :Year) .< Dates.Date(2005)
@@ -8,20 +9,33 @@
89
Xtest = selectrows(X, test)
910
ytest = selectrows(y, test)
1011

11-
LDA_model = LDA()
12-
fitresult, = fit(LDA_model, 1, Xtrain, ytrain)
13-
class_means, projection_matrix = fitted_params(LDA_model, fitresult)
14-
preds = predict(LDA_model, fitresult, Xtest)
12+
lda_model = LDA()
13+
14+
## Check model `fit`
15+
fitresult, = fit(lda_model, 1, Xtrain, ytrain)
16+
class_means, projection_matrix = fitted_params(lda_model, fitresult)
17+
@test round.(class_means', sigdigits = 3) == [0.0428 0.0339; -0.0395 -0.0313]
18+
## Check model `predict`
19+
preds = predict(lda_model, fitresult, Xtest)
1520
mce = cross_entropy(preds, ytest) |> mean
1621
@test 0.685 ≤ mce ≤ 0.695
17-
@test round.(class_means', sigdigits = 3) == [0.0428 0.0339; -0.0395 -0.0313]
22+
## Check model `transform`
23+
# MultivariateStats Linear Discriminant Analysis transform
24+
proj = fitresult[1].proj
25+
XWt = matrix(X) * proj
26+
tlda_ms = table(XWt, prototype=X)
27+
# MLJ Linear Discriminant Analysis transform
28+
tlda_mlj = transform(lda_model, fitresult, X)
29+
@test tlda_mlj == tlda_ms
30+
## Check model traits
1831
d = info_dict(LDA)
1932
@test d[:input_scitype] == Table(Continuous)
2033
@test d[:target_scitype] == AbstractVector{<:Finite}
2134
@test d[:name] == "LDA"
2235
end
2336
2437
@testset "MLDA-2" begin
38+
## Data
2539
Random.seed!(1125)
2640
X1 = -2 .+ randn(100, 2)
2741
X2 = randn(100, 2)
@@ -41,14 +55,17 @@ end
4155
ytrain = selectrows(y, train)
4256
Xtest = selectrows(X, test)
4357
ytest = selectrows(y, test)
58+
4459
lda_model = LDA()
60+
## Check model `fit`/`predict`
4561
fitresult, = fit(lda_model, 1, Xtrain, ytrain)
4662
preds = predict_mode(lda_model, fitresult, Xtest)
4763
mcr = misclassification_rate(preds, ytest)
4864
@test mcr ≤ 0.15
4965
end
5066
5167
@testset "BayesianMulticlassLDA" begin
68+
## Data
5269
Xfull, y = @load_smarket
5370
X = selectcols(Xfull, [:Lag1,:Lag2])
5471
train = selectcols(Xfull, :Year) .< Dates.Date(2005)
@@ -57,30 +74,32 @@ end
5774
ytrain = selectrows(y, train)
5875
Xtest = selectrows(X, test)
5976
ytest = selectrows(y, test)
77+
6078
BLDA_model = BayesianLDA()
79+
## Check model `fit`
6180
fitresult, = fit(BLDA_model, 1, Xtrain, ytrain)
6281
class_means, projection_matrix, priors = fitted_params(BLDA_model, fitresult)
82+
@test round.(class_means', sigdigits = 3) == [0.0428 0.0339; -0.0395 -0.0313]
83+
## Check model `predict`
6384
preds = predict(BLDA_model, fitresult, Xtest)
6485
mce = cross_entropy(preds, ytest) |> mean
6586
@test 0.685 mce 0.695
66-
@test round.(class_means', sigdigits = 3) == [0.0428 0.0339; -0.0395 -0.0313]
87+
## Check model traits
6788
d = info_dict(BayesianLDA)
6889
@test d[:input_scitype] == Table(Continuous)
6990
@test d[:target_scitype] == AbstractVector{<:Finite}
7091
@test d[:name] == "BayesianLDA"
7192
end
7293

7394
@testset "BayesianSubspaceLDA" begin
95+
## Data
7496
X, y = @load_iris
7597
LDA_model = BayesianSubspaceLDA()
98+
## Check model `fit`
7699
fitresult, _, report = fit(LDA_model, 1, X, y)
77100
class_means, projection_matrix, prior_probabilities = fitted_params(
78101
LDA_model, fitresult
79102
)
80-
preds=predict(LDA_model, fitresult, X)
81-
predicted_class = predict_mode(LDA_model, fitresult, X)
82-
mcr = misclassification_rate(predicted_class, y)
83-
mce = cross_entropy(preds, y) |> mean
84103
@test mean(
85104
abs.(
86105
class_means' - [
@@ -101,16 +120,24 @@ end
101120
)
102121
) < 0.05
103122
@test round.(prior_probabilities, sigdigits=7) == [0.3333333, 0.3333333, 0.3333333]
104-
@test round.(mcr, sigdigits=1) == 0.02
105123
@test round.(report.explained_variance_ratio, digits=4) == [0.9915, 0.0085]
124+
125+
## Check model `predict`
126+
preds=predict(LDA_model, fitresult, X)
127+
predicted_class = predict_mode(LDA_model, fitresult, X)
128+
mcr = misclassification_rate(predicted_class, y)
129+
mce = cross_entropy(preds, y) |> mean
130+
@test round.(mcr, sigdigits=1) == 0.02
106131
@test 0.04 ≤ mce ≤ 0.045
132+
## Check model traits
107133
d = info_dict(BayesianSubspaceLDA)
108134
@test d[:input_scitype] == Table(Continuous)
109135
@test d[:target_scitype] == AbstractVector{<:Finite}
110136
@test d[:name] == "BayesianSubspaceLDA"
111137
end
112138
113139
@testset "SubspaceLDA" begin
140+
## Data
114141
Random.seed!(1125)
115142
X1 = -2 .+ randn(100, 2)
116143
X2 = randn(100, 2)
@@ -130,7 +157,9 @@ end
130157
ytrain = selectrows(y, train)
131158
Xtest = selectrows(X, test)
132159
ytest = selectrows(y, test)
160+
133161
lda_model = SubspaceLDA()
162+
## Check model `fit`/ `transform`
134163
fitresult, = fit(lda_model, 1, Xtrain, ytrain)
135164
preds = predict_mode(lda_model, fitresult, Xtest)
136165
mcr = misclassification_rate(preds, ytest)
@@ -144,8 +173,51 @@ end
144173
# MLJ Linear Discriminant Analysis transform
145174
tlda_mlj = transform(lda_model, fitresult, X)
146175
@test tlda_mlj == tlda_ms
176+
## Check model traits
147177
d = info_dict(SubspaceLDA)
148178
@test d[:input_scitype] == Table(Continuous)
149179
@test d[:target_scitype] == AbstractVector{<:Finite}
150180
@test d[:name] == "SubspaceLDA"
151-
end
181+
end
182+
183+
@testset "discriminant models checks" begin
184+
## Data to be used for tests
185+
y = categorical(["apples", "oranges", "carrots", "mango"])
186+
X = (x1 =rand(4), x2 = collect(1:4))
187+
188+
## Note: The following test depend on the order in which they are written.
189+
## Hence do not change the ordering of the tests.
190+
191+
## Check to make sure error is thrown if we only have a single
192+
## unique class during training.
193+
model = LDA()
194+
# categorical array with same pool as y but only containing "apples"
195+
y1 = y[[1,1,1,1]]
196+
@test_throws ArgumentError fit(model, 1, X, y1)
197+
198+
## Check to make sure error is thrown if we don't have more samples
199+
## than unique classes during training.
200+
@test_throws ArgumentError fit(model, 1, X, y)
201+
202+
## Check to make sure error is thrown if `out_dim` exceeds the number of features in
203+
## sample matrix used in training.
204+
model = LDA(out_dim=3)
205+
# categorical array with same pool as y but only containing "apples" & "oranges"
206+
y2 = y[[1,2,1,2]]
207+
@test_throws ArgumentError fit(model, 1, X, y2)
208+
209+
## Check to make sure error is thrown if length(`priors`) != number of classes
210+
## in common pool of target vector used in training.
211+
model = BayesianLDA(priors=[0.1, 0.5, 0.4])
212+
@test_throws ArgumentError fit(model, 1, X, y)
213+
214+
## Check to make sure error is thrown if sum(`priors`) isn't approximately equal to 1.
215+
model = BayesianLDA(priors=[0.1, 0.5, 0.4, 0.2])
216+
@test_throws ArgumentError fit(model, 1, X, y)
217+
218+
## Check to make sure error is thrown if `priors .< 0` or `priors .> 1`.
219+
model = BayesianLDA(priors=[-0.1, 0.0, 1.0, 0.1])
220+
@test_throws ArgumentError fit(model, 1, X, y)
221+
model = BayesianLDA(priors=[1.1, 0.0, 0.0, -0.1])
222+
@test_throws ArgumentError fit(model, 1, X, y)
223+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ using MLJMultivariateStatsInterface
88
using StableRNGs
99
using Test
1010

11+
const MS = MultivariateStats
12+
1113
include("testutils.jl")
1214
println("\nutils"); include("utils.jl")
1315
println("\ncomponent_analysis"); include("models/decomposition_models.jl")

0 commit comments

Comments
 (0)