Skip to content

Commit fb88fff

Browse files
authored
Univariate Finite 2 (#77)
1 parent fa19b24 commit fb88fff

File tree

4 files changed

+15
-25
lines changed

4 files changed

+15
-25
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJLinearModels"
22
uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692"
33
authors = ["Thibaut Lienart <[email protected]>"]
4-
version = "0.4.0"
4+
version = "0.5.0"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -16,7 +16,7 @@ Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
1616
DocStringExtensions = "^0.8"
1717
IterativeSolvers = "^0.8"
1818
LinearMaps = "^2.6"
19-
MLJModelInterface = "^0.1,^0.2"
19+
MLJModelInterface = "^0.3"
2020
Optim = "^0.20,^0.21"
2121
Parameters = "^0.12"
2222
julia = "^1"

src/mlj/classifiers.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ the strength of the L2 (resp. L1) regularisation components.
2525
* `fit_intercept` (Bool): whether to fit an intercept (Default: `true`)
2626
* `penalize_intercept` (Bool): whether to penalize intercept (Default: `false`)
2727
* `solver` (Solver): type of solver to use, default if `nothing`.
28-
* `multi_class` (Bool): whether it's a binary or multi class classification
29-
problem. This is usually set automatically.
3028
"""
3129
@with_kw_noshow mutable struct LogisticClassifier <: MMI.Probabilistic
3230
lambda::Real = 1.0
@@ -35,17 +33,15 @@ the strength of the L2 (resp. L1) regularisation components.
3533
fit_intercept::Bool = true
3634
penalize_intercept::Bool = false
3735
solver::Option{Solver} = nothing
38-
multi_class::Bool = false
39-
nclasses::Int = 2
4036
end
4137

42-
glr(m::LogisticClassifier) =
38+
glr(m::LogisticClassifier, nclasses::Integer) =
4339
LogisticRegression(m.lambda, m.gamma;
4440
penalty=Symbol(m.penalty),
45-
multi_class=m.multi_class,
41+
multi_class=(nclasses > 2),
4642
fit_intercept=m.fit_intercept,
4743
penalize_intercept=m.penalize_intercept,
48-
nclasses=m.nclasses)
44+
nclasses=nclasses)
4945

5046
descr(::Type{LogisticClassifier}) = "Classifier corresponding to the loss function ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁`` where `L` is the logistic loss."
5147

@@ -66,15 +62,14 @@ to `true` by default. The other parameters are the same.
6662
fit_intercept::Bool = true
6763
penalize_intercept::Bool = false
6864
solver::Option{Solver} = nothing
69-
nclasses::Int = 2 # leave to 2, cf LogisticRegression
7065
end
7166

72-
glr(m::MultinomialClassifier) =
67+
glr(m::MultinomialClassifier, nclasses::Integer) =
7368
MultinomialRegression(m.lambda, m.gamma;
7469
penalty=Symbol(m.penalty),
7570
fit_intercept=m.fit_intercept,
7671
penalize_intercept=m.penalize_intercept,
77-
nclasses=m.nclasses)
72+
nclasses=nclasses)
7873

7974
descr(::Type{MultinomialClassifier}) =
8075
"Classifier corresponding to the loss function " *

src/mlj/interface.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,36 +54,31 @@ function MMI.fit(m::Union{CLF_MODELS...}, verb::Int, X, y)
5454
yplain[yplain .== 1] .= -1
5555
yplain[yplain .== 2] .= 1
5656
# force the binary case
57-
m.multi_class = false
58-
m.nclasses = 0
59-
else # > 2
60-
m.nclasses = nclasses
57+
nclasses = 0
6158
end
6259
# NOTE: here the number of classes is either 0 or > 2
63-
clf = glr(m)
60+
clf = glr(m, nclasses)
6461
solver = m.solver === nothing ? _solver(clf, size(Xmatrix)) : m.solver
6562
# get the parameters
6663
θ = fit(clf, Xmatrix, yplain, solver=solver)
6764
# return
68-
return (θ, features, classes), nothing, NamedTuple{}()
65+
return (θ, features, classes, nclasses), nothing, NamedTuple{}()
6966
end
7067

71-
function MMI.predict(m::Union{CLF_MODELS...}, (θ, features, classes), Xnew)
68+
function MMI.predict(m::Union{CLF_MODELS...}, (θ, features, classes, c), Xnew)
7269
Xmatrix = MMI.matrix(Xnew)
73-
c = m.nclasses
7470
preds = apply_X(Xmatrix, θ, c)
7571
if c > 2 # multiclass
7672
preds .= softmax(preds)
7773
else # binary (necessarily c==0)
7874
preds .= sigmoid.(preds)
7975
preds = hcat(1.0 .- preds, preds) # scores for -1 and 1
80-
return [MMI.UnivariateFinite(classes, preds[i, :]) for i in 1:size(Xmatrix,1)]
76+
return MMI.UnivariateFinite(classes, preds)
8177
end
82-
return [MMI.UnivariateFinite(classes, preds[i, :]) for i in 1:size(Xmatrix,1)]
78+
return MMI.UnivariateFinite(classes, preds)
8379
end
8480

85-
function MMI.fitted_params(m::Union{CLF_MODELS...}, (θ, features, classes))
86-
c = m.nclasses
81+
function MMI.fitted_params(m::Union{CLF_MODELS...}, (θ, features, classes, c))
8782
# helper function to assemble the results
8883
_fitted_params(coefs, features, intercept) =
8984
(classes=classes, coefs=coef_vec(coefs, features), intercept=intercept)

test/interface/fitpredict.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ end
6363
@testset "String-Symbol" begin
6464
model = LogisticClassifier(penalty="l1")
6565
@test model.penalty == "l1"
66-
gr = MLJLinearModels.glr(model)
66+
gr = MLJLinearModels.glr(model, 2)
6767
@test gr isa GLR
6868
@test gr.penalty isa ScaledPenalty{L1Penalty}
6969
end

0 commit comments

Comments
 (0)