11@testset " MulticlassLDA" begin
2+ # # Data
23 Xfull, y = @load_smarket
34 X = selectcols(Xfull, [:Lag1,:Lag2])
45 train = selectcols(Xfull, :Year) .< Dates. Date(2005 )
89 Xtest = selectrows(X, test)
910 ytest = selectrows(y, test)
1011
11- LDA_model = LDA()
12- fitresult, = fit(LDA_model, 1 , Xtrain, ytrain)
13- class_means, projection_matrix = fitted_params(LDA_model, fitresult)
14- preds = predict(LDA_model, fitresult, Xtest)
12+ lda_model = LDA()
13+
14+ # # Check model `fit`
15+ fitresult, = fit(lda_model, 1 , Xtrain, ytrain)
16+ class_means, projection_matrix = fitted_params(lda_model, fitresult)
17+ @test round.(class_means' , sigdigits = 3) == [0.0428 0.0339; -0.0395 -0.0313]
18+ ## Check model `predict`
19+ preds = predict(lda_model, fitresult, Xtest)
1520 mce = cross_entropy(preds, ytest) |> mean
1621 @test 0.685 ≤ mce ≤ 0.695
17- @test round.(class_means' , sigdigits = 3) == [0.0428 0.0339; -0.0395 -0.0313]
22+ ## Check model `transform`
23+ # MultivariateStats Linear Discriminant Analysis transform
24+ proj = fitresult[1].proj
25+ XWt = matrix(X) * proj
26+ tlda_ms = table(XWt, prototype=X)
27+ # MLJ Linear Discriminant Analysis transform
28+ tlda_mlj = transform(lda_model, fitresult, X)
29+ @test tlda_mlj == tlda_ms
30+ ## Check model traits
1831 d = info_dict(LDA)
1932 @test d[:input_scitype] == Table(Continuous)
2033 @test d[:target_scitype] == AbstractVector{<:Finite}
2134 @test d[:name] == "LDA"
2235end
2336
2437@testset "MLDA-2" begin
38+ ## Data
2539 Random.seed!(1125)
2640 X1 = -2 .+ randn(100, 2)
2741 X2 = randn(100, 2)
4155 ytrain = selectrows(y, train)
4256 Xtest = selectrows(X, test)
4357 ytest = selectrows(y, test)
58+
4459 lda_model = LDA()
60+ ## Check model `fit`/`predict`
4561 fitresult, = fit(lda_model, 1, Xtrain, ytrain)
4662 preds = predict_mode(lda_model, fitresult, Xtest)
4763 mcr = misclassification_rate(preds, ytest)
4864 @test mcr ≤ 0.15
4965end
5066
5167@testset "BayesianMulticlassLDA" begin
68+ ## Data
5269 Xfull, y = @load_smarket
5370 X = selectcols(Xfull, [:Lag1,:Lag2])
5471 train = selectcols(Xfull, :Year) .< Dates.Date(2005)
5774 ytrain = selectrows(y, train)
5875 Xtest = selectrows(X, test)
5976 ytest = selectrows(y, test)
77+
6078 BLDA_model = BayesianLDA()
79+ ## Check model `fit`
6180 fitresult, = fit(BLDA_model, 1, Xtrain, ytrain)
6281 class_means, projection_matrix, priors = fitted_params(BLDA_model, fitresult)
82+ @test round.(class_means' , sigdigits = 3 ) == [0.0428 0.0339 ; - 0.0395 - 0.0313 ]
83+ # # Check model `predict`
6384 preds = predict(BLDA_model, fitresult, Xtest)
6485 mce = cross_entropy(preds, ytest) |> mean
6586 @test 0.685 ≤ mce ≤ 0.695
66- @test round.(class_means ' , sigdigits = 3 ) == [ 0.0428 0.0339 ; - 0.0395 - 0.0313 ]
87+ # # Check model traits
6788 d = info_dict(BayesianLDA)
6889 @test d[:input_scitype] == Table(Continuous)
6990 @test d[:target_scitype] == AbstractVector{<: Finite }
7091 @test d[:name] == " BayesianLDA"
7192end
7293
7394@testset " BayesianSubspaceLDA" begin
95+ # # Data
7496 X, y = @load_iris
7597 LDA_model = BayesianSubspaceLDA()
98+ # # Check model `fit`
7699 fitresult, _, report = fit(LDA_model, 1 , X, y)
77100 class_means, projection_matrix, prior_probabilities = fitted_params(
78101 LDA_model, fitresult
79102 )
80- preds= predict(LDA_model, fitresult, X)
81- predicted_class = predict_mode(LDA_model, fitresult, X)
82- mcr = misclassification_rate(predicted_class, y)
83- mce = cross_entropy(preds, y) |> mean
84103 @test mean(
85104 abs.(
86105 class_means' - [
@@ -101,16 +120,24 @@ end
101120 )
102121 ) < 0.05
103122 @test round.(prior_probabilities, sigdigits=7) == [0.3333333, 0.3333333, 0.3333333]
104- @test round.(mcr, sigdigits=1) == 0.02
105123 @test round.(report.explained_variance_ratio, digits=4) == [0.9915, 0.0085]
124+
125+ ## Check model `predict`
126+ preds=predict(LDA_model, fitresult, X)
127+ predicted_class = predict_mode(LDA_model, fitresult, X)
128+ mcr = misclassification_rate(predicted_class, y)
129+ mce = cross_entropy(preds, y) |> mean
130+ @test round.(mcr, sigdigits=1) == 0.02
106131 @test 0.04 ≤ mce ≤ 0.045
132+ ## Check model traits
107133 d = info_dict(BayesianSubspaceLDA)
108134 @test d[:input_scitype] == Table(Continuous)
109135 @test d[:target_scitype] == AbstractVector{<:Finite}
110136 @test d[:name] == "BayesianSubspaceLDA"
111137end
112138
113139@testset "SubspaceLDA" begin
140+ ## Data
114141 Random.seed!(1125)
115142 X1 = -2 .+ randn(100, 2)
116143 X2 = randn(100, 2)
130157 ytrain = selectrows(y, train)
131158 Xtest = selectrows(X, test)
132159 ytest = selectrows(y, test)
160+
133161 lda_model = SubspaceLDA()
162+ ## Check model `fit`/ `transform`
134163 fitresult, = fit(lda_model, 1, Xtrain, ytrain)
135164 preds = predict_mode(lda_model, fitresult, Xtest)
136165 mcr = misclassification_rate(preds, ytest)
144173 # MLJ Linear Discriminant Analysis transform
145174 tlda_mlj = transform(lda_model, fitresult, X)
146175 @test tlda_mlj == tlda_ms
176+ ## Check model traits
147177 d = info_dict(SubspaceLDA)
148178 @test d[:input_scitype] == Table(Continuous)
149179 @test d[:target_scitype] == AbstractVector{<:Finite}
150180 @test d[:name] == "SubspaceLDA"
151- end
181+ end
182+
183+ @testset "discriminant models checks" begin
184+ ## Data to be used for tests
185+ y = categorical(["apples", "oranges", "carrots", "mango"])
186+ X = (x1 =rand(4), x2 = collect(1:4))
187+
188+ ## Note: The following test depend on the order in which they are written.
189+ ## Hence do not change the ordering of the tests.
190+
191+ ## Check to make sure error is thrown if we only have a single
192+ ## unique class during training.
193+ model = LDA()
194+ # categorical array with same pool as y but only containing "apples"
195+ y1 = y[[1,1,1,1]]
196+ @test_throws ArgumentError fit(model, 1, X, y1)
197+
198+ ## Check to make sure error is thrown if we don' t have more samples
199+ # # than unique classes during training.
200+ @test_throws ArgumentError fit(model, 1 , X, y)
201+
202+ # # Check to make sure error is thrown if `out_dim` exceeds the number of features in
203+ # # sample matrix used in training.
204+ model = LDA(out_dim= 3 )
205+ # categorical array with same pool as y but only containing "apples" & "oranges"
206+ y2 = y[[1 ,2 ,1 ,2 ]]
207+ @test_throws ArgumentError fit(model, 1 , X, y2)
208+
209+ # # Check to make sure error is thrown if length(`priors`) != number of classes
210+ # # in common pool of target vector used in training.
211+ model = BayesianLDA(priors= [0.1 , 0.5 , 0.4 ])
212+ @test_throws ArgumentError fit(model, 1 , X, y)
213+
214+ # # Check to make sure error is thrown if sum(`priors`) isn't approximately equal to 1.
215+ model = BayesianLDA(priors= [0.1 , 0.5 , 0.4 , 0.2 ])
216+ @test_throws ArgumentError fit(model, 1 , X, y)
217+
218+ # # Check to make sure error is thrown if `priors .< 0` or `priors .> 1`.
219+ model = BayesianLDA(priors= [- 0.1 , 0.0 , 1.0 , 0.1 ])
220+ @test_throws ArgumentError fit(model, 1 , X, y)
221+ model = BayesianLDA(priors= [1.1 , 0.0 , 0.0 , - 0.1 ])
222+ @test_throws ArgumentError fit(model, 1 , X, y)
223+ end
0 commit comments