Skip to content

Commit 4b5a703

Browse files
authored
Using the Light Interface (#51)
1 parent f199300 commit 4b5a703

18 files changed

+210
-154
lines changed

Project.toml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,36 @@
11
name = "MLJLinearModels"
22
uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692"
33
authors = ["Thibaut Lienart <[email protected]>"]
4-
version = "0.2.4"
4+
version = "0.3.0"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
88
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
11-
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
11+
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
1212
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1313
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
1414
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1515

1616
[compat]
1717
DocStringExtensions = "^0.8"
1818
IterativeSolvers = "^0.8"
19-
LinearMaps = "^2.5"
20-
MLJBase = "^0.9"
21-
Optim = "^0.19"
19+
LinearMaps = "^2.6"
20+
MLJModelInterface = "^0.1"
21+
Optim = "^0.20"
2222
Parameters = "^0.12"
2323
Tables = "^0.2"
24-
julia = "^1.0.0"
24+
julia = "^1"
2525

2626
[extras]
2727
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
28+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
2829
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
2930
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
3031
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
3132
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3233
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3334

3435
[targets]
35-
test = ["DelimitedFiles", "PyCall", "Test", "Random", "RDatasets", "RCall"]
36+
test = ["DelimitedFiles", "PyCall", "Test", "Random", "RDatasets", "RCall", "MLJBase"]

src/MLJLinearModels.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ import LinearMaps: LinearMap
66
import IterativeSolvers: cg
77
import Optim
88

9-
import MLJBase
9+
import MLJModelInterface
1010

1111
import Base.+, Base.-, Base.*, Base./, Base.convert
1212

13-
const AVR = AbstractVector{<:Real}
14-
13+
const MMI = MLJModelInterface
14+
const AVR = AbstractVector{<:Real}
1515
const Option{T} = Union{Nothing,T}
1616

1717
include("scratchspace.jl")

src/fit/analytical.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
"""
44
$SIGNATURES
55
6-
Fit a least square regression either with no penalty (OLS) or with a L2 penalty (Ridge).
6+
Fit a least square regression either with no penalty (OLS) or with a L2 penalty
7+
(Ridge).
78
89
## Complexity
910
1011
Assuming `n` dominates `p`,
1112
12-
* non-iterative (full solve): O(np²) - dominated by the construction of the Hessian X'X.
13-
* iterative (conjugate gradient): O(κnp) - with κ the number of CG steps (κ ≤ p).
13+
* non-iterative (full solve): O(np²) - dominated by the construction of the
14+
Hessian X'X.
15+
* iterative (conjugate gradient): O(κnp) - with κ the number of CG steps
16+
(κ ≤ p).
1417
"""
1518
function _fit(glr::GLR{L2Loss,<:L2R}, solver::Analytical, X, y)
1619
# full solve
@@ -34,7 +37,8 @@ function _fit(glr::GLR{L2Loss,<:L2R}, solver::Analytical, X, y)
3437
p = size(X, 2) + Int(glr.fit_intercept)
3538
max_cg_steps = min(solver.max_inner, p)
3639
# Form the Hessian map, cost of application H*v is O(np)
37-
Hm = LinearMap(Hv!(glr, X, y), p; ismutating=true, isposdef=true, issymmetric=true)
40+
Hm = LinearMap(Hv!(glr, X, y), p;
41+
ismutating=true, isposdef=true, issymmetric=true)
3842
b = X'y
3943
glr.fit_intercept && (b = vcat(b, sum(y)))
4044
return cg(Hm, b; maxiter=max_cg_steps)

src/fit/default.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ export fit
22

33
# Default solvers
44

5-
# TODO: in the future, have cases where if the things are too big, take another default.
6-
# also should check if p > n in which case should do dual stuff (or other appropriate alternative)
5+
# TODO: in the future, have cases where if the things are too big, take another
6+
# default. Also should check if p > n in which case should do dual stuff (or
7+
# other appropriate alternative)
78

89
# Linear, Ridge
910
_solver(::GLR{L2Loss,<:L2R}, np::NTuple{2,Int}) = Analytical()
@@ -21,8 +22,8 @@ end
2122
# Robust, Quantile
2223
_solver(::GLR{<:RobustLoss,<:L2R}, np::NTuple{2,Int}) = LBFGS()
2324

24-
# Fallback NOTE: should revisit bc with non-smooth, wouldn't work probably PGD/PSGD
25-
# depending on how much data there is
25+
# Fallback NOTE: should revisit bc with non-smooth, wouldn't work probably
26+
# PGD/PSGD depending on how much data there is
2627
_solver(::GLR, np::NTuple{2,Int}) = @error "Not yet implemented."
2728

2829

src/fit/iwls.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ function _fit(glr::GLR{RobustLoss{ρ},<:L2R}, solver::IWLSCG, X, y) where {ρ}
1616
# update the weights and retrieve the application function
1717
# Mθv! corresponds to the current application of (X'WX + λI) on v
1818
Mθv! = _Mv!(ω, θ)
19-
Mm = LinearMap(Mθv!, p; ismutating=true, isposdef=true, issymmetric=true)
19+
Mm = LinearMap(Mθv!, p;
20+
ismutating=true, isposdef=true, issymmetric=true)
2021
Wy = ω .* y
2122
b = X'Wy
2223
if glr.fit_intercept
@@ -30,6 +31,7 @@ function _fit(glr::GLR{RobustLoss{ρ},<:L2R}, solver::IWLSCG, X, y) where {ρ}
3031
copyto!(θ_, θ)
3132
k += 1
3233
end
33-
tol solver.tol || @warn "IWLS did not converge in $(solver.max_iter) iterations."
34+
tol solver.tol ||
35+
@warn "IWLS did not converge in $(solver.max_iter) iterations."
3436
return θ
3537
end

src/fit/newton.jl

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ Fit a GLR using Newton's method.
99
1010
## Complexity
1111
12-
Assuming `n` dominates `p`, O(κnp²), dominated by the construction of the Hessian at each step with
13-
κ the number of Newton steps.
12+
Assuming `n` dominates `p`, O(κnp²), dominated by the construction of the
13+
Hessian at each step with κ the number of Newton steps.
1414
"""
15-
function _fit(glr::GLR{<:Union{LogisticLoss,RobustLoss},<:L2R}, solver::Newton, X, y)
15+
function _fit(glr::GLR{<:Union{LogisticLoss,RobustLoss},<:L2R},
16+
solver::Newton, X, y)
1617
p = size(X, 2) + Int(glr.fit_intercept)
1718
θ₀ = zeros(p)
1819
_fgh! = fgh!(glr, X, y)
@@ -24,20 +25,21 @@ end
2425
"""
2526
$SIGNATURES
2627
27-
Fit a GLR using Newton's method combined with an iterative solver (conjugate gradient) to solve
28-
the Newton steps (∇²f)⁻¹∇f.
28+
Fit a GLR using Newton's method combined with an iterative solver (conjugate
29+
gradient) to solve the Newton steps (∇²f)⁻¹∇f.
2930
3031
## Complexity
3132
32-
Assuming `n` dominates `p`, O(κ₁κ₂np), dominated by the application of the Hessian at each step
33-
where κ₁ is the number of Newton steps and κ₂ is the average number of CG steps per Newton step
34-
(which is at most p).
33+
Assuming `n` dominates `p`, O(κ₁κ₂np), dominated by the application of the
34+
Hessian at each step where κ₁ is the number of Newton steps and κ₂ is the
35+
average number of CG steps per Newton step (which is at most p).
3536
"""
36-
function _fit(glr::GLR{<:Union{LogisticLoss,RobustLoss},<:L2R}, solver::NewtonCG, X, y)
37+
function _fit(glr::GLR{<:Union{LogisticLoss,RobustLoss},<:L2R},
38+
solver::NewtonCG, X, y)
3739
p = size(X, 2) + Int(glr.fit_intercept)
3840
θ₀ = zeros(p)
3941
_f = objective(glr, X, y)
40-
_fg! = (g, θ) -> fgh!(glr, X, y)(0.0, g, nothing, θ) # XXX: Optim.jl/issues/738
42+
_fg! = (g, θ) -> fgh!(glr, X, y)(0.0, g, nothing, θ) # Optim.jl/issues/738
4143
_Hv! = Hv!(glr, X, y)
4244
opt = Optim.TwiceDifferentiableHV(_f, _fg!, _Hv!, θ₀)
4345
res = Optim.optimize(opt, θ₀, Optim.KrylovTrustRegion())
@@ -51,10 +53,11 @@ Fit a GLR using LBFGS.
5153
5254
## Complexity
5355
54-
Assuming `n` dominates `p`, O(κnp), dominated by the computation of the gradient at each step with
55-
κ the number of LBFGS steps.
56+
Assuming `n` dominates `p`, O(κnp), dominated by the computation of the
57+
gradient at each step with κ the number of LBFGS steps.
5658
"""
57-
function _fit(glr::GLR{<:Union{LogisticLoss,RobustLoss},<:L2R}, solver::LBFGS, X, y)
59+
function _fit(glr::GLR{<:Union{LogisticLoss,RobustLoss},<:L2R},
60+
solver::LBFGS, X, y)
5861
p = size(X, 2) + Int(glr.fit_intercept)
5962
θ₀ = zeros(p)
6063
_fg! = (f, g, θ) -> fgh!(glr, X, y)(f, g, nothing, θ)
@@ -69,13 +72,15 @@ end
6972
"""
7073
$SIGNATURES
7174
72-
Fit a multiclass GLR using Newton's method with an iterative solver (conjugate gradient).
75+
Fit a multiclass GLR using Newton's method with an iterative solver (conjugate
76+
gradient).
7377
7478
## Complexity
7579
76-
Assuming `n` dominates `p`, O(κ₁κ₂npc), where `c` is the number of classes. The computations are
77-
dominated by the application of the Hessian at each step with κ₁ the number of Newton steps and κ₂
78-
the average number of CG steps per Newton step.
80+
Assuming `n` dominates `p`, O(κ₁κ₂npc), where `c` is the number of classes. The
81+
computations are dominated by the application of the Hessian at each step with
82+
κ₁ the number of Newton steps and κ₂ the average number of CG steps per Newton
83+
step.
7984
"""
8085
function _fit(glr::GLR{MultinomialLoss,<:L2R}, solver::NewtonCG, X, y)
8186
p = size(X, 2) + Int(glr.fit_intercept)
@@ -96,8 +101,9 @@ Fit a multiclass GLR using LBFGS.
96101
97102
## Complexity
98103
99-
Assuming `n` dominates `p`, O(κnpc), with `c` the number of classes, dominated by the computation
100-
of the gradient at each step with κ the number of LBFGS steps.
104+
Assuming `n` dominates `p`, O(κnpc), with `c` the number of classes, dominated
105+
by the computation of the gradient at each step with κ the number of LBFGS
106+
steps.
101107
"""
102108
function _fit(glr::GLR{MultinomialLoss,<:L2R}, solver::LBFGS, X, y)
103109
p = size(X, 2) + Int(glr.fit_intercept)

src/fit/proxgrad.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ function _fit(glr::GLR, solver::ProxGrad, X, y)
2222
_f = smooth_objective(glr, X, y; c=c)
2323
_fg! = smooth_fg!(glr, X, y)
2424
_prox! = prox!(glr)
25-
bt_cond = θ̂ -> _f(θ̂) > fθ̄ + dot(θ̂ .- θ̄, ∇fθ̄) + sum(abs2.(θ̂ .- θ̄))/(2η)
25+
bt_cond = θ̂ ->
26+
_f(θ̂) > fθ̄ + dot(θ̂ .- θ̄, ∇fθ̄) + sum(abs2.(θ̂ .- θ̄)) / (2η)
2627
# loop-related
2728
k, tol = 1, Inf
2829
while k solver.max_iter && tol > solver.tol
@@ -46,7 +47,8 @@ function _fit(glr::GLR, solver::ProxGrad, X, y)
4647
inner += 1
4748
end
4849
if inner == solver.max_inner
49-
@warn "No appropriate stepsize found via backtracking; interrupting."
50+
@warn "No appropriate stepsize found via backtracking; " *
51+
"interrupting."
5052
break
5153
end
5254
# update caches
@@ -59,6 +61,7 @@ function _fit(glr::GLR, solver::ProxGrad, X, y)
5961
# update niter
6062
k += 1
6163
end
62-
tol solver.tol || @warn "Proximal GD did not converge in $(solver.max_iter) iterations."
64+
tol solver.tol || @warn "Proximal GD did not converge in " *
65+
"$(solver.max_iter) iterations."
6366
return θ
6467
end

src/fit/solvers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct LBFGS <: Solver end
3838
@with_kw struct ProxGrad <: Solver
3939
accel::Bool = false # use Nesterov style acceleration (see also FISTA)
4040
max_iter::Int = 1000 # max number of overall iterations
41-
tol::Float64 = 1e-4 # tolerance over relative change of θ i.e. norm(θ-θ_)/norm(θ)
41+
tol::Float64 = 1e-4 # tol relative change of θ i.e. norm(θ-θ_)/norm(θ)
4242
max_inner::Int = 100 # β^max_inner should be > 1e-10
4343
beta::Float64 = 0.8 # in (0, 1); shrinkage in the backtracking step
4444
end
@@ -53,7 +53,7 @@ ISTA(; kwa...) = ProxGrad(;accel = false, kwa...)
5353
max_inner::Int = 200
5454
tol::Float64 = 1e-4
5555
damping::Float64 = 1.0 # should be between 0 and 1, 1 = trust iterates
56-
threshold::Float64 = 1e-6 # threshold for the residuals used for instance in quantile reg
56+
threshold::Float64 = 1e-6 # thresh for residuals; used eg in quantile reg
5757
end
5858

5959
# ===================== admm.jl

src/glr/constructors.jl

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ export GeneralizedLinearRegression, GLR,
66
RobustRegression, HuberRegression, QuantileRegression
77

88
"""
9-
GeneralizedLinearRegression{L<:Loss, P<:Penalty}
9+
GeneralizedLinearRegression{L<:Loss, P<:Penalty}
1010
1111
Generalized Linear Regression (GLR) model with objective function:
1212
1313
``L(y, Xθ) + P(θ)``
1414
15-
where `L` is a loss function, `P` a penalty, `y` is the vector of observed response, `X` is
16-
the feature matrix and `θ` the vector of parameters.
15+
where `L` is a loss function, `P` a penalty, `y` is the vector of observed
16+
response, `X` is the feature matrix and `θ` the vector of parameters.
1717
1818
Special cases include:
1919
@@ -74,8 +74,10 @@ $SIGNATURES
7474
7575
Objective function: ``|Xθ - y|₂²/2 + λ|θ|₂²/2 + γ|θ|₁``.
7676
"""
77-
function ElasticNetRegression::Real=1.0, γ::Real=1.0; lambda::Real=λ, gamma::Real=γ,
78-
fit_intercept::Bool=true, penalize_intercept::Bool=false)
77+
function ElasticNetRegression::Real=1.0, γ::Real=1.0;
78+
lambda::Real=λ, gamma::Real=γ,
79+
fit_intercept::Bool=true,
80+
penalize_intercept::Bool=false)
7981
check_pos.((lambda, gamma))
8082
GLR(penalty=lambda*L2Penalty()+gamma*L1Penalty(),
8183
fit_intercept=fit_intercept,
@@ -108,10 +110,11 @@ end
108110
"""
109111
$SIGNATURES
110112
111-
Objective function: ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁`` where `L` is either the logistic loss in the
112-
binary case or the multinomial loss otherwise.
113+
Objective function: ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁`` where `L` is either the
114+
logistic loss in the binary case or the multinomial loss otherwise.
113115
"""
114-
function LogisticRegression::Real=1.0, γ::Real=0.0; lambda::Real=λ, gamma::Real=γ,
116+
function LogisticRegression::Real=1.0, γ::Real=0.0;
117+
lambda::Real=λ, gamma::Real=γ,
115118
penalty::Symbol=iszero(gamma) ? :l2 : :en,
116119
multi_class::Bool=false, fit_intercept::Bool=true,
117120
penalize_intercept::Bool=false)
@@ -131,11 +134,13 @@ MultinomialRegression(a...; kwa...) = LogisticRegression(a...; multi_class=true,
131134
"""
132135
$SIGNATURES
133136
134-
Objective function: ``∑ρ(Xθ - y) + λ|θ|₂² + γ|θ|₁`` where ρ is a given function on the residuals.
137+
Objective function: ``∑ρ(Xθ - y) + λ|θ|₂² + γ|θ|₁`` where ρ is a given function
138+
on the residuals.
135139
"""
136140
function RobustRegression::RobustRho=HuberRho(0.1), λ::Real=1.0, γ::Real=0.0;
137141
rho::RobustRho=ρ, lambda::Real=λ, gamma::Real=γ,
138-
penalty::Symbol=iszero(gamma) ? :l2 : :en, fit_intercept::Bool=true,
142+
penalty::Symbol=iszero(gamma) ? :l2 : :en,
143+
fit_intercept::Bool=true,
139144
penalize_intercept::Bool=false)
140145
penalty = _l1l2en(lambda, gamma, penalty, "Robust regression")
141146
GLR(loss=RobustLoss(rho),
@@ -151,12 +156,14 @@ Huber Regression with objective:
151156
152157
``∑ρ(Xθ - y) + λ|θ|₂²/2 + γ|θ|₁``
153158
154-
Where `ρ` is the Huber function `ρ(r) = r²/2`` if `|r|≤δ` and `ρ(r)=δ(|r|-δ/2)` otherwise.
159+
Where `ρ` is the Huber function `ρ(r) = r²/2`` if `|r|≤δ` and
160+
`ρ(r)=δ(|r|-δ/2)` otherwise.
155161
"""
156162
function HuberRegression::Real=0.5, λ::Real=1.0, γ::Real=0.0;
157163
delta::Real=δ, lambda::Real=λ, gamma::Real=γ,
158164
penalty::Symbol=iszero(gamma) ? :l2 : :en,
159-
fit_intercept::Bool=true, penalize_intercept::Bool=false)
165+
fit_intercept::Bool=true,
166+
penalize_intercept::Bool=false)
160167
return RobustRegression(HuberRho(delta), lambda, gamma;
161168
penalty=penalty, fit_intercept=fit_intercept,
162169
penalize_intercept=penalize_intercept)
@@ -174,7 +181,8 @@ Where `ρ` is the check function `ρ(r) = r(δ - 1(r < 0))`.
174181
function QuantileRegression::Real=0.5, λ::Real=1.0, γ::Real=0.0;
175182
delta::Real=δ, lambda::Real=λ, gamma::Real=γ,
176183
penalty::Symbol=iszero(gamma) ? :l2 : :en,
177-
fit_intercept::Bool=true, penalize_intercept::Bool=false)
184+
fit_intercept::Bool=true,
185+
penalize_intercept::Bool=false)
178186
return RobustRegression(QuantileRho(delta), lambda, gamma;
179187
penalty=penalty, fit_intercept=fit_intercept,
180188
penalize_intercept=penalize_intercept)

0 commit comments

Comments
 (0)