Skip to content

Commit 10147c8

Browse files
authored
Merge pull request #29 from JuliaAI/dev
For a 0.2.4 release
2 parents 9fbfad6 + 3735f6f commit 10147c8

File tree

3 files changed

+177
-10
lines changed

3 files changed

+177
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJDecisionTreeInterface"
22
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.2.3"
4+
version = "0.2.4"
55

66
[deps]
77
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"

src/MLJDecisionTreeInterface.jl

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@ Base.show(stream::IO, c::TreePrinter) =
2525
# # DECISION TREE CLASSIFIER
2626

2727
# The following meets the MLJ standard for a `Model` docstring and is
28-
# created without the use of interpolation so it can be used a
29-
# template for authors of other MLJ model interfaces. The other
28+
# created without the use of interpolation so it can be used a # template for authors of other MLJ model interfaces. The other
3029
# doc-strings, defined later, are generated using the `doc_header`
3130
# utility to automatically generate the header, another option.
32-
3331
MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
3432
max_depth::Int = (-)(1)::(_ ≥ -1)
3533
min_samples_leaf::Int = 1::(_ ≥ 0)
@@ -39,6 +37,7 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
3937
post_prune::Bool = false
4038
merge_purity_threshold::Float64 = 1.0::(_ ≤ 1)
4139
display_depth::Int = 5::(_ ≥ 1)
40+
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
4241
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
4342
end
4443

@@ -73,8 +72,8 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
7372
cache = nothing
7473
report = (classes_seen=classes_seen,
7574
print_tree=TreePrinter(tree),
76-
features=features)
77-
75+
features=features,
76+
)
7877
return fitresult, cache, report
7978
end
8079

@@ -107,6 +106,8 @@ function MMI.predict(m::DecisionTreeClassifier, fitresult, Xnew)
107106
return MMI.UnivariateFinite(classes_seen, scores)
108107
end
109108

109+
MMI.reports_feature_importances(::Type{<:DecisionTreeClassifier}) = true
110+
110111

111112
# # RANDOM FOREST CLASSIFIER
112113

@@ -118,13 +119,21 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic
118119
n_subfeatures::Int = (-)(1)::(_ ≥ -1)
119120
n_trees::Int = 10::(_ ≥ 2)
120121
sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1)
122+
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
121123
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
122124
end
123125

124126
function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y)
127+
schema = Tables.schema(X)
125128
Xmatrix = MMI.matrix(X)
126129
yplain = MMI.int(y)
127130

131+
if schema === nothing
132+
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
133+
else
134+
features = schema.names |> collect
135+
end
136+
128137
classes_seen = filter(in(unique(y)), MMI.classes(y[1]))
129138
integers_seen = MMI.int(classes_seen)
130139

@@ -138,7 +147,9 @@ function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y)
138147
m.min_purity_increase;
139148
rng=m.rng)
140149
cache = nothing
141-
report = NamedTuple()
150+
151+
report = (features=features,)
152+
142153
return (forest, classes_seen, integers_seen), cache, report
143154
end
144155

@@ -151,25 +162,38 @@ function MMI.predict(m::RandomForestClassifier, fitresult, Xnew)
151162
return MMI.UnivariateFinite(classes_seen, scores)
152163
end
153164

165+
MMI.reports_feature_importances(::Type{<:RandomForestClassifier}) = true
166+
154167

155168
# # ADA BOOST STUMP CLASSIFIER
156169

157170
MMI.@mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic
158171
n_iter::Int = 10::(_ ≥ 1)
172+
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
159173
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
160174
end
161175

162176
function MMI.fit(m::AdaBoostStumpClassifier, verbosity::Int, X, y)
177+
schema = Tables.schema(X)
163178
Xmatrix = MMI.matrix(X)
164179
yplain = MMI.int(y)
165180

181+
if schema === nothing
182+
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
183+
else
184+
features = schema.names |> collect
185+
end
186+
187+
166188
classes_seen = filter(in(unique(y)), MMI.classes(y[1]))
167189
integers_seen = MMI.int(classes_seen)
168190

169191
stumps, coefs =
170192
DT.build_adaboost_stumps(yplain, Xmatrix, m.n_iter, rng=m.rng)
171193
cache = nothing
172-
report = NamedTuple()
194+
195+
report = (features=features,)
196+
173197
return (stumps, coefs, classes_seen, integers_seen), cache, report
174198
end
175199

