Skip to content

Commit 9b885c3

Browse files
committed
MLJBase->MLJModelInterface
1 parent f199300 commit 9b885c3

File tree

6 files changed

+47
-45
lines changed

6 files changed

+47
-45
lines changed

Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,20 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
88
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
11-
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
11+
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
1212
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1313
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
1414
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1515

1616
[compat]
1717
DocStringExtensions = "^0.8"
1818
IterativeSolvers = "^0.8"
19-
LinearMaps = "^2.5"
20-
MLJBase = "^0.9"
21-
Optim = "^0.19"
19+
LinearMaps = "^2.6"
20+
MLJModelInterface = "^0.1"
21+
Optim = "^0.20"
2222
Parameters = "^0.12"
2323
Tables = "^0.2"
24-
julia = "^1.0.0"
24+
julia = "^1"
2525

2626
[extras]
2727
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"

src/MLJLinearModels.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ import LinearMaps: LinearMap
66
import IterativeSolvers: cg
77
import Optim
88

9-
import MLJBase
9+
import MLJModelInterface
1010

1111
import Base.+, Base.-, Base.*, Base./, Base.convert
1212

13-
const AVR = AbstractVector{<:Real}
14-
13+
const MMI = MLJModelInterface
14+
const AVR = AbstractVector{<:Real}
1515
const Option{T} = Union{Nothing,T}
1616

1717
include("scratchspace.jl")

src/mlj/classifiers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
LOGISTIC CLASSIFIER
33
=================== =#
44

5-
@with_kw_noshow mutable struct LogisticClassifier <: MLJBase.Probabilistic
5+
@with_kw_noshow mutable struct LogisticClassifier <: MMI.Probabilistic
66
lambda::Real = 1.0
77
gamma::Real = 0.0
88
penalty::SymStr = :l2
@@ -23,7 +23,7 @@ descr(::Type{LogisticClassifier}) = "Classifier corresponding to the loss functi
2323
MULTINOMIAL CLASSIFIER
2424
====================== =#
2525

26-
@with_kw_noshow mutable struct MultinomialClassifier <: MLJBase.Probabilistic
26+
@with_kw_noshow mutable struct MultinomialClassifier <: MMI.Probabilistic
2727
lambda::Real = 1.0
2828
gamma::Real = 0.0
2929
penalty::SymStr = :l2

src/mlj/interface.jl

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@ const SymStr = Union{Symbol,String}
77
include("regressors.jl")
88
include("classifiers.jl")
99

10-
const REG_MODELS = (LinearRegressor, RidgeRegressor, LassoRegressor, ElasticNetRegressor,
11-
RobustRegressor, HuberRegressor, QuantileRegressor, LADRegressor)
10+
const REG_MODELS = (LinearRegressor, RidgeRegressor, LassoRegressor,
11+
ElasticNetRegressor, RobustRegressor, HuberRegressor,
12+
QuantileRegressor, LADRegressor)
1213
const CLF_MODELS = (LogisticClassifier, MultinomialClassifier)
1314
const ALL_MODELS = (REG_MODELS..., CLF_MODELS...)
1415

1516
#= ==========
1617
REGRESSORS
1718
========== =#
1819

19-
function MLJBase.fit(m::Union{REG_MODELS...}, verb::Int, X, y)
20-
Xmatrix = MLJBase.matrix(X)
20+
function MMI.fit(m::Union{REG_MODELS...}, verb::Int, X, y)
21+
Xmatrix = MMI.matrix(X)
2122
reg = glr(m)
2223
solver = m.solver === nothing ? _solver(reg, size(Xmatrix)) : m.solver
2324
# get the parameters
@@ -26,9 +27,9 @@ function MLJBase.fit(m::Union{REG_MODELS...}, verb::Int, X, y)
2627
return θ, nothing, NamedTuple{}()
2728
end
2829

29-
MLJBase.predict(m::Union{REG_MODELS...}, θ, Xnew) = apply_X(MLJBase.matrix(Xnew), θ)
30+
MMI.predict(m::Union{REG_MODELS...}, θ, Xnew) = apply_X(MMI.matrix(Xnew), θ)
3031

31-
function MLJBase.fitted_params(m::Union{REG_MODELS...}, θ)
32+
function MMI.fitted_params(m::Union{REG_MODELS...}, θ)
3233
m.fit_intercept && return (coefs = θ[1:end-1], intercept = θ[end])
3334
return (coefs = θ, intercept = nothing)
3435
end
@@ -37,10 +38,10 @@ end
3738
CLASSIFIERS
3839
=========== =#
3940

