1
1
# Taken from https://github.com/JuliaAI/MLJXGBoostInterface.jl
2
2
# It is likely also not the optimal method for serializing models, but it works
3
+
3
4
"""
4
- _persistent(booster )
5
+ _persistent(model::CatBoostModels, fitresult )
5
6
6
7
Private method.
7
8
8
9
Return a persistent (ie, Julia-serializable) representation of the
9
- CatBoost.jl model `booster `.
10
+ CatBoost.jl model `fitresult `.
10
11
11
- Restore the model with [`booster `](@ref)
12
+ Restore the model with [`fitresult `](@ref)
12
13
"""
13
- function _persistent (booster )
14
+ function _persistent (:: CatBoostRegressor , fitresult )
14
15
ctb_file, io = mktemp ()
15
16
close (io)
16
17
17
- booster . save_model (ctb_file)
18
+ fitresult . save_model (ctb_file)
18
19
persistent_booster = read (ctb_file)
19
20
rm (ctb_file)
20
21
return persistent_booster
21
22
end
23
+ function _persistent (:: CatBoostClassifier , fitresult)
24
+ model, y_first = fitresult
25
+ if model === nothing
26
+ # Case 1: Single unique class
27
+ return (nothing , fitresult. single_class, y_first)
28
+ else
29
+ # Case 2: Multiple unique classes
30
+ ctb_file, io = mktemp ()
31
+ close (io)
32
+
33
+ model. save_model (ctb_file)
34
+ persistent_booster = read (ctb_file)
35
+ rm (ctb_file)
36
+ return (persistent_booster, y_first)
37
+ end
38
+ end
22
39
23
40
"""
24
41
_booster(persistent)
@@ -28,24 +45,45 @@ Private method.
28
45
Return the CatBoost.jl model which has `persistent` as its persistent
29
46
(Julia-serializable) representation. See [`persistent`](@ref) method.
30
47
"""
31
- function _booster (persistent)
48
+ function _booster (:: CatBoostRegressor , persistent)
49
+ ctb_file, io = mktemp ()
50
+ write (io, persistent)
51
+ close (io)
52
+
53
+ booster = catboost. CatBoostRegressor (). load_model (ctb_file)
54
+
55
+ rm (ctb_file)
56
+
57
+ return booster
58
+ end
59
+ function _booster (:: CatBoostClassifier , persistent)
32
60
ctb_file, io = mktemp ()
33
61
write (io, persistent)
34
62
close (io)
35
63
36
- booster = catboost. CatBoost (). load_model (ctb_file)
64
+ booster = catboost. CatBoostClassifier (). load_model (ctb_file)
37
65
38
66
rm (ctb_file)
39
67
40
68
return booster
41
69
end
42
70
43
- function MMI. save (:: CatBoostModels , fr; kw... )
44
- (booster, a_target_element) = fr
45
- return (_persistent (booster), a_target_element)
71
+ function MMI. save (model:: CatBoostModels , fitresult; kwargs... )
72
+ return _persistent (model, fitresult)
73
+ end
74
+
75
+ function MMI. restore (model:: CatBoostRegressor , serializable_fitresult)
76
+ return _booster (model, serializable_fitresult)
46
77
end
47
78
48
- function MMI. restore (:: CatBoostModels , fr)
49
- (persistent, a_target_element) = fr
50
- return (_booster (persistent), a_target_element)
79
+ function MMI. restore (model:: CatBoostClassifier , serializable_fitresult)
80
+ if serializable_fitresult[1 ] === nothing
81
+ # Case 1: Single unique class
82
+ return (model= nothing , single_class= serializable_fitresult[2 ],
83
+ y_first= serializable_fitresult[3 ])
84
+ else
85
+ # Case 2: Multiple unique classes
86
+ persistent_booster, y_first = serializable_fitresult
87
+ return (_booster (model, persistent_booster), y_first)
88
+ end
51
89
end
0 commit comments