@@ -3,6 +3,7 @@ module MLJDecisionTreeInterface
33import MLJModelInterface
44using MLJModelInterface. ScientificTypesBase
55import DecisionTree
6+ import Tables
67
78using Random
89import Random. GLOBAL_RNG
@@ -37,15 +38,21 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
3738 n_subfeatures:: Int = 0 :: (_ ≥ -1)
3839 post_prune:: Bool = false
3940 merge_purity_threshold:: Float64 = 1.0 :: (_ ≤ 1)
40- pdf_smoothing:: Float64 = 0.0 :: (0 ≤ _ ≤ 1)
4141 display_depth:: Int = 5 :: (_ ≥ 1)
4242 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
4343end
4444
4545function MMI. fit (m:: DecisionTreeClassifier , verbosity:: Int , X, y)
46+ schema = Tables. schema (X)
4647 Xmatrix = MMI. matrix (X)
4748 yplain = MMI. int (y)
4849
50+ if schema === nothing
51+ features = [Symbol (" x$j " ) for j in 1 : size (Xmatrix, 2 )]
52+ else
53+ features = schema. names |> collect
54+ end
55+
4956 classes_seen = filter (in (unique (y)), MMI. classes (y[1 ]))
5057 integers_seen = MMI. int (classes_seen)
5158
@@ -61,11 +68,12 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
6168 end
6269 verbosity < 2 || DT. print_tree (tree, m. display_depth)
6370
64- fitresult = (tree, classes_seen, integers_seen)
71+ fitresult = (tree, classes_seen, integers_seen, features )
6572
6673 cache = nothing
6774 report = (classes_seen= classes_seen,
68- print_tree= TreePrinter (tree))
75+ print_tree= TreePrinter (tree),
76+ features= features)
6977
7078 return fitresult, cache, report
7179end
@@ -76,7 +84,9 @@ function get_encoding(classes_seen)
7684end
7785
7886MMI. fitted_params (:: DecisionTreeClassifier , fitresult) =
79- (tree= fitresult[1 ], encoding= get_encoding (fitresult[2 ]))
87+ (tree= fitresult[1 ],
88+ encoding= get_encoding (fitresult[2 ]),
89+ features= fitresult[4 ])
8090
8191function smooth (scores, smoothing)
8292 iszero (smoothing) && return scores
@@ -92,10 +102,9 @@ function MMI.predict(m::DecisionTreeClassifier, fitresult, Xnew)
92102 tree, classes_seen, integers_seen = fitresult
93103 # retrieve the predicted scores
94104 scores = DT. apply_tree_proba (tree, Xmatrix, integers_seen)
95- # smooth if required
96- sm_scores = smooth (scores, m. pdf_smoothing)
105+
97106 # return vector of UF
98- return MMI. UnivariateFinite (classes_seen, sm_scores )
107+ return MMI. UnivariateFinite (classes_seen, scores )
99108end
100109
101110
@@ -109,7 +118,6 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic
109118 n_subfeatures:: Int = (- )(1 ):: (_ ≥ -1)
110119 n_trees:: Int = 10 :: (_ ≥ 2)
111120 sampling_fraction:: Float64 = 0.7 :: (0 < _ ≤ 1)
112- pdf_smoothing:: Float64 = 0.0 :: (0 ≤ _ ≤ 1)
113121 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
114122end
115123
@@ -140,16 +148,14 @@ function MMI.predict(m::RandomForestClassifier, fitresult, Xnew)
140148 Xmatrix = MMI. matrix (Xnew)
141149 forest, classes_seen, integers_seen = fitresult
142150 scores = DT. apply_forest_proba (forest, Xmatrix, integers_seen)
143- sm_scores = smooth (scores, m. pdf_smoothing)
144- return MMI. UnivariateFinite (classes_seen, sm_scores)
151+ return MMI. UnivariateFinite (classes_seen, scores)
145152end
146153
147154
148155# # ADA BOOST STUMP CLASSIFIER
149156
150157MMI. @mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic
151158 n_iter:: Int = 10 :: (_ ≥ 1)
152- pdf_smoothing:: Float64 = 0.0 :: (0 ≤ _ ≤ 1)
153159end
154160
155161function MMI. fit (m:: AdaBoostStumpClassifier , verbosity:: Int , X, y)
@@ -174,8 +180,7 @@ function MMI.predict(m::AdaBoostStumpClassifier, fitresult, Xnew)
174180 stumps, coefs, classes_seen, integers_seen = fitresult
175181 scores = DT. apply_adaboost_stumps_proba (stumps, coefs,
176182 Xmatrix, integers_seen)
177- sm_scores = smooth (scores, m. pdf_smoothing)
178- return MMI. UnivariateFinite (classes_seen, sm_scores)
183+ return MMI. UnivariateFinite (classes_seen, scores)
179184end
180185
181186
@@ -228,7 +233,6 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic
228233 n_subfeatures:: Int = (- )(1 ):: (_ ≥ -1)
229234 n_trees:: Int = 10 :: (_ ≥ 2)
230235 sampling_fraction:: Float64 = 0.7 :: (0 < _ ≤ 1)
231- pdf_smoothing:: Float64 = 0.0 :: (0 ≤ _ ≤ 1)
232236 rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
233237end
234238
@@ -364,14 +368,6 @@ Train the machine using `fit!(mach, rows=...)`.
364368
365369- `rng=Random.GLOBAL_RNG`: random number generator or seed
366370
367- - `pdf_smoothing=0.0`: threshold for smoothing the predicted scores.
368- Raw leaf-based probabilities are smoothed as follows: If `n` is the
369- number of observed classes, then each class probability is replaced
370- by `pdf_smoothing/n`, if it falls below that ratio, and the
371- resulting vector of probabilities is renormalized. Smoothing is only
372- applied to classes actually observed in training. Unseen classes
373- retain zero-probability predictions.
374-
375371
376372# Operations
377373
@@ -394,6 +390,9 @@ The fields of `fitted_params(mach)` are:
394390 of tree (obtained by calling `fit!(mach, verbosity=2)` or from
395391 report - see below)
396392
393+ - `features`: the names of the features encountered in training, in an
394+ order consistent with the output of `print_tree` (see below)
395+
397396
398397# Report
399398
@@ -405,6 +404,9 @@ The fields of `report(mach)` are:
405404 tree, with single argument the tree depth; interpretation requires
406405 internal integer-class encoding (see "Fitted parameters" above).
407406
407+ - `features`: the names of the features encountered in training, in an
408+ order consistent with the output of `print_tree` (see below)
409+
408410
409411# Examples
410412
@@ -495,9 +497,6 @@ Train the machine with `fit!(mach, rows=...)`.
495497
496498- `rng=Random.GLOBAL_RNG`: random number generator or seed
497499
498- - `pdf_smoothing=0.0`: threshold for smoothing the predicted scores of
499- each tree. See [`DecisionTreeClassifier`](@ref)
500-
501500
502501# Operations
503502
@@ -569,9 +568,6 @@ Train the machine with `fit!(mach, rows=...)`.
569568
570569- `n_iter=10`: number of iterations of AdaBoost
571570
572- - `pdf_smoothing=0.0`: threshold for smoothing the predicted scores.
573- See [`DecisionTreeClassifier`](@ref)
574-
575571
576572# Operations
577573
0 commit comments