Skip to content

Commit a80f7a3

Browse files
committed
wrap trees for better display to close #45
1 parent ebd0c7c commit a80f7a3

File tree

2 files changed

+87
-50
lines changed

2 files changed

+87
-50
lines changed

src/MLJDecisionTreeInterface.jl

Lines changed: 86 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,27 @@ function MMI.fit(
7777
return fitresult, cache, report
7878
end
7979

80+
# returns a dictionary of categorical elements keyed on ref integer:
8081
get_encoding(classes_seen) = Dict(MMI.int(c) => c for c in classes(classes_seen))
8182

82-
MMI.fitted_params(::DecisionTreeClassifier, fitresult) =
83-
(tree=fitresult[1],
84-
encoding=get_encoding(fitresult[2]),
85-
features=fitresult[4])
83+
# given such a dictionary, return printable class labels, ordered by corresponding ref
84+
# integer:
85+
classlabels(encoding) = [string(encoding[i]) for i in sort(keys(encoding) |> collect)]
86+
87+
_node_or_leaf(r::DecisionTree.Root) = r.node
88+
_node_or_leaf(n::Any) = n
89+
90+
function MMI.fitted_params(::DecisionTreeClassifier, fitresult)
91+
raw_tree = fitresult[1]
92+
encoding = get_encoding(fitresult[2])
93+
features = fitresult[4]
94+
classlabels = MLJDecisionTreeInterface.classlabels(encoding)
95+
tree = DecisionTree.wrap(
96+
_node_or_leaf(raw_tree),
97+
(featurenames=features, classlabels),
98+
)
99+
(; tree, raw_tree, encoding, features)
100+
end
86101

87102
function MMI.predict(m::DecisionTreeClassifier, fitresult, Xnew)
88103
tree, classes_seen, integers_seen = fitresult
@@ -285,13 +300,22 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, Xmatrix, y, features)
285300
cache = nothing
286301

287302
report = (features=features,)
303+
fitresult = (tree, features)
288304

289-
return tree, cache, report
305+
return fitresult, cache, report
290306
end
291307

292-
MMI.fitted_params(::DecisionTreeRegressor, tree) = (tree=tree,)
308+
function MMI.fitted_params(::DecisionTreeRegressor, fitresult)
309+
raw_tree = fitresult[1]
310+
features = fitresult[2]
311+
tree = DecisionTree.wrap(
312+
_node_or_leaf(raw_tree),
313+
(; featurenames=features),
314+
)
315+
(; tree, raw_tree)
316+
end
293317

294-
MMI.predict(::DecisionTreeRegressor, tree, Xnew) = DT.apply_tree(tree, Xnew)
318+
MMI.predict(::DecisionTreeRegressor, fitresult, Xnew) = DT.apply_tree(fitresult[1], Xnew)
295319

296320
MMI.reports_feature_importances(::Type{<:DecisionTreeRegressor}) = true
297321

@@ -446,11 +470,11 @@ MMI.selectrows(::TreeModel, I, Xmatrix) = (view(Xmatrix, I, :),)
446470

