Skip to content

Commit aa7c4a9

Browse files
authored
Symstr (#45)
* penalties for basic regressions can be specified with strings like sklearn * patch release
1 parent 289a373 commit aa7c4a9

File tree

5 files changed

+25
-13
lines changed

5 files changed

+25
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJLinearModels"
22
uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692"
33
authors = ["Thibaut Lienart <[email protected]>"]
4-
version = "0.2.3"
4+
version = "0.2.4"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/mlj/classifiers.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
@with_kw_noshow mutable struct LogisticClassifier <: MLJBase.Probabilistic
66
lambda::Real = 1.0
77
gamma::Real = 0.0
8-
penalty::Symbol = :l2
8+
penalty::SymStr = :l2
99
fit_intercept::Bool = true
1010
penalize_intercept::Bool = false
1111
solver::Option{Solver} = nothing
1212
multi_class::Bool = false
1313
end
1414

15-
glr(m::LogisticClassifier) = LogisticRegression(m.lambda, m.gamma; penalty=m.penalty,
15+
glr(m::LogisticClassifier) = LogisticRegression(m.lambda, m.gamma; penalty=Symbol(m.penalty),
1616
multi_class=m.multi_class,
1717
fit_intercept=m.fit_intercept,
1818
penalize_intercept=m.penalize_intercept)
@@ -26,13 +26,13 @@ descr(::Type{LogisticClassifier}) = "Classifier corresponding to the loss functi
2626
@with_kw_noshow mutable struct MultinomialClassifier <: MLJBase.Probabilistic
2727
lambda::Real = 1.0
2828
gamma::Real = 0.0
29-
penalty::Symbol = :l2
29+
penalty::SymStr = :l2
3030
fit_intercept::Bool = true
3131
penalize_intercept::Bool = false
3232
solver::Option{Solver} = nothing
3333
end
3434

35-
glr(m::MultinomialClassifier) = MultinomialRegression(m.lambda, m.gamma; penalty=m.penalty,
35+
glr(m::MultinomialClassifier) = MultinomialRegression(m.lambda, m.gamma; penalty=Symbol(m.penalty),
3636
fit_intercept=m.fit_intercept,
3737
penalize_intercept=m.penalize_intercept)
3838

src/mlj/interface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ export LinearRegressor, RidgeRegressor, LassoRegressor, ElasticNetRegressor,
22
RobustRegressor, HuberRegressor, QuantileRegressor, LADRegressor,
33
LogisticClassifier, MultinomialClassifier
44

5+
const SymStr = Union{Symbol,String}
6+
57
include("regressors.jl")
68
include("classifiers.jl")
79

src/mlj/regressors.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ descr(::Type{ElasticNetRegressor}) = "Regression with objective function ``|Xθ
6969
rho::RobustRho = HuberRho(0.1)
7070
lambda::Real = 1.0
7171
gamma::Real = 0.0
72-
penalty::Symbol = :l2
72+
penalty::SymStr = :l2
7373
fit_intercept::Bool = true
7474
penalize_intercept::Bool = false
7575
solver::Option{Solver} = nothing
7676
end
7777

78-
glr(m::RobustRegressor) = RobustRegression(m.rho, m.lambda, m.gamma; penalty=m.penalty,
78+
glr(m::RobustRegressor) = RobustRegression(m.rho, m.lambda, m.gamma; penalty=Symbol(m.penalty),
7979
fit_intercept=m.fit_intercept,
8080
penalize_intercept=m.penalize_intercept)
8181

@@ -89,13 +89,13 @@ descr(::Type{RobustRegressor}) = "Robust regression with objective ``∑ρ(Xθ -
8989
delta::Real = 0.5
9090
lambda::Real = 1.0
9191
gamma::Real = 0.0
92-
penalty::Symbol = :l2
92+
penalty::SymStr = :l2
9393
fit_intercept::Bool = true
9494
penalize_intercept::Bool = false
9595
solver::Option{Solver} = nothing
9696
end
9797

98-
glr(m::HuberRegressor) = HuberRegression(m.delta, m.lambda, m.gamma; penalty=m.penalty,
98+
glr(m::HuberRegressor) = HuberRegression(m.delta, m.lambda, m.gamma; penalty=Symbol(m.penalty),
9999
fit_intercept=m.fit_intercept,
100100
penalize_intercept=m.penalize_intercept)
101101

@@ -109,13 +109,14 @@ descr(::Type{HuberRegressor}) = "Robust regression with objective ``∑ρ(Xθ -
109109
delta::Real = 0.5
110110
lambda::Real = 1.0
111111
gamma::Real = 0.0
112-
penalty::Symbol = :l2
112+
penalty::SymStr = :l2
113113
fit_intercept::Bool = true
114114
penalize_intercept::Bool = false
115115
solver::Option{Solver} = nothing
116116
end
117117

118-
glr(m::QuantileRegressor) = QuantileRegression(m.delta, m.lambda, m.gamma; penalty=m.penalty,
118+
glr(m::QuantileRegressor) = QuantileRegression(m.delta, m.lambda, m.gamma;
119+
penalty=Symbol(m.penalty),
119120
fit_intercept=m.fit_intercept,
120121
penalize_intercept=m.penalize_intercept)
121122

@@ -128,13 +129,13 @@ descr(::Type{QuantileRegressor}) = "Robust regression with objective ``∑ρ(Xθ
128129
@with_kw_noshow mutable struct LADRegressor <: MLJBase.Deterministic
129130
lambda::Real = 1.0
130131
gamma::Real = 0.0
131-
penalty::Symbol = :l2
132+
penalty::SymStr = :l2
132133
fit_intercept::Bool = true
133134
penalize_intercept::Bool = false
134135
solver::Option{Solver} = nothing
135136
end
136137

137-
glr(m::LADRegressor) = LADRegression(m.lambda, m.gamma; penalty=m.penalty,
138+
glr(m::LADRegressor) = LADRegression(m.lambda, m.gamma; penalty=Symbol(m.penalty),
138139
fit_intercept=m.fit_intercept,
139140
penalize_intercept=m.penalize_intercept)
140141

test/interface/fitpredict.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,12 @@ end
6060
mcr = MLJBase.misclassification_rate(ŷ, yc)
6161
@test mcr 0.2
6262
end
63+
64+
# see issue https://github.com/alan-turing-institute/MLJ.jl/issues/387
65+
@testset "String-Symbol" begin
66+
model = LogisticClassifier(penalty="l1")
67+
@test model.penalty == "l1"
68+
gr = MLJLinearModels.glr(model)
69+
@test gr isa GLR
70+
@test gr.penalty isa ScaledPenalty{L1Penalty}
71+
end

0 commit comments

Comments
 (0)