Skip to content

Commit b7eb8fc

Browse files
committed
fix for fitted_params in multinomial case with no intercept
1 parent 6d4b488 commit b7eb8fc

File tree

4 files changed

+9
-3
lines changed

4 files changed

+9
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
*DS_Store
22
Manifest.toml
33
docs/build
4+
sandbox.jl

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJLinearModels"
22
uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692"
33
authors = ["Thibaut Lienart <[email protected]>"]
4-
version = "0.5.4"
4+
version = "0.5.5"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/mlj/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ function MMI.fitted_params(m::Union{CLF_MODELS...}, (θ, features, classes, c))
8888
if m.fit_intercept
8989
return _fitted_params(W, features, W[end, :])
9090
end
91-
return _fitted_params(W[1:end-1, :], features, nothing)
91+
return _fitted_params(W, features, nothing)
9292
end
9393
# single class (necessarily c==0)
9494
m.fit_intercept && return _fitted_params(θ[1:end-1], features, θ[end])

test/interface/fitpredict.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,14 @@ end
4949
Xt = MLJBase.table(X)
5050
yc = MLJBase.categorical(y1)
5151

52-
mc = MultinomialClassifier(lambda=λ, gamma=γ)
52+
mc = MultinomialClassifier(lambda=λ, gamma=γ, fit_intercept=false)
5353
fr, = MLJBase.fit(mc, 1, Xt, yc)
5454

55+
mach = MLJBase.machine(mc, Xt, yc)
56+
MLJBase.fit!(mach)
57+
fp = MLJBase.fitted_params(mach)
58+
@test length(fp.coefs) == 5
59+
5560
= MLJBase.predict(mc, fr, Xt)
5661
= MLJBase.mode.(ŷ)
5762

0 commit comments

Comments
 (0)