@@ -4,6 +4,7 @@ import MLJModelInterface
44using MLJModelInterface. ScientificTypesBase
55import DecisionTree
66import Tables
7+ using CategoricalArrays
78
89using Random
910import Random. GLOBAL_RNG
2122Base. show (stream:: IO , c:: TreePrinter ) =
2223 print (stream, " TreePrinter object (call with display depth)" )
2324
25+ function classes (y)
26+ p = CategoricalArrays. pool (y)
27+ [p[i] for i in 1 : length (p)]
28+ end
2429
2530# # DECISION TREE CLASSIFIER
2631
27- # The following meets the MLJ standard for a `Model` docstring and is
28- # created without the use of interpolation so it can be used a # template for authors of other MLJ model interfaces. The other
29- # doc-strings, defined later, are generated using the `doc_header`
30- # utility to automatically generate the header, another option.
3132MMI. @mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
3233 max_depth:: Int = (- )(1 ):: (_ ≥ -1)
3334 min_samples_leaf:: Int = 1 :: (_ ≥ 0)
@@ -41,19 +42,17 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
4142 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
4243end
4344
44- function MMI. fit (m:: DecisionTreeClassifier , verbosity:: Int , X, y)
45- schema = Tables. schema (X)
46- Xmatrix = MMI. matrix (X)
47- yplain = MMI. int (y)
48-
49- if schema === nothing
50- features = [Symbol (" x$j " ) for j in 1 : size (Xmatrix, 2 )]
51- else
52- features = schema. names |> collect
53- end
45+ function MMI. fit (
46+ m:: DecisionTreeClassifier ,
47+ verbosity:: Int ,
48+ Xmatrix,
49+ yplain,
50+ features,
51+ classes,
52+ )
5453
55- classes_seen = filter ( in ( unique (y)), MMI . classes (y[ 1 ]) )
56- integers_seen = MMI. int (classes_seen )
54+ integers_seen = unique (yplain )
55+ classes_seen = MMI. decoder (classes)(integers_seen )
5756
5857 tree = DT. build_tree (yplain, Xmatrix,
5958 m. n_subfeatures,
@@ -70,40 +69,26 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
7069 fitresult = (tree, classes_seen, integers_seen, features)
7170
7271 cache = nothing
73- report = (classes_seen= classes_seen,
74- print_tree= TreePrinter (tree),
75- features= features,
76- )
72+ report = (
73+ classes_seen= classes_seen,
74+ print_tree= TreePrinter (tree),
75+ features= features,
76+ )
7777 return fitresult, cache, report
7878end
7979
80- function get_encoding (classes_seen)
81- a_cat_element = classes_seen[1 ]
82- return Dict (MMI. int (c) => c for c in MMI. classes (a_cat_element))
83- end
80+ get_encoding (classes_seen) = Dict (MMI. int (c) => c for c in classes (classes_seen))
8481
8582MMI. fitted_params (:: DecisionTreeClassifier , fitresult) =
8683 (tree= fitresult[1 ],
8784 encoding= get_encoding (fitresult[2 ]),
8885 features= fitresult[4 ])
8986
90- function smooth (scores, smoothing)
91- iszero (smoothing) && return scores
92- threshold = smoothing / size (scores, 2 )
93- # clip low values
94- scores[scores .< threshold] .= threshold
95- # normalize
96- return scores ./ sum (scores, dims= 2 )
97- end
98-
9987function MMI. predict (m:: DecisionTreeClassifier , fitresult, Xnew)
100- Xmatrix = MMI. matrix (Xnew)
10188 tree, classes_seen, integers_seen = fitresult
10289 # retrieve the predicted scores
103- scores = DT. apply_tree_proba (tree, Xmatrix, integers_seen)
104-
105- # return vector of UF
106- return MMI. UnivariateFinite (classes_seen, scores)
90+ scores = DT. apply_tree_proba (tree, Xnew, integers_seen)
91+ MMI. UnivariateFinite (classes_seen, scores)
10792end
10893
10994MMI. reports_feature_importances (:: Type{<:DecisionTreeClassifier} ) = true
@@ -123,19 +108,17 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic
123108 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
124109end
125110
126- function MMI. fit (m:: RandomForestClassifier , verbosity:: Int , X, y)
127- schema = Tables. schema (X)
128- Xmatrix = MMI. matrix (X)
129- yplain = MMI. int (y)
130-
131- if schema === nothing
132- features = [Symbol (" x$j " ) for j in 1 : size (Xmatrix, 2 )]
133- else
134- features = schema. names |> collect
135- end
111+ function MMI. fit (
112+ m:: RandomForestClassifier ,
113+ verbosity:: Int ,
114+ Xmatrix,
115+ yplain,
116+ features,
117+ classes,
118+ )
136119
137- classes_seen = filter ( in ( unique (y)), MMI . classes (y[ 1 ]) )
138- integers_seen = MMI. int (classes_seen )
120+ integers_seen = unique (yplain )
121+ classes_seen = MMI. decoder (classes)(integers_seen )
139122
140123 forest = DT. build_forest (yplain, Xmatrix,
141124 m. n_subfeatures,
156139MMI. fitted_params (:: RandomForestClassifier , (forest,_)) = (forest= forest,)
157140
158141function MMI. predict (m:: RandomForestClassifier , fitresult, Xnew)
159- Xmatrix = MMI. matrix (Xnew)
160142 forest, classes_seen, integers_seen = fitresult
161- scores = DT. apply_forest_proba (forest, Xmatrix , integers_seen)
162- return MMI. UnivariateFinite (classes_seen, scores)
143+ scores = DT. apply_forest_proba (forest, Xnew , integers_seen)
144+ MMI. UnivariateFinite (classes_seen, scores)
163145end
164146
165147MMI. reports_feature_importances (:: Type{<:RandomForestClassifier} ) = true
@@ -173,20 +155,17 @@ MMI.@mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic
173155 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
174156end
175157
176- function MMI. fit (m:: AdaBoostStumpClassifier , verbosity:: Int , X, y)
177- schema = Tables. schema (X)
178- Xmatrix = MMI. matrix (X)
179- yplain = MMI. int (y)
180-
181- if schema === nothing
182- features = [Symbol (" x$j " ) for j in 1 : size (Xmatrix, 2 )]
183- else
184- features = schema. names |> collect
185- end
158+ function MMI. fit (
159+ m:: AdaBoostStumpClassifier ,
160+ verbosity:: Int ,
161+ Xmatrix,
162+ yplain,
163+ features,
164+ classes,
165+ )
186166
187-
188- classes_seen = filter (in (unique (y)), MMI. classes (y[1 ]))
189- integers_seen = MMI. int (classes_seen)
167+ integers_seen = unique (yplain)
168+ classes_seen = MMI. decoder (classes)(integers_seen)
190169
191170 stumps, coefs =
192171 DT. build_adaboost_stumps (yplain, Xmatrix, m. n_iter, rng= m. rng)
@@ -201,10 +180,13 @@ MMI.fitted_params(::AdaBoostStumpClassifier, (stumps,coefs,_)) =
201180 (stumps= stumps,coefs= coefs)
202181
203182function MMI. predict (m:: AdaBoostStumpClassifier , fitresult, Xnew)
204- Xmatrix = MMI. matrix (Xnew)
205183 stumps, coefs, classes_seen, integers_seen = fitresult
206- scores = DT. apply_adaboost_stumps_proba (stumps, coefs,
207- Xmatrix, integers_seen)
184+ scores = DT. apply_adaboost_stumps_proba (
185+ stumps,
186+ coefs,
187+ Xnew,
188+ integers_seen,
189+ )
208190 return MMI. UnivariateFinite (classes_seen, scores)
209191end
210192
@@ -225,23 +207,18 @@ MMI.@mlj_model mutable struct DecisionTreeRegressor <: MMI.Deterministic
225207 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
226208end
227209
228- function MMI. fit (m:: DecisionTreeRegressor , verbosity:: Int , X, y)
229- schema = Tables. schema (X)
230- Xmatrix = MMI. matrix (X)
210+ function MMI. fit (m:: DecisionTreeRegressor , verbosity:: Int , Xmatrix, y, features)
231211
232- if schema === nothing
233- features = [Symbol (" x$j " ) for j in 1 : size (Xmatrix, 2 )]
234- else
235- features = schema. names |> collect
236- end
237-
238- tree = DT. build_tree (float (y), Xmatrix,
239- m. n_subfeatures,
240- m. max_depth,
241- m. min_samples_leaf,
242- m. min_samples_split,
243- m. min_purity_increase;
244- rng= m. rng)
212+ tree = DT. build_tree (
213+ y,
214+ Xmatrix,
215+ m. n_subfeatures,
216+ m. max_depth,
217+ m. min_samples_leaf,
218+ m. min_samples_split,
219+ m. min_purity_increase;
220+ rng= m. rng
221+ )
245222
246223 if m. post_prune
247224 tree = DT. prune_tree (tree, m. merge_purity_threshold)
255232
256233MMI. fitted_params (:: DecisionTreeRegressor , tree) = (tree= tree,)
257234
258- function MMI. predict (:: DecisionTreeRegressor , tree, Xnew)
259- Xmatrix = MMI. matrix (Xnew)
260- return DT. apply_tree (tree, Xmatrix)
261- end
235+ MMI. predict (:: DecisionTreeRegressor , tree, Xnew) = DT. apply_tree (tree, Xnew)
262236
263237MMI. reports_feature_importances (:: Type{<:DecisionTreeRegressor} ) = true
264238
@@ -277,25 +251,21 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic
277251 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
278252end
279253
280- function MMI. fit (m:: RandomForestRegressor , verbosity:: Int , X, y)
281- schema = Tables. schema (X)
282- Xmatrix = MMI. matrix (X)
283-
284- if schema === nothing
285- features = [Symbol (" x$j " ) for j in 1 : size (Xmatrix, 2 )]
286- else
287- features = schema. names |> collect
288- end
254+ function MMI. fit (m:: RandomForestRegressor , verbosity:: Int , Xmatrix, y, features)
255+
256+ forest = DT. build_forest (
257+ y,
258+ Xmatrix,
259+ m. n_subfeatures,
260+ m. n_trees,
261+ m. sampling_fraction,
262+ m. max_depth,
263+ m. min_samples_leaf,
264+ m. min_samples_split,
265+ m. min_purity_increase,
266+ rng= m. rng
267+ )
289268
290- forest = DT. build_forest (float (y), Xmatrix,
291- m. n_subfeatures,
292- m. n_trees,
293- m. sampling_fraction,
294- m. max_depth,
295- m. min_samples_leaf,
296- m. min_samples_split,
297- m. min_purity_increase,
298- rng= m. rng)
299269 cache = nothing
300270 report = (features= features,)
301271
304274
305275MMI. fitted_params (:: RandomForestRegressor , forest) = (forest= forest,)
306276
307- function MMI. predict (:: RandomForestRegressor , forest, Xnew)
308- Xmatrix = MMI. matrix (Xnew)
309- return DT. apply_forest (forest, Xmatrix)
310- end
277+ MMI. predict (:: RandomForestRegressor , forest, Xnew) = DT. apply_forest (forest, Xnew)
311278
312279MMI. reports_feature_importances (:: Type{<:RandomForestRegressor} ) = true
313280
@@ -327,7 +294,39 @@ const IterativeModel = Union{
327294 AdaBoostStumpClassifier,
328295}
329296
330- const RandomForestModel = Union{DecisionTreeClassifier, RandomForestClassifier}
297+ const Classifier = Union{
298+ DecisionTreeClassifier,
299+ RandomForestClassifier,
300+ AdaBoostStumpClassifier,
301+ }
302+
303+ const Regressor = Union{
304+ DecisionTreeRegressor,
305+ RandomForestRegressor,
306+ }
307+
308+ const RandomForestModel = Union{
309+ DecisionTreeClassifier,
310+ RandomForestClassifier,
311+ }
312+
313+
314+ # # DATA FRONT END
315+
316+ _columnnames (X) = Tables. columnnames (Tables. columns (X)) |> collect
317+
318+ # for fit:
319+ MMI. reformat (:: Classifier , X, y) =
320+ (Tables. matrix (X), MMI. int (y), _columnnames (X), classes (y))
321+ MMI. reformat (:: Regressor , X, y) =
322+ (Tables. matrix (X), float (y), _columnnames (X))
323+ MMI. selectrows (:: TreeModel , I, Xmatrix, y, meta... ) =
324+ (view (Xmatrix, I, :), view (y, I), meta... )
325+
326+ # for predict:
327+ MMI. reformat (:: TreeModel , X) = (Tables. matrix (X),)
328+ MMI. selectrows (:: TreeModel , I, Xmatrix) = (view (Xmatrix, I, :),)
329+
331330
332331# # FEATURE IMPORTANCES
333332
0 commit comments