Skip to content

Commit 4f8e799

Browse files
committed
avoid ambiguity in counting number of classes
1 parent 655a87a commit 4f8e799

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/mlj/interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function MMI.fit(m::Union{CLF_MODELS...}, verb::Int, X, y)
4545
sch = MMI.schema(X)
4646
features = (sch === nothing) ? nothing : sch.names
4747
yplain = convert.(Int, MMI.int(y))
48-
classes = MMI.classes(y[1])
48+
classes = MMI.classes(y[1])[unique(yplain)]
4949
nclasses = length(classes)
5050
if nclasses == 2
5151
# recode
@@ -69,7 +69,7 @@ function MMI.fit(m::Union{CLF_MODELS...}, verb::Int, X, y)
6969
end
7070

7171
function MMI.predict(m::Union{CLF_MODELS...}, (θ, features, c, classes), Xnew)
72-
Xmatrix = MMI.matrix(Xnew)
72+
Xmatrix = MMI.matrix(Xnew)
7373
preds = apply_X(Xmatrix, θ, c)
7474
# binary classification
7575
if c == 1

0 commit comments

Comments
 (0)