447471
# get actual arguments needed for importance calculation from various fitresults.
448472
get_fitresult(
449-
m::Union{DecisionTreeClassifier, RandomForestClassifier},
473+
m::Union{DecisionTreeClassifier, RandomForestClassifier, DecisionTreeRegressor},
450474
fitresult,
451475
) = (fitresult[1],)
452476
get_fitresult(
453-
m::Union{DecisionTreeRegressor, RandomForestRegressor},
477+
m::RandomForestRegressor,
454478
fitresult,
455479
) = (fitresult,)
456480
get_fitresult(m::AdaBoostStumpClassifier, fitresult)= (fitresult[1], fitresult[2])
@@ -600,12 +624,14 @@ Train the machine using `fit!(mach, rows=...)`.
600624
601625
The fields of `fitted_params(mach)` are:
602626
603-
- `tree`: the tree or stump object returned by the core DecisionTree.jl algorithm
627+
- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl
628+
algorithm
629+
630+
- `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl
631+
interface; see "Examples" below
604632
605633
- `encoding`: dictionary of target classes keyed on integers used
606-
internally by DecisionTree.jl; needed to interpret pretty printing
607-
of tree (obtained by calling `fit!(mach, verbosity=2)` or from
608-
report - see below)
634+
internally by DecisionTree.jl
609635
610636
- `features`: the names of the features encountered in training, in an
611637
order consistent with the output of `print_tree` (see below)
@@ -617,7 +643,7 @@ The fields of `report(mach)` are:
617643
618644
- `classes_seen`: list of target classes actually observed in training
619645
620-
- `print_tree`: method to print a pretty representation of the fitted
646+
- `print_tree`: alternative method to print the fitted
621647
tree, with single argument the tree depth; interpretation requires
622648
internal integer-class encoding (see "Fitted parameters" above).
623649
@@ -629,11 +655,11 @@ The fields of `report(mach)` are:
629655
630656
```
631657
using MLJ
632-
Tree = @load DecisionTreeClassifier pkg=DecisionTree
633-
tree = Tree(max_depth=4, min_samples_split=3)
658+
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
659+
model = DecisionTreeClassifier(max_depth=3, min_samples_split=3)
634660
635661
X, y = @load_iris
636-
mach = machine(tree, X, y) |> fit!
662+
mach = machine(model, X, y) |> fit!
637663
638664
Xnew = (sepal_length = [6.4, 7.2, 7.4],
639665
sepal_width = [2.8, 3.0, 2.8],
@@ -643,33 +669,26 @@ yhat = predict(mach, Xnew) # probabilistic predictions
643669
predict_mode(mach, Xnew) # point predictions
644670
pdf.(yhat, "virginica") # probabilities for the "verginica" class
645671
646-
fitted_params(mach).tree # raw tree or stump object from DecisionTrees.jl
647-
648-
julia> report(mach).print_tree(3)
649-
Feature 4, Threshold 0.8
650-
L-> 1 : 50/50
651-
R-> Feature 4, Threshold 1.75
652-
L-> Feature 3, Threshold 4.95
653-
L->
654-
R->
655-
R-> Feature 3, Threshold 4.85
656-
L->
657-
R-> 3 : 43/43
672+
julia> tree = fitted_params(mach).tree
673+
petal_length < 2.45
674+
├─ setosa (50/50)
675+
└─ petal_width < 1.75
676+
├─ petal_length < 4.95
677+
│ ├─ versicolor (47/48)
678+
│ └─ virginica (4/6)
679+
└─ petal_length < 4.85
680+
├─ virginica (2/3)
681+
└─ virginica (43/43)
682+
683+
using Plots, TreeRecipe
684+
plot(tree) # for a graphical representation of the tree
685+
686+
feature_importances(mach)
658687
```
659688
660-
To interpret the internal class labelling:
661-
662-
```
663-
julia> fitted_params(mach).encoding
664-
Dict{CategoricalArrays.CategoricalValue{String, UInt32}, UInt32} with 3 entries:
665-
0x00000003 => "virginica"
666-
0x00000001 => "setosa"
667-
0x00000002 => "versicolor"
668-
```
669-
670-
See also
671-
[DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and
672-
the unwrapped model type [`MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier`](@ref).
689+
See also [DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and the
690+
unwrapped model type
691+
[`MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier`](@ref).
673692
674693
"""
675694
DecisionTreeClassifier
@@ -903,7 +922,8 @@ Train the machine with `fit!(mach, rows=...)`.
903922
- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
904923
combined purity `>= merge_purity_threshold`
905924
906-
- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
925+
- `feature_importance`: method to use for computing feature importances. One of
926+
`(:impurity, :split)`
907927
908928
- `rng=Random.GLOBAL_RNG`: random number generator or seed
909929
@@ -921,6 +941,8 @@ The fields of `fitted_params(mach)` are:
921941
- `tree`: the tree or stump object returned by the core
922942
DecisionTree.jl algorithm
923943
944+
- `features`: the names of the features encountered in training
945+
924946
925947
# Report
926948
@@ -931,16 +953,31 @@ The fields of `fitted_params(mach)` are:
931953
932954
```
933955
using MLJ
934-
Tree = @load DecisionTreeRegressor pkg=DecisionTree
935-
tree = Tree(max_depth=4, min_samples_split=3)
956+
DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
957+
model = DecisionTreeRegressor(max_depth=3, min_samples_split=3)
936958
937-
X, y = make_regression(100, 2) # synthetic data
938-
mach = machine(tree, X, y) |> fit!
959+
X, y = make_regression(100, 4; rng=123) # synthetic data
960+
mach = machine(model, X, y) |> fit!
939961
940-
Xnew, _ = make_regression(3, 2)
962+
Xnew, _ = make_regression(3, 2; rng=123)
941963
yhat = predict(mach, Xnew) # new predictions
942964
943-
fitted_params(mach).tree # raw tree or stump object from DecisionTree.jl
965+
julia> fitted_params(mach).tree
966+
x1 < 0.2758
967+
├─ x2 < 0.9137
968+
│ ├─ x1 < -0.9582
969+
│ │ ├─ 0.9189256882087312 (0/12)
970+
│ │ └─ -0.23180616021065256 (0/38)
971+
│ └─ -1.6461153800037722 (0/9)
972+
└─ x1 < 1.062
973+
├─ x2 < -0.4969
974+
│ ├─ -0.9330755147107384 (0/5)
975+
│ └─ -2.3287967825015548 (0/17)
976+
└─ x2 < 0.4598
977+
├─ -2.931299926506291 (0/11)
978+
└─ -4.726518740473489 (0/8)
979+
980+
feature_importances(mach) # get feature importances
944981
```
945982
946983
See also

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ yyhat = predict_mode(baretree, fitresult, X[1:3, :])
8282
@test report.features == [:sepal_length, :sepal_width, :petal_length, :petal_width]
8383

8484
fp = fitted_params(baretree, fitresult)
85-
@test Set([:tree, :encoding, :features]) == Set(keys(fp))
85+
@test Set([:tree, :encoding, :features, :raw_tree]) == Set(keys(fp))
8686
@test fp.features == report.features
8787
enc = fp.encoding
8888
@test Set(values(enc)) == Set(["virginica", "setosa", "versicolor"])

0 commit comments

Comments
 (0)