Skip to content

Commit b925c00

Browse files
authored
Merge pull request #17 from JuliaAI/dev
For a 0.2.0 release
2 parents 274b2a9 + fd7ceed commit b925c00

File tree

4 files changed

+38
-34
lines changed

4 files changed

+38
-34
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ on:
33
pull_request:
44
branches:
55
- master
6+
- dev
67
push:
78
branches:
89
- master
10+
- dev
911
tags: '*'
1012
jobs:
1113
test:

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
name = "MLJDecisionTreeInterface"
22
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.4"
4+
version = "0.2.0"
55

66
[deps]
77
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
88
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
10+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1011

1112
[compat]
1213
DecisionTree = "0.10"
1314
MLJModelInterface = "^0.3,^0.4, 1.0"
15+
Tables = "1.6"
1416
julia = "1.6"
1517

1618
[extras]

src/MLJDecisionTreeInterface.jl

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module MLJDecisionTreeInterface
33
import MLJModelInterface
44
using MLJModelInterface.ScientificTypesBase
55
import DecisionTree
6+
import Tables
67

78
using Random
89
import Random.GLOBAL_RNG
@@ -37,15 +38,21 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
3738
n_subfeatures::Int = 0::(_ ≥ -1)
3839
post_prune::Bool = false
3940
merge_purity_threshold::Float64 = 1.0::(_ ≤ 1)
40-
pdf_smoothing::Float64 = 0.0::(0 ≤ _ ≤ 1)
4141
display_depth::Int = 5::(_ ≥ 1)
4242
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
4343
end
4444

4545
function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
46+
schema = Tables.schema(X)
4647
Xmatrix = MMI.matrix(X)
4748
yplain = MMI.int(y)
4849

50+
if schema === nothing
51+
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
52+
else
53+
features = schema.names |> collect
54+
end
55+
4956
classes_seen = filter(in(unique(y)), MMI.classes(y[1]))
5057
integers_seen = MMI.int(classes_seen)
5158

@@ -61,11 +68,12 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
6168
end
6269
verbosity < 2 || DT.print_tree(tree, m.display_depth)
6370

64-
fitresult = (tree, classes_seen, integers_seen)
71+
fitresult = (tree, classes_seen, integers_seen, features)
6572

6673
cache = nothing
6774
report = (classes_seen=classes_seen,
68-
print_tree=TreePrinter(tree))
75+
print_tree=TreePrinter(tree),
76+
features=features)
6977

7078
return fitresult, cache, report
7179
end
@@ -76,7 +84,9 @@ function get_encoding(classes_seen)
7684
end
7785

7886
MMI.fitted_params(::DecisionTreeClassifier, fitresult) =
79-
(tree=fitresult[1], encoding=get_encoding(fitresult[2]))
87+
(tree=fitresult[1],
88+
encoding=get_encoding(fitresult[2]),
89+
features=fitresult[4])
8090

8191
function smooth(scores, smoothing)
8292
iszero(smoothing) && return scores
@@ -92,10 +102,9 @@ function MMI.predict(m::DecisionTreeClassifier, fitresult, Xnew)
92102
tree, classes_seen, integers_seen = fitresult
93103
# retrieve the predicted scores
94104
scores = DT.apply_tree_proba(tree, Xmatrix, integers_seen)
95-
# smooth if required
96-
sm_scores = smooth(scores, m.pdf_smoothing)
105+
97106
# return vector of UF
98-
return MMI.UnivariateFinite(classes_seen, sm_scores)
107+
return MMI.UnivariateFinite(classes_seen, scores)
99108
end
100109

101110

@@ -109,7 +118,6 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic
109118
n_subfeatures::Int = (-)(1)::(_ ≥ -1)
110119
n_trees::Int = 10::(_ ≥ 2)
111120
sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1)
112-
pdf_smoothing::Float64 = 0.0::(0 ≤ _ ≤ 1)
113121
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
114122
end
115123

@@ -140,16 +148,14 @@ function MMI.predict(m::RandomForestClassifier, fitresult, Xnew)
140148
Xmatrix = MMI.matrix(Xnew)
141149
forest, classes_seen, integers_seen = fitresult
142150
scores = DT.apply_forest_proba(forest, Xmatrix, integers_seen)
143-
sm_scores = smooth(scores, m.pdf_smoothing)
144-
return MMI.UnivariateFinite(classes_seen, sm_scores)
151+
return MMI.UnivariateFinite(classes_seen, scores)
145152
end
146153

147154

148155
# # ADA BOOST STUMP CLASSIFIER
149156

150157
MMI.@mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic
151158
n_iter::Int = 10::(_ ≥ 1)
152-
pdf_smoothing::Float64 = 0.0::(0 ≤ _ ≤ 1)
153159
end
154160

