Skip to content

Commit ffba60b

Browse files
committed
upgrade to MS 0.9
1 parent 9403779 commit ffba60b

File tree

4 files changed

+24
-19
lines changed

4 files changed

+24
-19
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJMultivariateStatsInterface"
22
uuid = "1b6a4a23-ba22-4f51-9698-8599985d3728"
33
authors = ["Anthony D. Blaom <[email protected]>", "Thibaut Lienart <[email protected]>", "Okon Samuel <[email protected]>"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
@@ -13,7 +13,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1313
[compat]
1414
Distances = "^0.9,^0.10"
1515
MLJModelInterface = "^0.3.5,^0.4, 1.0"
16-
MultivariateStats = "0.7, 0.8"
16+
MultivariateStats = "0.9"
1717
StatsBase = "0.32, 0.33"
1818
julia = "1"
1919

src/models/decomposition_models.jl

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ function MMI.fit(model::PCA, verbosity::Int, X)
4444
)
4545
cache = nothing
4646
report = (
47-
indim=MS.indim(fitresult),
48-
outdim=MS.outdim(fitresult),
47+
indim=MS.size(fitresult,1),
48+
outdim=MS.size(fitresult,2),
4949
tprincipalvar=MS.tprincipalvar(fitresult),
5050
tresidualvar=MS.tresidualvar(fitresult),
51-
tvar=MS.tvar(fitresult),
51+
tvar=MS.var(fitresult),
5252
mean=copy(MS.mean(fitresult)),
5353
principalvars=copy(MS.principalvars(fitresult))
5454
)
@@ -113,9 +113,9 @@ function MMI.fit(model::KernelPCA, verbosity::Int, X)
113113
)
114114
cache = nothing
115115
report = (
116-
indim=MS.indim(fitresult),
117-
outdim=MS.outdim(fitresult),
118-
principalvars=copy(MS.principalvars(fitresult))
116+
indim=MS.size(fitresult,1),
117+
outdim=MS.size(fitresult,2),
118+
principalvars=copy(MS.eigvals(fitresult))
119119
)
120120
return fitresult, cache, report
121121
end
@@ -168,14 +168,19 @@ $ICA_DESCR
168168
end
169169

170170
function MMI.fit(model::ICA, verbosity::Int, X)
171+
icagfun(fname::Symbol, ::Type{T} = Float64) where T<:Real=
172+
fname == :tanh ? MS.Tanh{T}(1.0) :
173+
fname == :gaus ? MS.Gaus{T}() :
174+
error("Unknown gfun $(fname)")
175+
171176
Xarray = MMI.matrix(X)
172177
n, p = size(Xarray)
173178
m = min(n, p)
174179
k = ifelse(model.k m, model.k, m)
175180
fitresult = MS.fit(
176181
MS.ICA, transpose(Xarray), k;
177182
alg=model.alg,
178-
fun=MS.icagfun(model.fun, eltype(Xarray)),
183+
fun=icagfun(model.fun, eltype(Xarray)),
179184
do_whiten=model.do_whiten,
180185
maxiter=model.maxiter,
181186
tol=model.tol,
@@ -184,8 +189,8 @@ function MMI.fit(model::ICA, verbosity::Int, X)
184189
)
185190
cache = nothing
186191
report = (
187-
indim=MS.indim(fitresult),
188-
outdim=MS.outdim(fitresult),
192+
indim=MS.size(fitresult,1),
193+
outdim=MS.size(fitresult,2),
189194
mean=copy(MS.mean(fitresult))
190195
)
191196
return fitresult, cache, report
@@ -244,8 +249,8 @@ function MMI.fit(model::PPCA, verbosity::Int, X)
244249
)
245250
cache = nothing
246251
report = (
247-
indim=MS.indim(fitresult),
248-
outdim=MS.outdim(fitresult),
252+
indim=MS.size(fitresult,1),
253+
outdim=MS.size(fitresult,2),
249254
tvar=MS.var(fitresult),
250255
mean=copy(MS.mean(fitresult)),
251256
loadings=MS.loadings(fitresult)
@@ -308,8 +313,8 @@ function MMI.fit(model::FactorAnalysis, verbosity::Int, X)
308313
)
309314
cache = nothing
310315
report = (
311-
indim=MS.indim(fitresult),
312-
outdim=MS.outdim(fitresult),
316+
indim=MS.size(fitresult,1),
317+
outdim=MS.size(fitresult,2),
313318
variance=MS.var(fitresult),
314319
covariance_matrix=MS.cov(fitresult),
315320
mean=MS.mean(fitresult),
@@ -346,7 +351,7 @@ for (M, MFitResultType) in model_types
346351
@eval function MMI.transform(::$M, fr::$MFitResultType, X)
347352
# X is n x d, need to transpose twice
348353
Xarray = MMI.matrix(X)
349-
Xnew = transpose(MS.transform(fr, transpose(Xarray)))
354+
Xnew = transpose(MS.predict(fr, transpose(Xarray)))
350355
return MMI.table(Xnew, prototype=X)
351356
end
352357

src/models/discriminant_analysis.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function MMI.fit(model::LDA, ::Int, X, y)
5555
cache = nothing
5656
report = (
5757
classes=classes_seen,
58-
out_dim=MS.outdim(core_res),
58+
out_dim=MS.size(core_res)[2],
5959
class_means=MS.classmeans(core_res),
6060
mean=MS.mean(core_res),
6161
class_weights=MS.classweights(core_res),
@@ -221,7 +221,7 @@ function MMI.fit(model::BayesianLDA, ::Int, X, y)
221221
cache = nothing
222222
report = (
223223
classes=classes_seen,
224-
out_dim=MS.outdim(core_res),
224+
out_dim=MS.size(core_res)[2],
225225
class_means=MS.classmeans(core_res),
226226
mean=MS.mean(core_res),
227227
class_weights=MS.classweights(core_res),

test/testutils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222
function test_composition_model(ms_model, mlj_model, X, X_array ; test_inverse=true)
2323
mlj_model_type = typeof(mlj_model)
2424
Xtr_ms = permutedims(
25-
MultivariateStats.transform(ms_model, permutedims(X_array))
25+
MultivariateStats.predict(ms_model, permutedims(X_array))
2626
)
2727
fitresult, _, _ = fit(mlj_model, 1, X)
2828
Xtr_mlj_table = transform(mlj_model, fitresult, X)

0 commit comments

Comments
 (0)