Skip to content

Commit 1d165a2

Browse files
committed
add front end to dimension_reduction.jl learner
1 parent c602273 commit 1d165a2

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

src/learners/dimension_reduction.jl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# This file defines `TruncatedSVD(; codim=1)`
22

33
using LearnAPI
4-
using LinearAlgebra
4+
using LinearAlgebra
5+
import LearnDataFrontEnds as FrontEnds
56

67

78
# # DIMENSION REDUCTION USING TRUNCATED SVD DECOMPOSITION
89

910
# Recall that truncated SVD reduction is the same as PCA reduction, but without
10-
# centering. We suppose observations are presented as the columns of a `Real` matrix.
11+
# centering.
1112

1213
# Some struct fields are left abstract for simplicity.
1314

@@ -23,6 +24,11 @@ end
2324
Instantiate a truncated singular value decomposition algorithm for reducing the dimension
2425
of observations by `codim`.
2526
27+
Data can be provided to `fit` or `transform` in any form supported by the `Tarragon` data
28+
front end at LearnDataFrontEnds.jl. However, the outputs of `transform` and
29+
`inverse_transform` are always matrices.
30+
31+
2632
```julia
2733
learner = Truncated()
2834
X = rand(3, 100) # 100 observations in 3-space
@@ -49,10 +55,21 @@ end
4955

5056
LearnAPI.learner(model::TruncatedSVDFitted) = model.learner
5157

52-
function LearnAPI.fit(learner::TruncatedSVD, X; verbosity=1)
58+
# add a canned data front end; `obs` will return objects of type `FrontEnds.Obs`:
59+
LearnAPI.obs(learner::TruncatedSVD, data) =
60+
FrontEnds.fitobs(learner, data, FrontEnds.Tarragon())
61+
LearnAPI.obs(model::TruncatedSVDFitted, data) =
62+
obs(model, data, FrontEnds.Tarragon())
63+
64+
# training data deconstructor:
65+
LearnAPI.features(learner::TruncatedSVD, data) =
66+
LearnAPI.features(learner, data, FrontEnds.Tarragon())
67+
68+
function LearnAPI.fit(learner::TruncatedSVD, observations::FrontEnds.Obs; verbosity=1)
5369

5470
# unpack hyperparameters:
5571
codim = learner.codim
72+
X = observations.features
5673
p, n = size(X)
5774
n p || error("Insufficient number observations. ")
5875
outdim = p - codim
@@ -70,14 +87,19 @@ function LearnAPI.fit(learner::TruncatedSVD, X; verbosity=1)
7087
return TruncatedSVDFitted(learner, U, Ut, singular_values)
7188

7289
end
90+
LearnAPI.fit(learner::TruncatedSVD, data; kwargs...) =
91+
LearnAPI.fit(learner, LearnAPI.obs(learner, data); kwargs...)
7392

74-
LearnAPI.transform(model::TruncatedSVDFitted, X) = model.Ut*X
93+
LearnAPI.transform(model::TruncatedSVDFitted, observations::FrontEnds.Obs) =
94+
model.Ut*(observations.features)
95+
LearnAPI.transform(model::TruncatedSVDFitted, data) =
96+
LearnAPI.transform(model, obs(model, data))
7597

7698
# convenience fit-transform:
77-
LearnAPI.transform(learner::TruncatedSVD, X; kwargs...) =
78-
transform(fit(learner, X; kwargs...), X)
99+
LearnAPI.transform(learner::TruncatedSVD, data; kwargs...) =
100+
transform(fit(learner, data; kwargs...), data)
79101

80-
LearnAPI.inverse_transform(model::TruncatedSVDFitted, W) = model.U*W
102+
LearnAPI.inverse_transform(model::TruncatedSVDFitted, W::AbstractMatrix) = model.U*W
81103

82104
# accessor function:
83105
function LearnAPI.extras(model::TruncatedSVDFitted)

src/learners/regression.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ end
2121
"""
2222
Ridge(; lambda=0.1)
2323
24-
Instantiate a ridge regression learner, with regularization of `lambda`.
24+
Instantiate a ridge regression learner, with regularization of `lambda`. Data can be
25+
provided to `fit` or `predict` in any form supported by the `Saffron` data front end at
26+
LearnDataFrontEnds.jl.
2527
2628
"""
2729
Ridge(; lambda=0.1) = Ridge(lambda) # LearnAPI.constructor defined later

test/learners/dimension_reduction.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using LearnAPI
33
using LearnTestAPI
44
using StableRNGs
55
using Statistics
6+
using LinearAlgebra
67

78
# synthesize test data:
89
rng = StableRNG(123)

0 commit comments

Comments
 (0)