11# This file defines `TruncatedSVD(; codim=1)`
22
33using LearnAPI
4- using LinearAlgebra
4+ using LinearAlgebra
5+ import LearnDataFrontEnds as FrontEnds
56
67
78# # DIMENSION REDUCTION USING TRUNCATED SVD DECOMPOSITION
89
910# Recall that truncated SVD reduction is the same as PCA reduction, but without
10- # centering. We suppose observations are presented as the columns of a `Real` matrix.
11+ # centering.
1112
1213# Some struct fields are left abstract for simplicity.
1314
2324Instantiate a truncated singular value decomposition algorithm for reducing the dimension
2425of observations by `codim`.
2526
27+ Data can be provided to `fit` or `transform` in any form supported by the `Tarragon` data
28+ front end at LearnDataFrontEnds.jl. However, the outputs of `transform` and
29+ `inverse_transform` are always matrices.
30+
31+
2632```julia
2733learner = Truncated()
2834X = rand(3, 100) # 100 observations in 3-space
4955
5056LearnAPI. learner (model:: TruncatedSVDFitted ) = model. learner
5157
52- function LearnAPI. fit (learner:: TruncatedSVD , X; verbosity= 1 )
58+ # add a canned data front end; `obs` will return objects of type `FrontEnds.Obs`:
59+ LearnAPI. obs (learner:: TruncatedSVD , data) =
60+ FrontEnds. fitobs (learner, data, FrontEnds. Tarragon ())
61+ LearnAPI. obs (model:: TruncatedSVDFitted , data) =
62+ obs (model, data, FrontEnds. Tarragon ())
63+
64+ # training data deconstructor:
65+ LearnAPI. features (learner:: TruncatedSVD , data) =
66+ LearnAPI. features (learner, data, FrontEnds. Tarragon ())
67+
68+ function LearnAPI. fit (learner:: TruncatedSVD , observations:: FrontEnds.Obs ; verbosity= 1 )
5369
5470 # unpack hyperparameters:
5571 codim = learner. codim
72+ X = observations. features
5673 p, n = size (X)
5774 n ≥ p || error (" Insufficient number observations. " )
5875 outdim = p - codim
@@ -70,14 +87,19 @@ function LearnAPI.fit(learner::TruncatedSVD, X; verbosity=1)
7087 return TruncatedSVDFitted (learner, U, Ut, singular_values)
7188
7289end
90+ LearnAPI. fit (learner:: TruncatedSVD , data; kwargs... ) =
91+ LearnAPI. fit (learner, LearnAPI. obs (learner, data); kwargs... )
7392
74- LearnAPI. transform (model:: TruncatedSVDFitted , X) = model. Ut* X
93+ LearnAPI. transform (model:: TruncatedSVDFitted , observations:: FrontEnds.Obs ) =
94+ model. Ut* (observations. features)
95+ LearnAPI. transform (model:: TruncatedSVDFitted , data) =
96+ LearnAPI. transform (model, obs (model, data))
7597
7698# convenience fit-transform:
77- LearnAPI. transform (learner:: TruncatedSVD , X ; kwargs... ) =
78- transform (fit (learner, X ; kwargs... ), X )
99+ LearnAPI. transform (learner:: TruncatedSVD , data ; kwargs... ) =
100+ transform (fit (learner, data ; kwargs... ), data )
79101
80- LearnAPI. inverse_transform (model:: TruncatedSVDFitted , W) = model. U* W
102+ LearnAPI. inverse_transform (model:: TruncatedSVDFitted , W:: AbstractMatrix ) = model. U* W
81103
82104# accessor function:
83105function LearnAPI. extras (model:: TruncatedSVDFitted )
0 commit comments