Skip to content

Commit 13b1919

Browse files
Fix MLJ Serialization, add test for single class classifiers, fix single class classifiers predict (#35)
* Fix MLJ Serialization * reformat * add test for single class classifiers, fix single class classifiers predict
1 parent 445ba4b commit 13b1919

File tree

3 files changed

+75
-19
lines changed

3 files changed

+75
-19
lines changed

src/mlj_catboostclassifier.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,11 @@ MMI.reports_feature_importances(::Type{<:CatBoostClassifier}) = true
100100
function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool)
101101
if fitresult[1] === nothing
102102
# Always predict the single class
103-
n = nrow(X_pool)
103+
n = pyconvert(Int, X_pool.shape[0])
104104
classes = [fitresult.single_class]
105105
probs = ones(n, 1)
106-
return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first)
106+
pool = MMI.categorical([fitresult.y_first])
107+
return MMI.UnivariateFinite(classes, probs; pool=pool)
107108
end
108109

109110
model, y_first = fitresult
@@ -116,8 +117,8 @@ end
116117
function MMI.predict_mode(mlj_model::CatBoostClassifier, fitresult, X_pool)
117118
if fitresult[1] === nothing
118119
# Return probability 1 for the single class
119-
n = nrow(X_pool)
120-
return hcat(ones(n), zeros(n))
120+
n = pyconvert(Int, X_pool.shape[0])
121+
return fill(fitresult.y_first, n)
121122
end
122123

123124
model, y_first = fitresult

src/mlj_serialization.jl

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,41 @@
11
# Taken from https://github.com/JuliaAI/MLJXGBoostInterface.jl
22
# It is likely also not the optimal method for serializing models, but it works
3+
34
"""
4-
_persistent(booster)
5+
_persistent(model::CatBoostModels, fitresult)
56
67
Private method.
78
89
Return a persistent (ie, Julia-serializable) representation of the
9-
CatBoost.jl model `booster`.
10+
CatBoost.jl model `fitresult`.
1011
11-
Restore the model with [`booster`](@ref)
12+
Restore the model with [`fitresult`](@ref)
1213
"""
13-
function _persistent(booster)
14+
function _persistent(::CatBoostRegressor, fitresult)
1415
ctb_file, io = mktemp()
1516
close(io)
1617

17-
booster.save_model(ctb_file)
18+
fitresult.save_model(ctb_file)
1819
persistent_booster = read(ctb_file)
1920
rm(ctb_file)
2021
return persistent_booster
2122
end
23+
function _persistent(::CatBoostClassifier, fitresult)
24+
model, y_first = fitresult
25+
if model === nothing
26+
# Case 1: Single unique class
27+
return (nothing, fitresult.single_class, y_first)
28+
else
29+
# Case 2: Multiple unique classes
30+
ctb_file, io = mktemp()
31+
close(io)
32+
33+
model.save_model(ctb_file)
34+
persistent_booster = read(ctb_file)
35+
rm(ctb_file)
36+
return (persistent_booster, y_first)
37+
end
38+
end
2239

2340
"""
2441
_booster(persistent)
@@ -28,24 +45,45 @@ Private method.
2845
Return the CatBoost.jl model which has `persistent` as its persistent
2946
(Julia-serializable) representation. See [`persistent`](@ref) method.
3047
"""
31-
function _booster(persistent)
48+
function _booster(::CatBoostRegressor, persistent)
49+
ctb_file, io = mktemp()
50+
write(io, persistent)
51+
close(io)
52+
53+
booster = catboost.CatBoostRegressor().load_model(ctb_file)
54+
55+
rm(ctb_file)
56+
57+
return booster
58+
end
59+
function _booster(::CatBoostClassifier, persistent)
3260
ctb_file, io = mktemp()
3361
write(io, persistent)
3462
close(io)
3563

36-
booster = catboost.CatBoost().load_model(ctb_file)
64+
booster = catboost.CatBoostClassifier().load_model(ctb_file)
3765

3866
rm(ctb_file)
3967

4068
return booster
4169
end
4270

43-
function MMI.save(::CatBoostModels, fr; kw...)
44-
(booster, a_target_element) = fr
45-
return (_persistent(booster), a_target_element)
71+
function MMI.save(model::CatBoostModels, fitresult; kwargs...)
72+
return _persistent(model, fitresult)
73+
end
74+
75+
function MMI.restore(model::CatBoostRegressor, serializable_fitresult)
76+
return _booster(model, serializable_fitresult)
4677
end
4778

48-
function MMI.restore(::CatBoostModels, fr)
49-
(persistent, a_target_element) = fr
50-
return (_booster(persistent), a_target_element)
79+
function MMI.restore(model::CatBoostClassifier, serializable_fitresult)
80+
if serializable_fitresult[1] === nothing
81+
# Case 1: Single unique class
82+
return (model=nothing, single_class=serializable_fitresult[2],
83+
y_first=serializable_fitresult[3])
84+
else
85+
# Case 2: Multiple unique classes
86+
persistent_booster, y_first = serializable_fitresult
87+
return (_booster(model, persistent_booster), y_first)
88+
end
5189
end

test/mlj_interface.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,24 @@
2222
preds = MLJBase.predict(mach, X)
2323
probs = MLJBase.predict_mode(mach, X)
2424

25-
serializable_fitresult = MLJBase.save(mach, mach.fitresult)
25+
serializable_fitresult = MLJBase.save(mach, mach)
26+
restored_fitresult = MLJBase.restore(mach, serializable_fitresult)
27+
end
28+
29+
@testset "CatBoostClassifier - single class" begin
30+
X = (; a=[1, 4, 5, 6], b=[4, 5, 6, 7])
31+
y = [0, 0, 0, 0]
32+
33+
# MLJ Interface
34+
model = CatBoostClassifier(; iterations=5)
35+
mach = machine(model, X, y)
36+
MLJBase.fit!(mach)
37+
preds = MLJBase.predict(mach, X)
38+
println(preds)
39+
probs = MLJBase.predict_mode(mach, X)
40+
println(probs)
41+
42+
serializable_fitresult = MLJBase.save(mach, mach)
2643
restored_fitresult = MLJBase.restore(mach, serializable_fitresult)
2744
end
2845

@@ -36,7 +53,7 @@
3653
MLJBase.fit!(mach)
3754
preds = MLJBase.predict(mach, X)
3855

39-
serializable_fitresult = MLJBase.save(mach, mach.fitresult)
56+
serializable_fitresult = MLJBase.save(mach, mach)
4057
restored_fitresult = MLJBase.restore(mach, serializable_fitresult)
4158
end
4259

0 commit comments

Comments
 (0)