Skip to content

Commit f6bc4d0

Browse files
committed
Split LinearRegresor and RidgeRegressor into MultitargetLinearRegressor, LinearRegressor, MultitargetRidgeRegressor, RidgeRegressor
1 parent a6f9425 commit f6bc4d0

File tree

3 files changed

+188
-115
lines changed

3 files changed

+188
-115
lines changed

src/MLJMultivariateStatsInterface.jl

Lines changed: 74 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -34,70 +34,84 @@ const FactorAnalysisResultType = MS.FactorAnalysis
3434
const default_kernel = (x, y) -> x'y #default kernel used in KernelPCA
3535

3636
# Definitions of model descriptions for use in model doc-strings.
37-
const PCA_DESCR = """Principal component analysis. Learns a linear transformation to
38-
project the data on a lower dimensional space while preserving most of the initial
39-
variance.
40-
"""
37+
const PCA_DESCR = """
38+
Principal component analysis. Learns a linear transformation to
39+
project the data on a lower dimensional space while preserving most of the initial
40+
variance.
41+
"""
4142
const KPCA_DESCR = "Kernel principal component analysis."
4243
const ICA_DESCR = "Independent component analysis."
4344
const PPCA_DESCR = "Probabilistic principal component analysis"
4445
const FactorAnalysis_DESCR = "Factor Analysis"
45-
const LDA_DESCR = """Multiclass linear discriminant analysis. The algorithm learns a
46-
projection matrix `P` that projects a feature matrix `Xtrain` onto a lower dimensional
47-
space of dimension `out_dim` such that the trace of the transformed between-class scatter
48-
matrix(`Pᵀ*Sb*P`) is maximized relative to the trace of the transformed within-class
49-
scatter matrix (`Pᵀ*Sw*P`).The projection matrix is scaled such that `Pᵀ*Sw*P=I` or
50-
`Pᵀ*Σw*P=I`(where `Σw` is the within-class covariance matrix) .
51-
Predicted class posterior probability for feature matrix `Xtest` are derived by applying
52-
a softmax transformationto a matrix `Pr`, such that rowᵢ of `Pr` contains computed
53-
distances(based on a distance metric) in the transformed space of rowᵢ in `Xtest` to the
54-
centroid of each class.
55-
"""
56-
const BayesianLDA_DESCR = """Bayesian Multiclass linear discriminant analysis. The algorithm
57-
learns a projection matrix `P` that projects a feature matrix `Xtrain` onto a lower
58-
dimensional space of dimension `out_dim` such that the trace of the transformed
59-
between-class scatter matrix(`Pᵀ*Sb*P`) is maximized relative to the trace of the
60-
transformed within-class scatter matrix (`Pᵀ*Sw*P`). The projection matrix is scaled such
61-
that `Pᵀ*Sw*P = n` or `Pᵀ*Σw*P=I` (Where `n` is the number of training samples and `Σw`
62-
is the within-class covariance matrix).
63-
Predicted class posterior probability distibution are derived by applying Bayes rule with
64-
a multivariate Gaussian class-conditional distribution.
65-
"""
66-
const SubspaceLDA_DESCR = """Multiclass linear discriminant analysis. Suitable for high
67-
dimensional data (Avoids computing scatter matrices `Sw` ,`Sb`). The algorithm learns a
68-
projection matrix `P = W*L` that projects a feature matrix `Xtrain` onto a lower
69-
dimensional space of dimension `nc - 1` such that the trace of the transformed
70-
between-class scatter matrix(`Pᵀ*Sb*P`) is maximized relative to the trace of the
71-
transformed within-class scatter matrix (`Pᵀ*Sw*P`). The projection matrix is scaled such
72-
that `Pᵀ*Sw*P = mult*I` or `Pᵀ*Σw*P=mult/(n-nc)*I` (where `n` is the number of training
73-
samples, mult` is one of `n` or `1` depending on whether `Sb` is normalized, `Σw` is the
74-
within-class covariance matrix, and `nc` is the number of unique classes in `y`) and also
75-
obeys `Wᵀ*Sb*p = λ*Wᵀ*Sw*p`, for every column `p` in `P`.
76-
Predicted class posterior probability for feature matrix `Xtest` are derived by applying a
77-
softmax transformation to a matrix `Pr`, such that rowᵢ of `Pr` contains computed
78-
distances(based on a distance metric) in the transformed space of rowᵢ in `Xtest` to the
79-
centroid of each class.
80-
"""
81-
const BayesianSubspaceLDA_DESCR = """Bayesian Multiclass linear discriminant analysis.
82-
Suitable for high dimensional data (Avoids computing scatter matrices `Sw` ,`Sb`). The
83-
algorithm learns a projection matrix `P = W*L` (`Sw`), that projects a feature matrix
84-
`Xtrain` onto a lower dimensional space of dimension `nc-1` such that the trace of the
85-
transformed between-class scatter matrix(`Pᵀ*Sb*P`) is maximized relative to the trace
86-
of the transformed within-class scatter matrix (`Pᵀ*Sw*P`). The projection matrix is
87-
scaled such that `Pᵀ*Sw*P = mult*I` or `Pᵀ*Σw*P=mult/(n-nc)*I` (where `n` is the number of
88-
training samples, `mult` is one of `n` or `1` depending on whether `Sb` is normalized,
89-
`Σw` is the within-class covariance matrix, and `nc` is the number of unique classes in
90-
`y`) and also obeys `Wᵀ*Sb*p = λ*Wᵀ*Sw*p`, for every column `p` in `P`.
91-
Posterior class probability distibution are derived by applying Bayes rule with a
92-
multivariate Gaussian class-conditional distribution
93-
"""
94-
const LINEAR_DESCR = """Linear regression. Learns a linear combination(s) of given
95-
variables to fit the responses by minimizing the squared error between.
96-
"""
97-
const RIDGE_DESCR = """Ridge regressor with regularization parameter lambda. Learns a
98-
linear regression with a penalty on the l2 norm of the coefficients.
99-
"""
100-
46+
const LDA_DESCR = """
47+
Multiclass linear discriminant analysis. The algorithm learns a
48+
projection matrix `P` that projects a feature matrix `Xtrain` onto a lower dimensional
49+
space of dimension `out_dim` such that the trace of the transformed between-class
50+
scatter matrix(`Pᵀ*Sb*P`) is maximized relative to the trace of the transformed
51+
within-class scatter matrix (`Pᵀ*Sw*P`).The projection matrix is scaled such that
52+
`Pᵀ*Sw*P=I` or `Pᵀ*Σw*P=I`(where `Σw` is the within-class covariance matrix) .
53+
Predicted class posterior probability for feature matrix `Xtest` are derived by
54+
applying a softmax transformationto a matrix `Pr`, such that rowᵢ of `Pr` contains
55+
computed distances(based on a distance metric) in the transformed space of rowᵢ in
56+
`Xtest` to the centroid of each class.
57+
"""
58+
const BayesianLDA_DESCR = """
59+
Bayesian Multiclass linear discriminant analysis. The algorithm
60+
learns a projection matrix `P` that projects a feature matrix `Xtrain` onto a lower
61+
dimensional space of dimension `out_dim` such that the trace of the transformed
62+
between-class scatter matrix(`Pᵀ*Sb*P`) is maximized relative to the trace of the
63+
transformed within-class scatter matrix (`Pᵀ*Sw*P`). The projection matrix is scaled
64+
such that `Pᵀ*Sw*P = n` or `Pᵀ*Σw*P=I` (Where `n` is the number of training samples
65+
and `Σw` is the within-class covariance matrix).
66+
Predicted class posterior probability distibution are derived by applying Bayes rule
67+
with a multivariate Gaussian class-conditional distribution.
68+
"""
69+
const SubspaceLDA_DESCR = """
70+
Multiclass linear discriminant analysis. Suitable for high
71+
dimensional data (Avoids computing scatter matrices `Sw` ,`Sb`). The algorithm learns a
72+
projection matrix `P = W*L` that projects a feature matrix `Xtrain` onto a lower
73+
dimensional space of dimension `nc - 1` such that the trace of the transformed
74+
between-class scatter matrix(`Pᵀ*Sb*P`) is maximized relative to the trace of the
75+
transformed within-class scatter matrix (`Pᵀ*Sw*P`). The projection matrix is scaled
76+
such that `Pᵀ*Sw*P = mult*I` or `Pᵀ*Σw*P=mult/(n-nc)*I` (where `n` is the number of
77+
training samples, mult` is one of `n` or `1` depending on whether `Sb` is normalized,
78+
`Σw` is the within-class covariance matrix, and `nc` is the number of unique classes
79+
in `y`) and also obeys `Wᵀ*Sb*p = λ*Wᵀ*Sw*p`, for every column `p` in `P`.
80+
Predicted class posterior probability for feature matrix `Xtest` are derived by
81+
applying a softmax transformation to a matrix `Pr`, such that rowᵢ of `Pr` contains
82+
computed distances(based on a distance metric) in the transformed space of rowᵢ in
83+
`Xtest` to the centroid of each class.
84+
"""
85+
const BayesianSubspaceLDA_DESCR = """
86+
Bayesian Multiclass linear discriminant analysis. Suitable for high dimensional data
87+
(Avoids computing scatter matrices `Sw` ,`Sb`). The algorithm learns a projection
88+
matrix `P = W*L` (`Sw`), that projects a feature matrix `Xtrain` onto a lower
89+
dimensional space of dimension `nc-1` such that the trace of the transformed
90+
between-class scatter matrix(`Pᵀ*Sb*P`) is maximized relative to the trace of the
91+
transformed within-class scatter matrix (`Pᵀ*Sw*P`). The projection matrix is scaled
92+
such that `Pᵀ*Sw*P = mult*I` or `Pᵀ*Σw*P=mult/(n-nc)*I` (where `n` is the number of
93+
training samples, `mult` is one of `n` or `1` depending on whether `Sb` is normalized,
94+
`Σw` is the within-class covariance matrix, and `nc` is the number of unique classes in
95+
`y`) and also obeys `Wᵀ*Sb*p = λ*Wᵀ*Sw*p`, for every column `p` in `P`.
96+
Posterior class probability distibution are derived by applying Bayes rule with a
97+
multivariate Gaussian class-conditional distribution
98+
"""
99+
const LinearRegressor_DESCR = """
100+
Linear Regression. Learns a linear combination of given
101+
variables to fit the response by minimizing the squared error between.
102+
"""
103+
const MultitargetLinearRegressor_DESCR = """
104+
Multitarget Linear Regression. Learns linear combinations of given
105+
variables to fit the responses by minimizing the squared error between.
106+
"""
107+
const RidgeRegressor_DESCR = """
108+
Ridge regressor with regularization parameter lambda. Learns a
109+
linear regression with a penalty on the l2 norm of the coefficients.
110+
"""
111+
const MultitargetRidgeRegressor_DESCR = """
112+
Multitarget Ridge regressor with regularization parameter lambda. Learns a
113+
Multitarget linear regression with a penalty on the l2 norm of the coefficients.
114+
"""
101115
const PKG = "MLJMultivariateStatsInterface"
102116

