Skip to content

Commit 422c8f8

Browse files
authored
fix default solver for robust regression, MLJ issue 401 (#44)
1 parent 85717c9 commit 422c8f8

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJLinearModels"
22
uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692"
33
authors = ["Thibaut Lienart <[email protected]>"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/fit/default.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,25 @@ export fit
55
# TODO: in the future, have cases where if the things are too big, take another default.
66
# also should check if p > n in which case should do dual stuff (or other appropriate alternative)
77

8+
# Linear, Ridge
89
_solver(::GLR{L2Loss,<:L2R}, np::NTuple{2,Int}) = Analytical()
910

11+
# Logistic, Multinomial
1012
_solver(::GLR{LogisticLoss,<:L2R}, np::NTuple{2,Int}) = LBFGS()
1113
_solver(::GLR{MultinomialLoss,<:L2R}, np::NTuple{2,Int}) = LBFGS()
1214

13-
function _solver(glr::GLR{<:SMOOTH_LOSS,<:ENR}, np::NTuple{2,Int})
15+
# Lasso, ElasticNet, Logistic, Multinomial
16+
function _solver(glr::GLR{<:SmoothLoss,<:ENR}, np::NTuple{2,Int})
1417
(is_l1(glr.penalty) || is_elnet(glr.penalty)) && return FISTA()
15-
@error "Not yet implemented"
18+
@error "Not yet implemented."
1619
end
1720

18-
_solver(::GLR{RobustLoss,<:L2R}, np::NTuple{2,Int}) = LBFGS()
19-
#_solver(::GLR{L1Loss,<:L2R}, np::NTuple{2,Int}) = FADMM()
21+
# Robust, Quantile
22+
_solver(::GLR{<:RobustLoss,<:L2R}, np::NTuple{2,Int}) = LBFGS()
2023

2124
# Fallback NOTE: should revisit bc with non-smooth, wouldn't work probably PGD/PSGD
2225
# depending on how much data there is
23-
_solver(::GLR, np::NTuple{2,Int}) = @error "Not yet implemented"
26+
_solver(::GLR, np::NTuple{2,Int}) = @error "Not yet implemented."
2427

2528

2629
"""

src/glr/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export objective, smooth_objective
22

33
# NOTE: RobustLoss are not always everywhere smooth but "smooth-enough".
4-
const SMOOTH_LOSS = Union{L2Loss, LogisticLoss, MultinomialLoss, RobustLoss}
4+
const SmoothLoss = Union{L2Loss, LogisticLoss, MultinomialLoss, RobustLoss}
55

66
"""
77
$SIGNATURES
@@ -34,7 +34,7 @@ $SIGNATURES
3434
3535
Return the smooth part of the objective function of a GLR.
3636
"""
37-
smooth_objective(glr::GLR{<:SMOOTH_LOSS,<:ENR}) = glr.loss + get_l2(glr.penalty)
37+
smooth_objective(glr::GLR{<:SmoothLoss,<:ENR}) = glr.loss + get_l2(glr.penalty)
3838
smooth_objective(::GLR) = @error "Case not implemented yet."
3939

4040
"""

0 commit comments

Comments
 (0)