@@ -184,6 +208,8 @@ function MMI.predict(m::AdaBoostStumpClassifier, fitresult, Xnew)
184208
return MMI.UnivariateFinite(classes_seen, scores)
185209
end
186210

211+
MMI.reports_feature_importances(::Type{<:AdaBoostStumpClassifier}) = true
212+
187213

188214
# # DECISION TREE REGRESSOR
189215

@@ -195,11 +221,20 @@ MMI.@mlj_model mutable struct DecisionTreeRegressor <: MMI.Deterministic
195221
n_subfeatures::Int = 0::(_ ≥ -1)
196222
post_prune::Bool = false
197223
merge_purity_threshold::Float64 = 1.0::(0 ≤ _ ≤ 1)
224+
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
198225
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
199226
end
200227

201228
function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y)
229+
schema = Tables.schema(X)
202230
Xmatrix = MMI.matrix(X)
231+
232+
if schema === nothing
233+
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
234+
else
235+
features = schema.names |> collect
236+
end
237+
203238
tree = DT.build_tree(float(y), Xmatrix,
204239
m.n_subfeatures,
205240
m.max_depth,
@@ -212,7 +247,9 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y)
212247
tree = DT.prune_tree(tree, m.merge_purity_threshold)
213248
end
214249
cache = nothing
215-
report = NamedTuple()
250+
251+
report = (features=features,)
252+
216253
return tree, cache, report
217254
end
218255

@@ -223,6 +260,8 @@ function MMI.predict(::DecisionTreeRegressor, tree, Xnew)
223260
return DT.apply_tree(tree, Xmatrix)
224261
end
225262

263+
MMI.reports_feature_importances(::Type{<:DecisionTreeRegressor}) = true
264+
226265

227266
# # RANDOM FOREST REGRESSOR
228267

@@ -234,11 +273,20 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic
234273
n_subfeatures::Int = (-)(1)::(_ ≥ -1)
235274
n_trees::Int = 10::(_ ≥ 2)
236275
sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1)
276+
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
237277
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
238278
end
239279

240280
function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y)
281+
schema = Tables.schema(X)
241282
Xmatrix = MMI.matrix(X)
283+
284+
if schema === nothing
285+
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
286+
else
287+
features = schema.names |> collect
288+
end
289+
242290
forest = DT.build_forest(float(y), Xmatrix,
243291
m.n_subfeatures,
244292
m.n_trees,
@@ -249,7 +297,8 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y)
249297
m.min_purity_increase,
250298
rng=m.rng)
251299
cache = nothing
252-
report = NamedTuple()
300+
report = (features=features,)
301+
253302
return forest, cache, report
254303
end
255304

@@ -260,6 +309,34 @@ function MMI.predict(::RandomForestRegressor, forest, Xnew)
260309
return DT.apply_forest(forest, Xmatrix)
261310
end
262311

312+
MMI.reports_feature_importances(::Type{<:RandomForestRegressor}) = true
313+
314+
315+
# # Feature Importances
316+
317+
# get actual arguments needed for importance calculation from various fitresults.
318+
get_fitresult(m::Union{DecisionTreeClassifier, RandomForestClassifier}, fitresult) = (fitresult[1],)
319+
get_fitresult(m::Union{DecisionTreeRegressor, RandomForestRegressor}, fitresult) = (fitresult,)
320+
get_fitresult(m::AdaBoostStumpClassifier, fitresult)= (fitresult[1], fitresult[2])
321+
322+
function MMI.feature_importances(m::Union{DecisionTreeClassifier, RandomForestClassifier, AdaBoostStumpClassifier, DecisionTreeRegressor, RandomForestRegressor}, fitresult, report)
323+
# generate feature importances for report
324+
if m.feature_importance == :impurity
325+
feature_importance_func = DT.impurity_importance
326+
elseif m.feature_importance == :split
327+
feature_importance_func = DT.split_importance
328+
end
329+
330+
mdl = get_fitresult(m, fitresult)
331+
features = report.features
332+
fi = feature_importance_func(mdl..., normalize=true)
333+
fi_pairs = Pair.(features, fi)
334+
# sort descending
335+
sort!(fi_pairs, by= x->-x[2])
336+
337+
return fi_pairs
338+
end
339+
263340

264341
# # METADATA (MODEL TRAITS)
265342

