Skip to content

Commit 202e058

Browse files
authored
Interface with MLJ (#34)
* add interface with MLJ + tests & version bump
1 parent 1e5af70 commit 202e058

21 files changed

+514
-95
lines changed

Project.toml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
name = "MLJLinearModels"
22
uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692"
33
authors = ["Thibaut Lienart <[email protected]>"]
4-
version = "0.1.0"
4+
version = "0.2.0"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
88
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
11+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
1112
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1213
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
14+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1315

1416
[compat]
15-
DocStringExtensions = ">= 0.7.0"
16-
Optim = ">= 0.19"
17-
IterativeSolvers = ">= 0.8"
18-
Parameters = ">= 0.10"
19-
LinearMaps = ">= 2.5"
17+
DocStringExtensions = "^0.8"
18+
IterativeSolvers = "^0.8"
19+
LinearMaps = "^2.5"
20+
MLJBase = "^0.7"
21+
Optim = "^0.19"
22+
Parameters = "^0.12"
23+
Tables = "^0.2"
2024
julia = "^1.0.0"
2125

2226
[extras]

src/MLJLinearModels.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ import LinearMaps: LinearMap
66
import IterativeSolvers: cg
77
import Optim
88

9+
import MLJBase
10+
911
import Base.+, Base.-, Base.*, Base./, Base.convert
1012

1113
const AVR = AbstractVector{<:Real}
1214

15+
const Option{T} = Union{Nothing,T}
16+
1317
include("scratchspace.jl")
1418

1519
include("utils.jl")
@@ -39,4 +43,7 @@ include("fit/proxgrad.jl")
3943
include("fit/iwls.jl")
4044
# include("fit/admm.jl")
4145

46+
# > Interface <
47+
include("mlj/interface.jl")
48+
4249
end # module

src/glr/constructors.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ end
5858
"""
5959
$SIGNATURES
6060
61-
Objective function: ``|Xθ - y|₂²/2 + λ|θ|₁``
61+
Objective function: ``|Xθ - y|₂²/2 + λ|θ|₁``.
6262
"""
6363
function LassoRegression::Real=1.0; lambda::Real=λ, fit_intercept::Bool=true,
6464
penalize_intercept::Bool=false)
@@ -72,7 +72,7 @@ end
7272
"""
7373
$SIGNATURES
7474
75-
Objective function: ``|Xθ - y|₂²/2 + λ|θ|₂²/2 + γ|θ|₁``
75+
Objective function: ``|Xθ - y|₂²/2 + λ|θ|₂²/2 + γ|θ|₁``.
7676
"""
7777
function ElasticNetRegression::Real=1.0, γ::Real=1.0; lambda::Real=λ, gamma::Real=γ,
7878
fit_intercept::Bool=true, penalize_intercept::Bool=false)
@@ -131,9 +131,7 @@ MultinomialRegression(a...; kwa...) = LogisticRegression(a...; multi_class=true,
131131
"""
132132
$SIGNATURES
133133
134-
Objective function: ``∑ρ(Xθ - y) + λ|θ|₂²`` where ρ is a given function on the residuals and
135-
δ a positive tuning parameter for the function in question (e.g. for Huber it corresponds to the
136-
radius of the ball in which residuals are weighed quadratically).
134+
Objective function: ``∑ρ(Xθ - y) + λ|θ|₂² + γ|θ|₁`` where ρ is a given function on the residuals.
137135
"""
138136
function RobustRegression::RobustRho=HuberRho(0.1), λ::Real=1.0, γ::Real=0.0;
139137
rho::RobustRho=ρ, lambda::Real=λ, gamma::Real=γ,
@@ -151,7 +149,7 @@ $SIGNATURES
151149
152150
Huber Regression with objective:
153151
154-
``∑ρ(Xθ - y) + λ|θ|₂²/2 + γ|θ|``
152+
``∑ρ(Xθ - y) + λ|θ|₂²/2 + γ|θ|``
155153
156154
Where `ρ` is the Huber function `ρ(r) = r²/2`` if `|r|≤δ` and `ρ(r)=δ(|r|-δ/2)` otherwise.
157155
"""
@@ -169,7 +167,7 @@ $SIGNATURES
169167
170168
Quantile Regression with objective:
171169
172-
``∑ρ(Xθ - y) + λ|θ|₂²/2 + γ|θ|``
170+
``∑ρ(Xθ - y) + λ|θ|₂²/2 + γ|θ|``
173171
174172
Where `ρ` is the check function `ρ(r) = r(δ - 1(r < 0))`.
175173
"""
@@ -187,7 +185,7 @@ $SIGNATURES
187185
188186
Least Absolute Deviation regression with objective:
189187
190-
``|Xθ - y|₁ + λ|θ|₂²/2 + γ|θ|``
188+
``|Xθ - y|₁ + λ|θ|₂²/2 + γ|θ|``
191189
192190
This is a specific type of Quantile Regression with `δ=0.5` (median).
193191
"""

src/mlj/classifiers.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#= ===================
2+
LOGISTIC CLASSIFIER
3+
=================== =#
4+
5+
@with_kw_noshow mutable struct LogisticClassifier <: MLJBase.Probabilistic
6+
lambda::Real = 1.0
7+
gamma::Real = 0.0
8+
penalty::Symbol = :l2
9+
fit_intercept::Bool = true
10+
penalize_intercept::Bool = false
11+
solver::Option{Solver} = nothing
12+
multi_class::Bool = false
13+
end
14+
15+
glr(m::LogisticClassifier) = LogisticRegression(m.lambda, m.gamma; penalty=m.penalty,
16+
multi_class=m.multi_class,
17+
fit_intercept=m.fit_intercept,
18+
penalize_intercept=m.penalize_intercept)
19+
20+
descr(::Type{LogisticClassifier}) = "Classifier corresponding to the loss function ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁`` where `L` is the logistic loss."
21+
22+
#= ======================
23+
MULTINOMIAL CLASSIFIER
24+
====================== =#
25+
26+
@with_kw_noshow mutable struct MultinomialClassifier <: MLJBase.Probabilistic
27+
lambda::Real = 1.0
28+
gamma::Real = 0.0
29+
penalty::Symbol = :l2
30+
fit_intercept::Bool = true
31+
penalize_intercept::Bool = false
32+
solver::Option{Solver} = nothing
33+
end
34+
35+
glr(m::MultinomialClassifier) = MultinomialRegression(m.lambda, m.gamma; penalty=m.penalty,
36+
fit_intercept=m.fit_intercept,
37+
penalize_intercept=m.penalize_intercept)
38+
39+
descr(::Type{MultinomialClassifier}) = "Classifier corresponding to the loss function ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁`` where `L` is the multinomial loss."

src/mlj/interface.jl

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
export LinearRegressor, RidgeRegressor, LassoRegressor, ElasticNetRegressor,
2+
RobustRegressor, HuberRegressor, QuantileRegressor, LADRegressor,
3+
LogisticClassifier, MultinomialClassifier
4+
5+
include("regressors.jl")
6+
include("classifiers.jl")
7+
8+
const REG_MODELS = (LinearRegressor, RidgeRegressor, LassoRegressor, ElasticNetRegressor,
9+
RobustRegressor, HuberRegressor, QuantileRegressor, LADRegressor)
10+
const CLF_MODELS = (LogisticClassifier, MultinomialClassifier)
11+
const ALL_MODELS = (REG_MODELS..., CLF_MODELS...)
12+
13+
#= ==========
14+
REGRESSORS
15+
========== =#
16+
17+
function MLJBase.fit(m::Union{REG_MODELS...}, verb::Int, X, y)
18+
Xmatrix = MLJBase.matrix(X)
19+
reg = glr(m)
20+
solver = m.solver === nothing ? _solver(reg, size(Xmatrix)) : m.solver
21+
# get the parameters
22+
θ = fit(reg, Xmatrix, y; solver=solver)
23+
# return
24+
return θ, nothing, NamedTuple{}()
25+
end
26+
27+
MLJBase.predict(m::Union{REG_MODELS...}, θ, Xnew) = apply_X(MLJBase.matrix(Xnew), θ)
28+
29+
function MLJBase.fitted_params(m::Union{REG_MODELS...}, θ)
30+
m.fit_intercept && return (coefs = θ[1:end-1], intercept = θ[end])
31+
return (coefs = θ, intercept = nothing)
32+
end
33+
34+
#= ===========
35+
CLASSIFIERS
36+
=========== =#
37+
38+
function MLJBase.fit(m::Union{CLF_MODELS...}, verb::Int, X, y)
39+
Xmatrix = MLJBase.matrix(X)
40+
yplain = convert.(Int, MLJBase.int(y))
41+
classes = MLJBase.classes(y[1])
42+
nclasses = length(classes)
43+
if nclasses == 2
44+
# recode
45+
yplain[yplain .== 1] .= -1
46+
yplain[yplain .== 2] .= 1
47+
c = 1
48+
else
49+
c = nclasses
50+
end
51+
clf = glr(m)
52+
solver = m.solver === nothing ? _solver(clf, size(Xmatrix)) : m.solver
53+
# get the parameters
54+
θ = fit(clf, Xmatrix, yplain, solver=solver)
55+
# return
56+
return (θ, c, classes), nothing, NamedTuple{}()
57+
end
58+
59+
function MLJBase.predict(m::Union{CLF_MODELS...}, (θ, c, classes), Xnew)
60+
Xmatrix = MLJBase.matrix(Xnew)
61+
preds = apply_X(Xmatrix, θ, c)
62+
# binary classification
63+
if c == 1
64+
preds .= sigmoid.(preds)
65+
preds = hcat(1.0 .- preds, preds) # scores for -1 and 1
66+
return [MLJBase.UnivariateFinite(classes, preds[i, :]) for i in 1:size(Xmatrix,1)]
67+
end
68+
# multiclass
69+
preds .= softmax(preds)
70+
return [MLJBase.UnivariateFinite(classes, preds[i, :]) for i in 1:size(Xmatrix,1)]
71+
end
72+
73+
function MLJBase.fitted_params(m::Union{CLF_MODELS...}, (θ, c, classes))
74+
if c > 1
75+
if m.fit_intercept
76+
W = reshape(θ, div(length(θ), c), c)
77+
return (coefs = W, intercept = nothing)
78+
end
79+
W = reshape(θ, p+1, c)
80+
return (coefs = W[1:p, :], intercept = W[end, :])
81+
end
82+
# single class
83+
m.fit_intercept && return (coefs = θ[1:end-1], intercept = θ[end])
84+
return (coefs = θ, intercept = nothing)
85+
end
86+
87+
#= =======================
88+
METADATA FOR ALL MODELS
89+
======================= =#
90+
91+
MLJBase.metadata_pkg.(ALL_MODELS,
92+
name="MLJLinearModels",
93+
uuid="6ee0df7b-362f-4a72-a706-9e79364fb692",
94+
url="https://github.com/alan-turing-institute/MLJLinearModels.jl",
95+
julia=true,
96+
license="MIT",
97+
is_wrapper=false)
98+
99+
descr_(M) = descr(M) *
100+
"\n→ based on [MLJLinearModels](https://github.com/alan-turing-institute/MLJLinearModels.jl)" *
101+
"\n→ do `@load $(MLJBase.name(M)) pkg=\"MLJLinearModels\" to use the model.`" *
102+
"\n→ do `?$(MLJBase.name(M))` for documentation."
103+
lp_(M) = "MLJLinearModels.$(MLJBase.name(M))"
104+
105+
for M in REG_MODELS
106+
MLJBase.metadata_model(M,
107+
input=MLJBase.Table(MLJBase.Continuous),
108+
target=AbstractVector{MLJBase.Continuous},
109+
weights=false,
110+
descr=descr_(M), path=lp_(M))
111+
end
112+
for M in CLF_MODELS
113+
MLJBase.metadata_model(M,
114+
input=MLJBase.Table(MLJBase.Continuous),
115+
target=AbstractVector{<:MLJBase.Finite},
116+
weights=false,
117+
descr=descr_(M), path=lp_(M))
118+
end

0 commit comments

Comments
 (0)