40-
function MLJBase.fit(m::Union{CLF_MODELS...}, verb::Int, X, y)
41-
Xmatrix = MLJBase.matrix(X)
42-
yplain = convert.(Int, MLJBase.int(y))
43-
classes = MLJBase.classes(y[1])
41+
function MMI.fit(m::Union{CLF_MODELS...}, verb::Int, X, y)
42+
Xmatrix = MMI.matrix(X)
43+
yplain = convert.(Int, MMI.int(y))
44+
classes = MMI.classes(y[1])
4445
nclasses = length(classes)
4546
if nclasses == 2
4647
# recode
@@ -58,21 +59,21 @@ function MLJBase.fit(m::Union{CLF_MODELS...}, verb::Int, X, y)
5859
return (θ, c, classes), nothing, NamedTuple{}()
5960
end
6061

61-
function MLJBase.predict(m::Union{CLF_MODELS...}, (θ, c, classes), Xnew)
62-
Xmatrix = MLJBase.matrix(Xnew)
62+
function MMI.predict(m::Union{CLF_MODELS...}, (θ, c, classes), Xnew)
63+
Xmatrix = MMI.matrix(Xnew)
6364
preds = apply_X(Xmatrix, θ, c)
6465
# binary classification
6566
if c == 1
6667
preds .= sigmoid.(preds)
6768
preds = hcat(1.0 .- preds, preds) # scores for -1 and 1
68-
return [MLJBase.UnivariateFinite(classes, preds[i, :]) for i in 1:size(Xmatrix,1)]
69+
return [MMI.UnivariateFinite(classes, preds[i, :]) for i in 1:size(Xmatrix,1)]
6970
end
7071
# multiclass
7172
preds .= softmax(preds)
72-
return [MLJBase.UnivariateFinite(classes, preds[i, :]) for i in 1:size(Xmatrix,1)]
73+
return [MMI.UnivariateFinite(classes, preds[i, :]) for i in 1:size(Xmatrix,1)]
7374
end
7475

75-
function MLJBase.fitted_params(m::Union{CLF_MODELS...}, (θ, c, classes))
76+
function MMI.fitted_params(m::Union{CLF_MODELS...}, (θ, c, classes))
7677
if c > 1
7778
if m.fit_intercept
7879
W = reshape(θ, div(length(θ), c), c)
@@ -90,7 +91,7 @@ end
9091
METADATA FOR ALL MODELS
9192
======================= =#
9293