@@ -379,6 +456,8 @@ Train the machine using `fit!(mach, rows=...)`.
379456
380457
- `display_depth=5`: max depth to show when displaying the tree
381458
459+
- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
460+
382461
- `rng=Random.GLOBAL_RNG`: random number generator or seed
383462
384463
@@ -512,6 +591,8 @@ Train the machine with `fit!(mach, rows=...)`.
512591
513592
- `sampling_fraction=0.7` fraction of samples to train each tree on
514593
594+
- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
595+
515596
- `rng=Random.GLOBAL_RNG`: random number generator or seed
516597
517598
@@ -587,6 +668,9 @@ Train the machine with `fit!(mach, rows=...)`.
587668
588669
- `n_iter=10`: number of iterations of AdaBoost
589670
671+
- `feature_importance`: method to use for computing feature importances. One of `(:impurity,
672+
:split)`
673+
590674
- `rng=Random.GLOBAL_RNG`: random number generator or seed
591675
592676
# Operations
@@ -678,6 +762,8 @@ Train the machine with `fit!(mach, rows=...)`.
678762
- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
679763
combined purity `>= merge_purity_threshold`
680764
765+
- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
766+
681767
- `rng=Random.GLOBAL_RNG`: random number generator or seed
682768
683769
@@ -760,6 +846,8 @@ Train the machine with `fit!(mach, rows=...)`.
760846
761847
- `sampling_fraction=0.7` fraction of samples to train each tree on
762848
849+
- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
850+
763851
- `rng=Random.GLOBAL_RNG`: random number generator or seed
764852
765853

test/runtests.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ yyhat = predict_mode(baretree, fitresult, MLJBase.selectrows(X, 1:3))
4545
@test Set(report.classes_seen) == Set(levels(y))
4646
@test report.print_tree(2) === nothing # :-(
4747
@test report.features == [:sepal_length, :sepal_width, :petal_length, :petal_width]
48+
4849
fp = fitted_params(baretree, fitresult)
4950
@test Set([:tree, :encoding, :features]) == Set(keys(fp))
5051
@test fp.features == report.features
@@ -155,3 +156,81 @@ end
155156
@test reproducibility(model, X, y, loss)
156157
end
157158
end
159+
160+
161+
@testset "feature importance defined" begin
162+
for model [
163+
DecisionTreeClassifier(),
164+
RandomForestClassifier(),
165+
AdaBoostStumpClassifier(),
166+
DecisionTreeRegressor(),
167+
RandomForestRegressor(),
168+
]
169+
170+
@test reports_feature_importances(model) == true
171+
end
172+
end
173+
174+
175+
176+
@testset "impurity importance" begin
177+
178+
X, y = MLJBase.make_blobs(100, 3; rng=stable_rng())
179+
180+
for model [
181+
DecisionTreeClassifier(),
182+
RandomForestClassifier(),
183+
AdaBoostStumpClassifier(),
184+
]
185+
m = machine(model, X, y)
186+
fit!(m)
187+
rpt = MLJBase.report(m)
188+
fi = MLJBase.feature_importances(model, m.fitresult, rpt)
189+
@test size(fi,1) == 3
190+
end
191+
192+
193+
X, y = make_regression(100,3; rng=stable_rng());
194+
for model in [
195+
DecisionTreeRegressor(),
196+
RandomForestRegressor(),
197+
]
198+
m = machine(model, X, y)
199+
fit!(m)
200+
rpt = MLJBase.report(m)
201+
fi = MLJBase.feature_importances(model, m.fitresult, rpt)
202+
@test size(fi,1) == 3
203+
end
204+
end
205+
206+
207+
@testset "split importance" begin
208+
X, y = MLJBase.make_blobs(100, 3; rng=stable_rng())
209+
210+
for model [
211+
DecisionTreeClassifier(feature_importance=:split),
212+
RandomForestClassifier(feature_importance=:split),
213+
AdaBoostStumpClassifier(feature_importance=:split),
214+
]
215+
m = machine(model, X, y)
216+
fit!(m)
217+
rpt = MLJBase.report(m)
218+
fi = MLJBase.feature_importances(model, m.fitresult, rpt)
219+
@test size(fi,1) == 3
220+
end
221+
222+
223+
X, y = make_regression(100,3; rng=stable_rng());
224+
for model in [
225+
DecisionTreeRegressor(feature_importance=:split),
226+
RandomForestRegressor(feature_importance=:split),
227+
]
228+
m = machine(model, X, y)
229+
fit!(m)
230+
rpt = MLJBase.report(m)
231+
fi = MLJBase.feature_importances(model, m.fitresult, rpt)
232+
@test size(fi,1) == 3
233+
end
234+
end
235+
236+

0 commit comments

Comments
 (0)