155161
function MMI.fit(m::AdaBoostStumpClassifier, verbosity::Int, X, y)
@@ -174,8 +180,7 @@ function MMI.predict(m::AdaBoostStumpClassifier, fitresult, Xnew)
174180
stumps, coefs, classes_seen, integers_seen = fitresult
175181
scores = DT.apply_adaboost_stumps_proba(stumps, coefs,
176182
Xmatrix, integers_seen)
177-
sm_scores = smooth(scores, m.pdf_smoothing)
178-
return MMI.UnivariateFinite(classes_seen, sm_scores)
183+
return MMI.UnivariateFinite(classes_seen, scores)
179184
end
180185

181186

@@ -228,7 +233,6 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic
228233
n_subfeatures::Int = (-)(1)::(_ ≥ -1)
229234
n_trees::Int = 10::(_ ≥ 2)
230235
sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1)
231-
pdf_smoothing::Float64 = 0.0::(0 ≤ _ ≤ 1)
232236
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
233237
end
234238

@@ -364,14 +368,6 @@ Train the machine using `fit!(mach, rows=...)`.
364368
365369
- `rng=Random.GLOBAL_RNG`: random number generator or seed
366370
367-
- `pdf_smoothing=0.0`: threshold for smoothing the predicted scores.
368-
Raw leaf-based probabilities are smoothed as follows: If `n` is the
369-
number of observed classes, then each class probability is replaced
370-
by `pdf_smoothing/n`, if it falls below that ratio, and the
371-
resulting vector of probabilities is renormalized. Smoothing is only
372-
applied to classes actually observed in training. Unseen classes
373-
retain zero-probability predictions.
374-
375371
376372
# Operations
377373
@@ -394,6 +390,9 @@ The fields of `fitted_params(mach)` are:
394390
of tree (obtained by calling `fit!(mach, verbosity=2)` or from
395391
report - see below)
396392
393+
- `features`: the names of the features encountered in training, in an
394+
order consistent with the output of `print_tree` (see below)
395+
397396
398397
# Report
399398
@@ -405,6 +404,9 @@ The fields of `report(mach)` are:
405404
tree, with single argument the tree depth; interpretation requires
406405
internal integer-class encoding (see "Fitted parameters" above).
407406
407+
- `features`: the names of the features encountered in training, in an
408+
order consistent with the output of `print_tree` (see below)
409+
408410
409411
# Examples
410412
@@ -495,9 +497,6 @@ Train the machine with `fit!(mach, rows=...)`.
495497
496498
- `rng=Random.GLOBAL_RNG`: random number generator or seed
497499
498-
- `pdf_smoothing=0.0`: threshold for smoothing the predicted scores of
499-
each tree. See [`DecisionTreeClassifier`](@ref)
500-
501500
502501
# Operations
503502
@@ -569,9 +568,6 @@ Train the machine with `fit!(mach, rows=...)`.
569568
570569
- `n_iter=10`: number of iterations of AdaBoost
571570
572-
- `pdf_smoothing=0.0`: threshold for smoothing the predicted scores.
573-
See [`DecisionTreeClassifier`](@ref)
574-
575571
576572
# Operations
577573

test/runtests.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,15 @@ yhat = MLJBase.predict(baretree, fitresult, X);
3737
yyhat = predict_mode(baretree, fitresult, MLJBase.selectrows(X, 1:3))
3838
@test MLJBase.classes(yyhat[1]) == MLJBase.classes(y[1])
3939

40+
# check report and fitresult fields:
41+
@test Set([:classes_seen, :print_tree, :features]) == Set(keys(report))
42+
@test Set(report.classes_seen) == Set(levels(y))
43+
@test report.print_tree(2) === nothing # :-(
44+
@test report.features == [:sepal_length, :sepal_width, :petal_length, :petal_width]
45+
fp = fitted_params(baretree, fitresult)
46+
@test Set([:tree, :encoding, :features]) == Set(keys(fp))
47+
@test fp.features == report.features
4048

41-
# # testing machine interface:
42-
# tree = machine(baretree, X, y)
43-
# fit!(tree)
44-
# yyhat = predict_mode(tree, MLJBase.selectrows(X, 1:3))
4549
using Random: seed!
4650
seed!(0)
4751

@@ -79,7 +83,7 @@ rgs = DecisionTreeRegressor()
7983
fitresult, _, _ = MLJBase.fit(rgs, 1, X, ycont)
8084
@test rms(predict(rgs, fitresult, X), ycont) < 1.5
8185

82-
clf = DecisionTreeClassifier(pdf_smoothing=0)
86+
clf = DecisionTreeClassifier()
8387
fitresult, _, _ = MLJBase.fit(clf, 1, X, yfinite)
8488
@test sum(predict(clf, fitresult, X) .== yfinite) == 0 # perfect prediction
8589

0 commit comments

Comments
 (0)