@@ -77,12 +77,27 @@ function MMI.fit(
7777 return fitresult, cache, report
7878end
7979
80+ # returns a dictionary of categorical elements keyed on ref integer:
8081get_encoding(classes_seen) = Dict(MMI. int(c) => c for c in classes(classes_seen))
8182
82- MMI. fitted_params(:: DecisionTreeClassifier , fitresult) =
83- (tree= fitresult[1 ],
84- encoding= get_encoding(fitresult[2 ]),
85- features= fitresult[4 ])
83+ # given such a dictionary, return printable class labels, ordered by corresponding ref
84+ # integer:
85+ classlabels(encoding) = [string(encoding[i]) for i in sort(keys(encoding) |> collect)]
86+
87+ _node_or_leaf(r:: DecisionTree.Root ) = r. node
88+ _node_or_leaf(n:: Any ) = n
89+
90+ function MMI. fitted_params(:: DecisionTreeClassifier , fitresult)
91+ raw_tree = fitresult[1 ]
92+ encoding = get_encoding(fitresult[2 ])
93+ features = fitresult[4 ]
94+ classlabels = MLJDecisionTreeInterface. classlabels(encoding)
95+ tree = DecisionTree. wrap(
96+ _node_or_leaf(raw_tree),
97+ (featurenames= features, classlabels),
98+ )
99+ (; tree, raw_tree, encoding, features)
100+ end
86101
87102function MMI. predict(m:: DecisionTreeClassifier , fitresult, Xnew)
88103 tree, classes_seen, integers_seen = fitresult
@@ -285,13 +300,22 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, Xmatrix, y, features)
285300 cache = nothing
286301
287302 report = (features= features,)
303+ fitresult = (tree, features)
288304
289- return tree , cache, report
305+ return fitresult , cache, report
290306end
291307
292- MMI. fitted_params(:: DecisionTreeRegressor , tree) = (tree= tree,)
308+ function MMI. fitted_params(:: DecisionTreeRegressor , fitresult)
309+ raw_tree = fitresult[1 ]
310+ features = fitresult[2 ]
311+ tree = DecisionTree. wrap(
312+ _node_or_leaf(raw_tree),
313+ (; featurenames= features),
314+ )
315+ (; tree, raw_tree)
316+ end
293317
294- MMI. predict(:: DecisionTreeRegressor , tree , Xnew) = DT. apply_tree(tree , Xnew)
318+ MMI. predict(:: DecisionTreeRegressor , fitresult , Xnew) = DT. apply_tree(fitresult[ 1 ] , Xnew)
295319
296320MMI. reports_feature_importances(:: Type{<:DecisionTreeRegressor} ) = true
297321
@@ -446,11 +470,11 @@ MMI.selectrows(::TreeModel, I, Xmatrix) = (view(Xmatrix, I, :),)
446470
447471# get actual arguments needed for importance calculation from various fitresults.
448472get_fitresult(
449- m:: Union{DecisionTreeClassifier, RandomForestClassifier} ,
473+ m:: Union{DecisionTreeClassifier, RandomForestClassifier, DecisionTreeRegressor } ,
450474 fitresult,
451475) = (fitresult[1 ],)
452476get_fitresult(
453- m:: Union{DecisionTreeRegressor, RandomForestRegressor} ,
477+ m:: RandomForestRegressor ,
454478 fitresult,
455479) = (fitresult,)
456480get_fitresult(m:: AdaBoostStumpClassifier , fitresult)= (fitresult[1 ], fitresult[2 ])
@@ -600,12 +624,14 @@ Train the machine using `fit!(mach, rows=...)`.
600624
601625The fields of `fitted_params(mach)` are:
602626
603- - `tree`: the tree or stump object returned by the core DecisionTree.jl algorithm
627+ - `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl
628+ algorithm
629+
630+ - `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl
631+ interface; see "Examples" below
604632
605633- `encoding`: dictionary of target classes keyed on integers used
606- internally by DecisionTree.jl; needed to interpret pretty printing
607- of tree (obtained by calling `fit!(mach, verbosity=2)` or from
608- report - see below)
634+ internally by DecisionTree.jl
609635
610636- `features`: the names of the features encountered in training, in an
611637 order consistent with the output of `print_tree` (see below)
@@ -617,7 +643,7 @@ The fields of `report(mach)` are:
617643
618644- `classes_seen`: list of target classes actually observed in training
619645
620- - `print_tree`: method to print a pretty representation of the fitted
646+ - `print_tree`: alternative method to print the fitted
621647 tree, with single argument the tree depth; interpretation requires
622648 internal integer-class encoding (see "Fitted parameters" above).
623649
@@ -629,11 +655,11 @@ The fields of `report(mach)` are:
629655
630656```
631657using MLJ
632- Tree = @load DecisionTreeClassifier pkg=DecisionTree
633- tree = Tree (max_depth=4 , min_samples_split=3)
658+ DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
659+ model = DecisionTreeClassifier (max_depth=3 , min_samples_split=3)
634660
635661X, y = @load_iris
636- mach = machine(tree , X, y) |> fit!
662+ mach = machine(model , X, y) |> fit!
637663
638664Xnew = (sepal_length = [6.4, 7.2, 7.4],
639665 sepal_width = [2.8, 3.0, 2.8],
@@ -643,33 +669,26 @@ yhat = predict(mach, Xnew) # probabilistic predictions
643669predict_mode(mach, Xnew) # point predictions
644670pdf.(yhat, "virginica") # probabilities for the "verginica" class
645671
646- fitted_params(mach).tree # raw tree or stump object from DecisionTrees.jl
647-
648- julia> report(mach).print_tree(3)
649- Feature 4, Threshold 0.8
650- L-> 1 : 50/50
651- R-> Feature 4, Threshold 1.75
652- L-> Feature 3, Threshold 4.95
653- L->
654- R->
655- R-> Feature 3, Threshold 4.85
656- L->
657- R-> 3 : 43/43
672+ julia> tree = fitted_params(mach).tree
673+ petal_length < 2.45
674+ ├─ setosa (50/50)
675+ └─ petal_width < 1.75
676+ ├─ petal_length < 4.95
677+ │ ├─ versicolor (47/48)
678+ │ └─ virginica (4/6)
679+ └─ petal_length < 4.85
680+ ├─ virginica (2/3)
681+ └─ virginica (43/43)
682+
683+ using Plots, TreeRecipe
684+ plot(tree) # for a graphical representation of the tree
685+
686+ feature_importances(mach)
658687```
659688
660- To interpret the internal class labelling:
661-
662- ```
663- julia> fitted_params(mach).encoding
664- Dict{CategoricalArrays.CategoricalValue{String, UInt32}, UInt32} with 3 entries:
665- 0x00000003 => "virginica"
666- 0x00000001 => "setosa"
667- 0x00000002 => "versicolor"
668- ```
669-
670- See also
671- [DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and
672- the unwrapped model type [`MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier`](@ref).
689+ See also [DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and the
690+ unwrapped model type
691+ [`MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier`](@ref).
673692
674693"""
675694DecisionTreeClassifier
@@ -903,7 +922,8 @@ Train the machine with `fit!(mach, rows=...)`.
903922- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
904923 combined purity `>= merge_purity_threshold`
905924
906- - `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
925+ - `feature_importance`: method to use for computing feature importances. One of
926+ `(:impurity, :split)`
907927
908928- `rng=Random.GLOBAL_RNG`: random number generator or seed
909929
@@ -921,6 +941,8 @@ The fields of `fitted_params(mach)` are:
921941- `tree`: the tree or stump object returned by the core
922942 DecisionTree.jl algorithm
923943
944+ - `features`: the names of the features encountered in training
945+
924946
925947# Report
926948
@@ -931,16 +953,31 @@ The fields of `fitted_params(mach)` are:
931953
932954```
933955using MLJ
934- Tree = @load DecisionTreeRegressor pkg=DecisionTree
935- tree = Tree (max_depth=4 , min_samples_split=3)
956+ DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
957+ model = DecisionTreeRegressor (max_depth=3 , min_samples_split=3)
936958
937- X, y = make_regression(100, 2 ) # synthetic data
938- mach = machine(tree , X, y) |> fit!
959+ X, y = make_regression(100, 4; rng=123 ) # synthetic data
960+ mach = machine(model , X, y) |> fit!
939961
940- Xnew, _ = make_regression(3, 2)
962+ Xnew, _ = make_regression(3, 2; rng=123 )
941963yhat = predict(mach, Xnew) # new predictions
942964
943- fitted_params(mach).tree # raw tree or stump object from DecisionTree.jl
965+ julia> fitted_params(mach).tree
966+ x1 < 0.2758
967+ ├─ x2 < 0.9137
968+ │ ├─ x1 < -0.9582
969+ │ │ ├─ 0.9189256882087312 (0/12)
970+ │ │ └─ -0.23180616021065256 (0/38)
971+ │ └─ -1.6461153800037722 (0/9)
972+ └─ x1 < 1.062
973+ ├─ x2 < -0.4969
974+ │ ├─ -0.9330755147107384 (0/5)
975+ │ └─ -2.3287967825015548 (0/17)
976+ └─ x2 < 0.4598
977+ ├─ -2.931299926506291 (0/11)
978+ └─ -4.726518740473489 (0/8)
979+
980+ feature_importances(mach) # get feature importances
944981```
945982
946983See also
0 commit comments