Skip to content

Commit 2111955

Browse files
committed
addresses #17
1 parent d9186ca commit 2111955

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

src/mlj/classifiers.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,32 @@
22
LOGISTIC CLASSIFIER
33
=================== =#
44

5+
"""
6+
$SIGNATURES
7+
8+
Logistic Classifier (typically called "Logistic Regression"). This model is
9+
a standard classifier for both binary and multiclass classification.
10+
In the binary case it corresponds to the LogisticLoss, in the multiclass to the
11+
Multinomial (softmax) loss. An elastic net penalty can be applied with
12+
overall objective function
13+
14+
``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁``
15+
16+
Where `L` is either the logistic or multinomial loss and `λ` and `γ` indicate
17+
the strength of the L2 (resp. L1) regularisation components.
18+
19+
## Parameters
20+
* `penalty` (Symbol or String): the penalty to use, either `:l2`, `:l1`, `:en`
21+
(elastic net) or `:none`. (Default: `:l2`)
22+
* `lambda` (Real): strength of the regulariser if `penalty` is `:l2` or `:l1`.
23+
Strength of the L2 regulariser if `penalty` is `:en`.
24+
* `gamma` (Real): strength of the L1 regulariser if `penalty` is `:en`.
25+
* `fit_intercept` (Bool): whether to fit an intercept (Default: `true`)
26+
* `penalize_intercept` (Bool): whether to penalize intercept (Default: `false`)
27+
* `solver` (Solver): type of solver to use, default if `nothing`.
28+
* `multi_class` (Bool): whether it's a binary or multi class classification
29+
problem. This is usually set automatically.
30+
"""
531
@with_kw_noshow mutable struct LogisticClassifier <: MMI.Probabilistic
632
lambda::Real = 1.0
733
gamma::Real = 0.0
@@ -25,6 +51,12 @@ descr(::Type{LogisticClassifier}) = "Classifier corresponding to the loss functi
2551
MULTINOMIAL CLASSIFIER
2652
====================== =#
2753

54+
"""
55+
$SIGNATURES
56+
57+
See `LogisticClassifier`, it's the same except that `multi_class` is set
58+
to `true` by default. The other parameters are the same.
59+
"""
2860
@with_kw_noshow mutable struct MultinomialClassifier <: MMI.Probabilistic
2961
lambda::Real = 1.0
3062
gamma::Real = 0.0

src/mlj/regressors.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ descr(::Type{LinearRegressor}) = "Regression with objective function ``|Xθ - y|
2828
RIDGE REGRESSOR
2929
=============== =#
3030

31+
"""
32+
$SIGNATURES
33+
34+
Ridge regression model with objective function
35+
36+
``|Xθ - y|₂²/2 + λ|θ|₂²/2``
37+
38+
## Parameters
39+
40+
* `lambda` (Real): strength of the L2 regularisation.
41+
* `fit_intercept` (Bool): whether to fit the intercept or not.
42+
* `penalize_intercept` (Bool): whether to penalize the intercept.
43+
* `solver`: type of solver to use (if `nothing` the default is used). The
44+
solver is Cholesky by default but can be Conjugate-Gradient as
45+
well. See `?Analytical` for more information.
46+
"""
3147
@with_kw_noshow mutable struct RidgeRegressor <: MMI.Deterministic
3248
lambda::Real = 1.0
3349
fit_intercept::Bool = true
@@ -46,6 +62,22 @@ descr(::Type{RidgeRegressor}) = "Regression with objective function ``|Xθ - y|
4662
LASSO REGRESSOR
4763
=============== =#
4864

65+
"""
66+
$SIGNATURES
67+
68+
Lasso regression model with objective function
69+
70+
``|Xθ - y|₂²/2 + λ|θ|₁``
71+
72+
## Parameters
73+
74+
* `lambda` (Real): strength of the L1 regularisation.
75+
* `fit_intercept` (Bool): whether to fit the intercept or not.
76+
* `penalize_intercept` (Bool): whether to penalize the intercept.
77+
* `solver`: type of solver to use (if `nothing` the default is used). Either
78+
`FISTA` or `ISTA` can be used (proximal methods, with/without
79+
acceleration).
80+
"""
4981
@with_kw_noshow mutable struct LassoRegressor <: MMI.Deterministic
5082
lambda::Real = 1.0
5183
fit_intercept::Bool = true
@@ -64,6 +96,23 @@ descr(::Type{LassoRegressor}) = "Regression with objective function ``|Xθ - y|
6496
ELASTIC NET REGRESSOR
6597
===================== =#
6698

