Skip to content

Commit 172e944

Browse files
committed
expose feature names in fitresult and report
1 parent 8625f28 commit 172e944

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

src/MLJDecisionTreeInterface.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module MLJDecisionTreeInterface
33
import MLJModelInterface
44
using MLJModelInterface.ScientificTypesBase
55
import DecisionTree
6+
import Tables
67

78
using Random
89
import Random.GLOBAL_RNG
@@ -50,7 +51,7 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
5051
if schema === nothing
5152
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
5253
else
53-
features = schema.names
54+
features = schema.names |> collect
5455
end
5556

5657
classes_seen = filter(in(unique(y)), MMI.classes(y[1]))
@@ -84,7 +85,9 @@ function get_encoding(classes_seen)
8485
end
8586

8687
MMI.fitted_params(::DecisionTreeClassifier, fitresult) =
87-
(tree=fitresult[1], encoding=get_encoding(fitresult[2]), features=features)
88+
(tree=fitresult[1],
89+
encoding=get_encoding(fitresult[2]),
90+
features=fitresult[4])
8891

8992
function smooth(scores, smoothing)
9093
iszero(smoothing) && return scores
@@ -402,6 +405,9 @@ The fields of `fitted_params(mach)` are:
402405
of tree (obtained by calling `fit!(mach, verbosity=2)` or from
403406
report - see below)
404407
408+
- `features`: the names of the features encountered in training, in an
409+
order consistent with the output of `print_tree` (see below)
410+
405411
406412
# Report
407413
@@ -413,6 +419,9 @@ The fields of `report(mach)` are:
413419
tree, with single argument the tree depth; interpretation requires
414420
internal integer-class encoding (see "Fitted parameters" above).
415421
422+
- `features`: the names of the features encountered in training, in an
423+
order consistent with the output of `print_tree` (see below)
424+
416425
417426
# Examples
418427

test/runtests.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,15 @@ yhat = MLJBase.predict(baretree, fitresult, X);
3737
yyhat = predict_mode(baretree, fitresult, MLJBase.selectrows(X, 1:3))
3838
@test MLJBase.classes(yyhat[1]) == MLJBase.classes(y[1])
3939

40+
# check report and fitresult fields:
41+
@test Set([:classes_seen, :print_tree, :features]) == Set(keys(report))
42+
@test Set(report.classes_seen) == Set(levels(y))
43+
@test report.print_tree(2) === nothing # :-(
44+
@test report.features == [:sepal_length, :sepal_width, :petal_length, :petal_width]
45+
fp = fitted_params(baretree, fitresult)
46+
@test Set([:tree, :encoding, :features]) == Set(keys(fp))
47+
@test fp.features == report.features
4048

41-
# # testing machine interface:
42-
# tree = machine(baretree, X, y)
43-
# fit!(tree)
44-
# yyhat = predict_mode(tree, MLJBase.selectrows(X, 1:3))
4549
using Random: seed!
4650
seed!(0)
4751

0 commit comments

Comments
 (0)