Skip to content

Commit ba0c224

Browse files
committed
closes #71
1 parent 957c166 commit ba0c224

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

src/mlj/interface.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@ function MMI.fit(m::Union{CLF_MODELS...}, verb::Int, X, y)
5555
else
5656
c = nclasses
5757
end
58-
clf = glr(m)
58+
clf = glr(m)
59+
# allow logclf to become multiclf
60+
if m isa LogisticClassifier
61+
m.multi_class = c > 1
62+
end
63+
5964
solver = m.solver === nothing ? _solver(clf, size(Xmatrix)) : m.solver
6065
# get the parameters
6166
θ = fit(clf, Xmatrix, yplain, solver=solver)

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ $SIGNATURES
44
Return nothing if the number of rows of `X` and `y` match and throws a
55
`DimensionMismatch` error otherwise.
66
"""
7-
function check_nrows(X::Matrix, y::VecOrMat)::Nothing
7+
function check_nrows(X::AbstractMatrix, y::AbstractVecOrMat)::Nothing
88
size(X, 1) == size(y, 1) && return nothing
99
throw(DimensionMismatch("`X` and `y` must have the same number of rows."))
1010
end
@@ -21,7 +21,7 @@ $SIGNATURES
2121
2222
Given a matrix `X`, append a column of ones if `fit_intercept` is true.
2323
"""
24-
function augment_X(X::Matrix{<:Real}, fit_intercept::Bool)
24+
function augment_X(X::AbstractMatrix{<:Real}, fit_intercept::Bool)
2525
fit_intercept || return X
2626
return hcat(X, ones(eltype(X), size(X, 1)))
2727
end

test/interface/fitpredict.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,13 @@ end
6767
@test gr isa GLR
6868
@test gr.penalty isa ScaledPenalty{L1Penalty}
6969
end
70+
71+
# see issue #71
72+
@testset "Logistic-m" begin
73+
X, y = MLJBase.make_blobs(centers=3)
74+
model = LogisticClassifier()
75+
mach = MLJBase.machine(model, X, y)
76+
fit!(mach)
77+
fp = MLJBase.fitted_params(mach)
78+
@test unique(fp.classes) == [1,2,3]
79+
end

0 commit comments

Comments
 (0)