93-
MLJBase.metadata_pkg.(ALL_MODELS,
94+
MMI.metadata_pkg.(ALL_MODELS,
9495
name="MLJLinearModels",
9596
uuid="6ee0df7b-362f-4a72-a706-9e79364fb692",
9697
url="https://github.com/alan-turing-institute/MLJLinearModels.jl",
@@ -100,21 +101,21 @@ MLJBase.metadata_pkg.(ALL_MODELS,
100101

101102
descr_(M) = descr(M) *
102103
"\n→ based on [MLJLinearModels](https://github.com/alan-turing-institute/MLJLinearModels.jl)" *
103-
"\n→ do `@load $(MLJBase.name(M)) pkg=\"MLJLinearModels\" to use the model.`" *
104-
"\n→ do `?$(MLJBase.name(M))` for documentation."
105-
lp_(M) = "MLJLinearModels.$(MLJBase.name(M))"
104+
"\n→ do `@load $(MMI.name(M)) pkg=\"MLJLinearModels\" to use the model.`" *
105+
"\n→ do `?$(MMI.name(M))` for documentation."
106+
lp_(M) = "MLJLinearModels.$(MMI.name(M))"
106107

107108
for M in REG_MODELS
108-
MLJBase.metadata_model(M,
109-
input=MLJBase.Table(MLJBase.Continuous),
110-
target=AbstractVector{MLJBase.Continuous},
109+
MMI.metadata_model(M,
110+
input=MMI.Table(MMI.Continuous),
111+
target=AbstractVector{MMI.Continuous},
111112
weights=false,
112113
descr=descr_(M), path=lp_(M))
113114
end
114115
for M in CLF_MODELS
115-
MLJBase.metadata_model(M,
116-
input=MLJBase.Table(MLJBase.Continuous),
117-
target=AbstractVector{<:MLJBase.Finite},
116+
MMI.metadata_model(M,
117+
input=MMI.Table(MMI.Continuous),
118+
target=AbstractVector{<:MMI.Finite},
118119
weights=false,
119120
descr=descr_(M), path=lp_(M))
120121
end

src/mlj/regressors.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
LINEAR REGRESSOR (OLS)
33
====================== =#
44

5-
@with_kw_noshow mutable struct LinearRegressor <: MLJBase.Deterministic
5+
@with_kw_noshow mutable struct LinearRegressor <: MMI.Deterministic
66
fit_intercept::Bool = true
77
solver::Option{Solver} = nothing
88
end
@@ -15,7 +15,7 @@ descr(::Type{LinearRegressor}) = "Regression with objective function ``|Xθ - y|
1515
RIDGE REGRESSOR
1616
=============== =#
1717

18-
@with_kw_noshow mutable struct RidgeRegressor <: MLJBase.Deterministic
18+
@with_kw_noshow mutable struct RidgeRegressor <: MMI.Deterministic
1919
lambda::Real = 1.0
2020
fit_intercept::Bool = true
2121
penalize_intercept::Bool = false
@@ -31,7 +31,7 @@ descr(::Type{RidgeRegressor}) = "Regression with objective function ``|Xθ - y|
3131
LASSO REGRESSOR
3232
=============== =#
3333

34-
@with_kw_noshow mutable struct LassoRegressor <: MLJBase.Deterministic
34+
@with_kw_noshow mutable struct LassoRegressor <: MMI.Deterministic
3535
lambda::Real = 1.0
3636
fit_intercept::Bool = true
3737
penalize_intercept::Bool = false
@@ -47,7 +47,7 @@ descr(::Type{LassoRegressor}) = "Regression with objective function ``|Xθ - y|
4747
ELASTIC NET REGRESSOR
4848
===================== =#
4949

50-
@with_kw_noshow mutable struct ElasticNetRegressor <: MLJBase.Deterministic
50+
@with_kw_noshow mutable struct ElasticNetRegressor <: MMI.Deterministic
5151
lambda::Real = 1.0
5252
gamma::Real = 0.0
5353
fit_intercept::Bool = true
@@ -65,7 +65,7 @@ descr(::Type{ElasticNetRegressor}) = "Regression with objective function ``|Xθ
6565
ROBUST REGRESSOR (General)
6666
========================== =#
6767

68-
@with_kw_noshow mutable struct RobustRegressor <: MLJBase.Deterministic
68+
@with_kw_noshow mutable struct RobustRegressor <: MMI.Deterministic
6969
rho::RobustRho = HuberRho(0.1)
7070
lambda::Real = 1.0
7171
gamma::Real = 0.0
@@ -85,7 +85,7 @@ descr(::Type{RobustRegressor}) = "Robust regression with objective ``∑ρ(Xθ -
8585
HUBER REGRESSOR
8686
=============== =#
8787

88-
@with_kw_noshow mutable struct HuberRegressor <: MLJBase.Deterministic
88+
@with_kw_noshow mutable struct HuberRegressor <: MMI.Deterministic
8989
delta::Real = 0.5
9090
lambda::Real = 1.0
9191
gamma::Real = 0.0
@@ -105,7 +105,7 @@ descr(::Type{HuberRegressor}) = "Robust regression with objective ``∑ρ(Xθ -
105105
QUANTILE REGRESSOR
106106
================== =#
107107

108-
@with_kw_noshow mutable struct QuantileRegressor <: MLJBase.Deterministic
108+
@with_kw_noshow mutable struct QuantileRegressor <: MMI.Deterministic
109109
delta::Real = 0.5
110110
lambda::Real = 1.0
111111
gamma::Real = 0.0
@@ -126,7 +126,7 @@ descr(::Type{QuantileRegressor}) = "Robust regression with objective ``∑ρ(Xθ
126126
LEAST ABSOLUTE DEVIATION REGRESSOR
127127
================================== =#
128128

129-
@with_kw_noshow mutable struct LADRegressor <: MLJBase.Deterministic
129+
@with_kw_noshow mutable struct LADRegressor <: MMI.Deterministic
130130
lambda::Real = 1.0
131131
gamma::Real = 0.0
132132
penalty::SymStr = :l2

test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using MLJLinearModels, Test, LinearAlgebra, Random
2-
import MLJBase
3-
DO_COMPARISONS = false; include("testutils.jl")
2+
using MLJBase # not MLJModelInterface, to mimick the full interface
3+
4+
DO_COMPARISONS = true; include("testutils.jl")
45

56
m("UTILS"); include("utils.jl")
67

0 commit comments

Comments
 (0)