@@ -3,6 +3,7 @@ module MLJDecisionTreeInterface
33import MLJModelInterface
44using MLJModelInterface. ScientificTypesBase
55import DecisionTree
6+ import Tables
67
78using Random
89import Random. GLOBAL_RNG
@@ -43,9 +44,16 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
4344end
4445
4546function MMI. fit (m:: DecisionTreeClassifier , verbosity:: Int , X, y)
47+ schema = Tables. schema (X)
4648 Xmatrix = MMI. matrix (X)
4749 yplain = MMI. int (y)
4850
51+ if schema === nothing
52+ features = [Symbol (" x$j " ) for j in 1 : size (Xmatrix, 2 )]
53+ else
54+ features = schema. names |> collect
55+ end
56+
4957 classes_seen = filter (in (unique (y)), MMI. classes (y[1 ]))
5058 integers_seen = MMI. int (classes_seen)
5159
@@ -61,11 +69,12 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
6169 end
6270 verbosity < 2 || DT. print_tree (tree, m. display_depth)
6371
64- fitresult = (tree, classes_seen, integers_seen)
72+ fitresult = (tree, classes_seen, integers_seen, features )
6573
6674 cache = nothing
6775 report = (classes_seen= classes_seen,
68- print_tree= TreePrinter (tree))
76+ print_tree= TreePrinter (tree),
77+ features= features)
6978
7079 return fitresult, cache, report
7180end
@@ -76,7 +85,9 @@ function get_encoding(classes_seen)
7685end
7786
7887MMI. fitted_params (:: DecisionTreeClassifier , fitresult) =
79- (tree= fitresult[1 ], encoding= get_encoding (fitresult[2 ]))
88+ (tree= fitresult[1 ],
89+ encoding= get_encoding (fitresult[2 ]),
90+ features= fitresult[4 ])
8091
8192function smooth (scores, smoothing)
8293 iszero (smoothing) && return scores
@@ -394,6 +405,9 @@ The fields of `fitted_params(mach)` are:
394405 of tree (obtained by calling `fit!(mach, verbosity=2)` or from
395406 report - see below)
396407
408+ - `features`: the names of the features encountered in training, in an
409+ order consistent with the output of `print_tree` (see below)
410+
397411
398412# Report
399413
@@ -405,6 +419,9 @@ The fields of `report(mach)` are:
405419 tree, with single argument the tree depth; interpretation requires
406420 internal integer-class encoding (see "Fitted parameters" above).
407421
422+ - `features`: the names of the features encountered in training, in an
423+ order consistent with the output of `print_tree` (see below)
424+
408425
409426# Examples
410427
0 commit comments