Skip to content

Commit e5c9c33

Browse files
authored
Merge pull request #44 from JuliaAI/dev
For a 0.3.2 release
2 parents 6ea3f23 + f39ac97 commit e5c9c33

File tree

4 files changed

+14
-4
lines changed

4 files changed

+14
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJMultivariateStatsInterface"
22
uuid = "1b6a4a23-ba22-4f51-9698-8599985d3728"
33
authors = ["Anthony D. Blaom <[email protected]>", "Thibaut Lienart <[email protected]>", "Okon Samuel <[email protected]>"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

src/models/decomposition_models.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,13 @@ model_types = [
340340
]
341341

342342
for (M, MFitResultType) in model_types
343-
@eval function MMI.fitted_params(::$M, fr)
344-
return (projection=copy(MS.projection(fr)),)
343+
344+
if M !== ICA # special cased below
345+
quote
346+
function MMI.fitted_params(::$M, fr)
347+
return (projection=copy(MS.projection(fr)),)
348+
end
349+
end |> eval
345350
end
346351

347352
@eval function MMI.transform(::$M, fr::$MFitResultType, X)
@@ -360,3 +365,5 @@ for (M, MFitResultType) in model_types
360365
end
361366
end
362367
end
368+
369+
MMI.fitted_params(::ICA, fr) = (projection=copy(fr.W), mean = copy(MS.mean(fr)))

test/models/decomposition_models.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ end
7575
test_composition_model(ica_ms, ica_mlj, X, X_array, test_inverse=false)
7676
end
7777

78-
7978
@testset "PPCA" begin
8079
X_array = matrix(X)
8180
tolerance = 5.0

test/testutils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,8 @@ function test_composition_model(ms_model, mlj_model, X, X_array ; test_inverse=t
3939
Xinv_mlj = matrix(Xinv_mlj_table)
4040
@test Xinv_ms Xinv_mlj
4141
end
42+
43+
# smoke test for issue #42
44+
fp = MLJBase.fitted_params(mlj_model, fitresult)
45+
:projection in keys(fp)
4246
end

0 commit comments

Comments
 (0)