@@ -3,6 +3,7 @@ module MLJDecisionTreeInterface
33import MLJModelInterface
44using MLJModelInterface. ScientificTypesBase
55import DecisionTree
6+ import Tables
67
78using Random
89import Random. GLOBAL_RNG
@@ -50,7 +51,7 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
5051 if schema === nothing
5152 features = [Symbol (" x$j " ) for j in 1 : size (Xmatrix, 2 )]
5253 else
53- features = schema. names
54+ features = schema. names |> collect
5455 end
5556
5657 classes_seen = filter (in (unique (y)), MMI. classes (y[1 ]))
@@ -84,7 +85,9 @@ function get_encoding(classes_seen)
8485end
8586
8687MMI. fitted_params (:: DecisionTreeClassifier , fitresult) =
87- (tree= fitresult[1 ], encoding= get_encoding (fitresult[2 ]), features= features)
88+ (tree= fitresult[1 ],
89+ encoding= get_encoding (fitresult[2 ]),
90+ features= fitresult[4 ])
8891
8992function smooth (scores, smoothing)
9093 iszero (smoothing) && return scores
@@ -402,6 +405,9 @@ The fields of `fitted_params(mach)` are:
402405 of tree (obtained by calling `fit!(mach, verbosity=2)` or from
403406 report - see below)
404407
408+ - `features`: the names of the features encountered in training, in an
409+ order consistent with the output of `print_tree` (see below)
410+
405411
406412# Report
407413
@@ -413,6 +419,9 @@ The fields of `report(mach)` are:
413419 tree, with single argument the tree depth; interpretation requires
414420 internal integer-class encoding (see "Fitted parameters" above).
415421
422+ - `features`: the names of the features encountered in training, in an
423+ order consistent with the output of `print_tree` (see below)
424+
416425
417426# Examples
418427
0 commit comments