103117
# ===================================================================

src/models/linear_models.jl

Lines changed: 106 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,6 @@
1-
####
2-
#### LinearRegressor
3-
####
4-
5-
"""
6-
LinearRegressor(; bias::Bool=true)
7-
8-
$LINEAR_DESCR
9-
10-
# Keyword Parameters
11-
12-
- `bias::Bool=true`: if true includes a bias term else fits without bias term.
13-
"""
14-
@mlj_model mutable struct LinearRegressor <: MMI.Deterministic
15-
bias::Bool = true
16-
end
17-
1+
#######
2+
## Common Regressor methods
3+
########
184
struct LinearFitresult{T, F<:Real, M<:AbstractArray{F}} <: MMI.MLJType
195
sol_matrix::M
206
bias::Bool
@@ -37,15 +23,6 @@ function _matrix(X, target)
3723
return Xmatrix, Y, _names(target)
3824
end
3925

40-
function MMI.fit(model::LinearRegressor, verbosity::Int, X, y)
41-
Xmatrix, y_, target_header= _matrix(X, y)
42-
θ = MS.llsq(Xmatrix, y_; bias=model.bias)
43-
fitresult = LinearFitresult(θ, model.bias, target_header)
44-
report = NamedTuple()
45-
cache = nothing
46-
return fitresult, cache, report
47-
end
48-
4926
function _regressor_fitted_params(fr::LinearFitresult{Nothing, <:Real, <:AbstractVector})
5027
return (
5128
coefficients=fr.sol_matrix[1:end-Int(fr.bias)],
@@ -60,10 +37,6 @@ function _regressor_fitted_params(fr::LinearFitresult{<:Vector, <:Real, <:Abstra
6037
)
6138
end
6239

63-
function MMI.fitted_params(::LinearRegressor, fr)
64-
return _regressor_fitted_params(fr)
65-
end
66-
6740
function _predict_regressor(
6841
fr::LinearFitresult{Nothing, <:Real, <:AbstractVector},
6942
Xmat_new::AbstractMatrix,
@@ -98,30 +71,66 @@ function _predict_regressor(
9871
end
9972
end
10073

101-
function MMI.predict(::LinearRegressor, fr, Xnew)
74+
####
75+
#### LinearRegressor & MultitargetLinearRegressor
76+
####
77+
78+
"""
79+
LinearRegressor(; bias::Bool=true)
80+
81+
$LinearRegressor_DESCR
82+
83+
# Keyword Parameters
84+
85+
- `bias::Bool=true`: if true includes a bias term else fits without bias term.
86+
"""
87+
@mlj_model mutable struct LinearRegressor <: MMI.Deterministic
88+
bias::Bool = true
89+
end
90+
91+
"""
92+
MultitargetLinearRegressor(; bias::Bool=true)
93+
94+
$MultitargetLinearRegressor_DESCR
95+
96+
# Keyword Parameters
97+
98+
- `bias::Bool=true`: if true includes a bias term else fits without bias term.
99+
"""
100+
@mlj_model mutable struct MultitargetLinearRegressor <: MMI.Deterministic
101+
bias::Bool = true
102+
end
103+
104+
const LINREG = Union{LinearRegressor, MultitargetLinearRegressor}
105+
106+
function MMI.fit(model::LINREG, verbosity::Int, X, y)
107+
Xmatrix, y_, target_header= _matrix(X, y)
108+
θ = MS.llsq(Xmatrix, y_; bias=model.bias)
109+
fitresult = LinearFitresult(θ, model.bias, target_header)
110+
report = NamedTuple()
111+
cache = nothing
112+
return fitresult, cache, report
113+
end
114+
115+
function MMI.fitted_params(::LINREG, fr)
116+
return _regressor_fitted_params(fr)
117+
end
118+
119+
function MMI.predict(::LINREG, fr, Xnew)
102120
Xmat_new = MMI.matrix(Xnew)
103121
return _predict_regressor(fr, Xmat_new, Xnew)
104122
end
105123

106-
metadata_model(
107-
LinearRegressor,
108-
input=Table(Continuous),
109-
target=Union{Table(Continuous), AbstractVector{Continuous}},
110-
weights=false,
111-
descr=LINEAR_DESCR,
112-
path="$(PKG).LinearRegressor"
113-
)
114-
115124
####
116-
#### RidgeRegressor
125+
#### RidgeRegressor & MultitargetRidgeRegressor
117126
####
118127

119128
_check_typeof_lambda(x)= x isa AbstractVecOrMat || (x isa Real && x 0)
120129

121130
"""
122131
RidgeRegressor(; lambda::Union{Real, AbstractVecOrMat}=1.0, bias::Bool=true)
123132
124-
$RIDGE_DESCR
133+
$RidgeRegressor_DESCR
125134
126135
# Keyword Parameters
127136
@@ -134,7 +143,25 @@ $RIDGE_DESCR
134143
bias::Bool = true
135144
end
136145

137-
function MMI.fit(model::RidgeRegressor, verbosity::Int, X, y)
146+
"""
147+
MultitargetRidgeRegressor(; lambda::Union{Real, AbstractVecOrMat}=1.0, bias::Bool=true)
148+
149+
$MultitargetRidgeRegressor_DESCR
150+
151+
# Keyword Parameters
152+
153+
- `lambda::Union{Real, AbstractVecOrMat}=1.0`: non-negative parameter for the
154+
regularization strength.
155+
- `bias::Bool=true`: if true includes a bias term else fits without bias term.
156+
"""
157+
@mlj_model mutable struct MultitargetRidgeRegressor <: MMI.Deterministic
158+
lambda::Union{Real, AbstractVecOrMat} = 1.0::(_check_typeof_lambda(_))
159+
bias::Bool = true
160+
end
161+
162+
const RIDGEREG = Union{RidgeRegressor, MultitargetRidgeRegressor}
163+
164+
function MMI.fit(model::RIDGEREG, verbosity::Int, X, y)
138165
Xmatrix, y_, target_header = _matrix(X, y)
139166
θ = MS.ridge(Xmatrix, y_, model.lambda; bias=model.bias)
140167
fitresult = LinearFitresult(θ, model.bias, target_header)
@@ -143,20 +170,52 @@ function MMI.fit(model::RidgeRegressor, verbosity::Int, X, y)
143170
return fitresult, cache, report
144171
end
145172

146-
function MMI.fitted_params(::RidgeRegressor, fr)
173+
function MMI.fitted_params(::RIDGEREG, fr)
147174
return _regressor_fitted_params(fr)
148175
end
149176

150-
function MMI.predict(::RidgeRegressor, fr, Xnew)
177+
function MMI.predict(::RIDGEREG, fr, Xnew)
151178
Xmat_new = MMI.matrix(Xnew)
152179
return _predict_regressor(fr, Xmat_new, Xnew)
153180
end
154181

182+
183+
############
184+
### Models Metadata
185+
############
186+
metadata_model(
187+
LinearRegressor,
188+
input=Table(Continuous),
189+
target=AbstractVector{Continuous},
190+
weights=false,
191+
descr=LinearRegressor_DESCR,
192+
path="$(PKG).LinearRegressor"
193+
)
194+
195+
metadata_model(
196+
MultitargetLinearRegressor,
197+
input=Table(Continuous),
198+
target=Table(Continuous),
199+
weights=false,
200+
descr=MultitargetLinearRegressor_DESCR,
201+
path="$(PKG).MultitargetLinearRegressor"
202+
)
203+
155204
metadata_model(
156205
RidgeRegressor,
157206
input=Table(Continuous),
158-
target=Union{Table(Continuous), AbstractVector{Continuous}},
207+
target=AbstractVector{Continuous},
159208
weights=false,
160-
descr=RIDGE_DESCR,
209+
descr=RidgeRegressor_DESCR ,
161210
path="$(PKG).RidgeRegressor"
162211
)
212+
213+
metadata_model(
214+
MultitargetRidgeRegressor,
215+
input=Table(Continuous),
216+
target=Table(Continuous),
217+
weights=false,
218+
descr=MultitargetRidgeRegressor_DESCR,
219+
path="$(PKG).MultitargetRidgeRegressor"
220+
)
221+

0 commit comments

Comments
 (0)