@@ -25,11 +25,9 @@ Base.show(stream::IO, c::TreePrinter) =
2525# # DECISION TREE CLASSIFIER
2626
2727# The following meets the MLJ standard for a `Model` docstring and is
28- # created without the use of interpolation so it can be used a
29- # template for authors of other MLJ model interfaces. The other
28+ # created without the use of interpolation so it can be used a # template for authors of other MLJ model interfaces. The other
3029# doc-strings, defined later, are generated using the `doc_header`
3130# utility to automatically generate the header, another option.
32-
3331MMI. @mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
3432 max_depth:: Int = (- )(1 ):: (_ ≥ -1)
3533 min_samples_leaf:: Int = 1 :: (_ ≥ 0)
@@ -39,6 +37,7 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
3937 post_prune:: Bool = false
4038 merge_purity_threshold:: Float64 = 1.0 :: (_ ≤ 1)
4139 display_depth:: Int = 5 :: (_ ≥ 1)
40+ feature_importance:: Symbol = :impurity :: (_ ∈ (:impurity, :split) )
4241 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
4342end
4443
@@ -73,8 +72,8 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
7372 cache = nothing
7473 report = (classes_seen= classes_seen,
7574 print_tree= TreePrinter (tree),
76- features= features)
77-
75+ features= features,
76+ )
7877 return fitresult, cache, report
7978end
8079
@@ -107,6 +106,8 @@ function MMI.predict(m::DecisionTreeClassifier, fitresult, Xnew)
107106 return MMI. UnivariateFinite (classes_seen, scores)
108107end
109108
109+ MMI. reports_feature_importances (:: Type{<:DecisionTreeClassifier} ) = true
110+
110111
111112# # RANDOM FOREST CLASSIFIER
112113
@@ -118,13 +119,21 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic
118119 n_subfeatures:: Int = (- )(1 ):: (_ ≥ -1)
119120 n_trees:: Int = 10 :: (_ ≥ 2)
120121 sampling_fraction:: Float64 = 0.7 :: (0 < _ ≤ 1)
122+ feature_importance:: Symbol = :impurity :: (_ ∈ (:impurity, :split) )
121123 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
122124end
123125
124126function MMI. fit (m:: RandomForestClassifier , verbosity:: Int , X, y)
127+ schema = Tables. schema (X)
125128 Xmatrix = MMI. matrix (X)
126129 yplain = MMI. int (y)
127130
131+ if schema === nothing
132+ features = [Symbol (" x$j " ) for j in 1 : size (Xmatrix, 2 )]
133+ else
134+ features = schema. names |> collect
135+ end
136+
128137 classes_seen = filter (in (unique (y)), MMI. classes (y[1 ]))
129138 integers_seen = MMI. int (classes_seen)
130139
@@ -138,7 +147,9 @@ function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y)
138147 m. min_purity_increase;
139148 rng= m. rng)
140149 cache = nothing
141- report = NamedTuple ()
150+
151+ report = (features= features,)
152+
142153 return (forest, classes_seen, integers_seen), cache, report
143154end
144155
@@ -151,25 +162,38 @@ function MMI.predict(m::RandomForestClassifier, fitresult, Xnew)
151162 return MMI. UnivariateFinite (classes_seen, scores)
152163end
153164
165+ MMI. reports_feature_importances (:: Type{<:RandomForestClassifier} ) = true
166+
154167
155168# # ADA BOOST STUMP CLASSIFIER
156169
157170MMI. @mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic
158171 n_iter:: Int = 10 :: (_ ≥ 1)
172+ feature_importance:: Symbol = :impurity :: (_ ∈ (:impurity, :split) )
159173 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
160174end
161175
162176function MMI. fit (m:: AdaBoostStumpClassifier , verbosity:: Int , X, y)
177+ schema = Tables. schema (X)
163178 Xmatrix = MMI. matrix (X)
164179 yplain = MMI. int (y)
165180
181+ if schema === nothing
182+ features = [Symbol (" x$j " ) for j in 1 : size (Xmatrix, 2 )]
183+ else
184+ features = schema. names |> collect
185+ end
186+
187+
166188 classes_seen = filter (in (unique (y)), MMI. classes (y[1 ]))
167189 integers_seen = MMI. int (classes_seen)
168190
169191 stumps, coefs =
170192 DT. build_adaboost_stumps (yplain, Xmatrix, m. n_iter, rng= m. rng)
171193 cache = nothing
172- report = NamedTuple ()
194+
195+ report = (features= features,)
196+
173197 return (stumps, coefs, classes_seen, integers_seen), cache, report
174198end
175199
@@ -184,6 +208,8 @@ function MMI.predict(m::AdaBoostStumpClassifier, fitresult, Xnew)
184208 return MMI. UnivariateFinite (classes_seen, scores)
185209end
186210
211+ MMI. reports_feature_importances (:: Type{<:AdaBoostStumpClassifier} ) = true
212+
187213
188214# # DECISION TREE REGRESSOR
189215
@@ -195,11 +221,20 @@ MMI.@mlj_model mutable struct DecisionTreeRegressor <: MMI.Deterministic
195221 n_subfeatures:: Int = 0 :: (_ ≥ -1)
196222 post_prune:: Bool = false
197223 merge_purity_threshold:: Float64 = 1.0 :: (0 ≤ _ ≤ 1)
224+ feature_importance:: Symbol = :impurity :: (_ ∈ (:impurity, :split) )
198225 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
199226end
200227
201228function MMI. fit (m:: DecisionTreeRegressor , verbosity:: Int , X, y)
229+ schema = Tables. schema (X)
202230 Xmatrix = MMI. matrix (X)
231+
232+ if schema === nothing
233+ features = [Symbol (" x$j " ) for j in 1 : size (Xmatrix, 2 )]
234+ else
235+ features = schema. names |> collect
236+ end
237+
203238 tree = DT. build_tree (float (y), Xmatrix,
204239 m. n_subfeatures,
205240 m. max_depth,
@@ -212,7 +247,9 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y)
212247 tree = DT. prune_tree (tree, m. merge_purity_threshold)
213248 end
214249 cache = nothing
215- report = NamedTuple ()
250+
251+ report = (features= features,)
252+
216253 return tree, cache, report
217254end
218255
@@ -223,6 +260,8 @@ function MMI.predict(::DecisionTreeRegressor, tree, Xnew)
223260 return DT. apply_tree (tree, Xmatrix)
224261end
225262
263+ MMI. reports_feature_importances (:: Type{<:DecisionTreeRegressor} ) = true
264+
226265
227266# # RANDOM FOREST REGRESSOR
228267
@@ -234,11 +273,20 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic
234273 n_subfeatures:: Int = (- )(1 ):: (_ ≥ -1)
235274 n_trees:: Int = 10 :: (_ ≥ 2)
236275 sampling_fraction:: Float64 = 0.7 :: (0 < _ ≤ 1)
276+ feature_importance:: Symbol = :impurity :: (_ ∈ (:impurity, :split) )
237277 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
238278end
239279
240280function MMI. fit (m:: RandomForestRegressor , verbosity:: Int , X, y)
281+ schema = Tables. schema (X)
241282 Xmatrix = MMI. matrix (X)
283+
284+ if schema === nothing
285+ features = [Symbol (" x$j " ) for j in 1 : size (Xmatrix, 2 )]
286+ else
287+ features = schema. names |> collect
288+ end
289+
242290 forest = DT. build_forest (float (y), Xmatrix,
243291 m. n_subfeatures,
244292 m. n_trees,
@@ -249,7 +297,8 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y)
249297 m. min_purity_increase,
250298 rng= m. rng)
251299 cache = nothing
252- report = NamedTuple ()
300+ report = (features= features,)
301+
253302 return forest, cache, report
254303end
255304
@@ -260,6 +309,34 @@ function MMI.predict(::RandomForestRegressor, forest, Xnew)
260309 return DT. apply_forest (forest, Xmatrix)
261310end
262311
312+ MMI. reports_feature_importances (:: Type{<:RandomForestRegressor} ) = true
313+
314+
315+ # # Feature Importances
316+
317+ # get actual arguments needed for importance calculation from various fitresults.
318+ get_fitresult (m:: Union{DecisionTreeClassifier, RandomForestClassifier} , fitresult) = (fitresult[1 ],)
319+ get_fitresult (m:: Union{DecisionTreeRegressor, RandomForestRegressor} , fitresult) = (fitresult,)
320+ get_fitresult (m:: AdaBoostStumpClassifier , fitresult)= (fitresult[1 ], fitresult[2 ])
321+
322+ function MMI. feature_importances (m:: Union{DecisionTreeClassifier, RandomForestClassifier, AdaBoostStumpClassifier, DecisionTreeRegressor, RandomForestRegressor} , fitresult, report)
323+ # generate feature importances for report
324+ if m. feature_importance == :impurity
325+ feature_importance_func = DT. impurity_importance
326+ elseif m. feature_importance == :split
327+ feature_importance_func = DT. split_importance
328+ end
329+
330+ mdl = get_fitresult (m, fitresult)
331+ features = report. features
332+ fi = feature_importance_func (mdl... , normalize= true )
333+ fi_pairs = Pair .(features, fi)
334+ # sort descending
335+ sort! (fi_pairs, by= x-> - x[2 ])
336+
337+ return fi_pairs
338+ end
339+
263340
264341# # METADATA (MODEL TRAITS)
265342
@@ -379,6 +456,8 @@ Train the machine using `fit!(mach, rows=...)`.
379456
380457- `display_depth=5`: max depth to show when displaying the tree
381458
459+ - `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
460+
382461- `rng=Random.GLOBAL_RNG`: random number generator or seed
383462
384463
@@ -512,6 +591,8 @@ Train the machine with `fit!(mach, rows=...)`.
512591
513592- `sampling_fraction=0.7` fraction of samples to train each tree on
514593
594+ - `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
595+
515596- `rng=Random.GLOBAL_RNG`: random number generator or seed
516597
517598
@@ -587,6 +668,9 @@ Train the machine with `fit!(mach, rows=...)`.
587668
588669- `n_iter=10`: number of iterations of AdaBoost
589670
671+ - `feature_importance`: method to use for computing feature importances. One of `(:impurity,
672+ :split)`
673+
590674- `rng=Random.GLOBAL_RNG`: random number generator or seed
591675
592676# Operations
@@ -678,6 +762,8 @@ Train the machine with `fit!(mach, rows=...)`.
678762- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
679763 combined purity `>= merge_purity_threshold`
680764
765+ - `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
766+
681767- `rng=Random.GLOBAL_RNG`: random number generator or seed
682768
683769
@@ -760,6 +846,8 @@ Train the machine with `fit!(mach, rows=...)`.
760846
761847- `sampling_fraction=0.7` fraction of samples to train each tree on
762848
849+ - `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
850+
763851- `rng=Random.GLOBAL_RNG`: random number generator or seed
764852
765853
0 commit comments