99+
"""
100+
$SIGNATURES
101+
102+
Elastic net regression model with objective function
103+
104+
``|Xθ - y|₂²/2 + λ|θ|₂²/2 + γ|θ|₁``
105+
106+
## Parameters
107+
108+
* `lambda` (Real): strength of the L2 regularisation.
109+
* `gamma` (Real): strength of the L1 regularisation.
110+
* `fit_intercept` (Bool): whether to fit the intercept or not.
111+
* `penalize_intercept` (Bool): whether to penalize the intercept.
112+
* `solver`: type of solver to use (if `nothing` the default is used). Either
113+
`FISTA` or `ISTA` can be used (proximal methods, with/without
114+
acceleration).
115+
"""
67116
@with_kw_noshow mutable struct ElasticNetRegressor <: MMI.Deterministic
68117
lambda::Real = 1.0
69118
gamma::Real = 0.0
@@ -83,6 +132,28 @@ descr(::Type{ElasticNetRegressor}) = "Regression with objective function ``|Xθ
83132
ROBUST REGRESSOR (General)
84133
========================== =#
85134

135+
"""
136+
$SIGNATURES
137+
138+
Robust regression model with objective function
139+
140+
``∑ρ(Xθ - y) + λ|θ|₂² + γ|θ|₁``
141+
142+
where `ρ` is a robust loss function (e.g. the Huber function).
143+
144+
## Parameters
145+
146+
* `rho` (RobustRho): the type of robust loss to use (see `HuberRho`,
147+
`TalwarRho`, ...)
148+
* `penalty` (Symbol or String): the penalty to use, either `:l2`, `:l1`, `:en`
149+
(elastic net) or `:none`. (Default: `:l2`)
150+
* `lambda` (Real): strength of the regulariser if `penalty` is `:l2` or `:l1`.
151+
Strength of the L2 regulariser if `penalty` is `:en`.
152+
* `gamma` (Real): strength of the L1 regulariser if `penalty` is `:en`.
153+
* `fit_intercept` (Bool): whether to fit an intercept (Default: `true`)
154+
* `penalize_intercept` (Bool): whether to penalize intercept (Default: `false`)
155+
* `solver` (Solver): type of solver to use, default if `nothing`.
156+
"""
86157
@with_kw_noshow mutable struct RobustRegressor <: MMI.Deterministic
87158
rho::RobustRho = HuberRho(0.1)
88159
lambda::Real = 1.0
@@ -105,6 +176,14 @@ descr(::Type{RobustRegressor}) = "Robust regression with objective ``∑ρ(Xθ -
105176
HUBER REGRESSOR
106177
=============== =#
107178

179+
"""
180+
$SIGNATURES
181+
182+
Huber Regression, see `RobustRegressor`, it's the same but with the robust loss
183+
set to `HuberRho`. The parameters are the same apart from `delta` which
184+
parametrises the `HuberRho` function (radius of the ball within which the loss
185+
is a quadratic loss).
186+
"""
108187
@with_kw_noshow mutable struct HuberRegressor <: MMI.Deterministic
109188
delta::Real = 0.5
110189
lambda::Real = 1.0
@@ -127,6 +206,14 @@ descr(::Type{HuberRegressor}) = "Robust regression with objective ``∑ρ(Xθ -
127206
QUANTILE REGRESSOR
128207
================== =#
129208

209+
"""
210+
$SIGNATURES
211+
212+
Quantile Regression, see `RobustRegressor`, it's the same but with the robust
213+
loss set to `QuantileRho`. The parameters are the same apart from `delta`
214+
which parametrises the `QuantileRho` function (indicating the quantile to use
215+
with default `0.5` for the median regression).
216+
"""
130217
@with_kw_noshow mutable struct QuantileRegressor <: MMI.Deterministic
131218
delta::Real = 0.5
132219
lambda::Real = 1.0
@@ -149,6 +236,17 @@ descr(::Type{QuantileRegressor}) = "Robust regression with objective ``∑ρ(Xθ
149236
LEAST ABSOLUTE DEVIATION REGRESSOR
150237
================================== =#
151238

239+
"""
240+
$SIGNATURES
241+
242+
Least Absolute Deviation regression with with objective function
243+
244+
``∑ρ(Xθ - y) + λ|θ|₂² + γ|θ|₁``
245+
246+
where `ρ` is the absolute loss.
247+
248+
See also `RobustRegressor`.
249+
"""
152250
@with_kw_noshow mutable struct LADRegressor <: MMI.Deterministic
153251
lambda::Real = 1.0
154252
gamma::Real = 0.0

0 commit comments

Comments
 (0)