Skip to content

Commit 4c1f0fc

Browse files
authored
Merge pull request #14 from JuliaAI/expose-feature-names
Expose feature names
2 parents 0914878 + 172e944 commit 4c1f0fc

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@ version = "0.1.4"
77
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
88
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
10+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1011

1112
[compat]
1213
DecisionTree = "0.10"
1314
MLJModelInterface = "^0.3,^0.4, 1.0"
15+
Tables = "1.6"
1416
julia = "1.6"
1517

1618
[extras]

src/MLJDecisionTreeInterface.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module MLJDecisionTreeInterface
33
import MLJModelInterface
44
using MLJModelInterface.ScientificTypesBase
55
import DecisionTree
6+
import Tables
67

78
using Random
89
import Random.GLOBAL_RNG
@@ -43,9 +44,16 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
4344
end
4445

4546
function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
47+
schema = Tables.schema(X)
4648
Xmatrix = MMI.matrix(X)
4749
yplain = MMI.int(y)
4850

51+
if schema === nothing
52+
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
53+
else
54+
features = schema.names |> collect
55+
end
56+
4957
classes_seen = filter(in(unique(y)), MMI.classes(y[1]))
5058
integers_seen = MMI.int(classes_seen)
5159

@@ -61,11 +69,12 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
6169
end
6270
verbosity < 2 || DT.print_tree(tree, m.display_depth)
6371

64-
fitresult = (tree, classes_seen, integers_seen)
72+
fitresult = (tree, classes_seen, integers_seen, features)
6573

6674
cache = nothing
6775
report = (classes_seen=classes_seen,
68-
print_tree=TreePrinter(tree))
76+
print_tree=TreePrinter(tree),
77+
features=features)
6978

7079
return fitresult, cache, report
7180
end
@@ -76,7 +85,9 @@ function get_encoding(classes_seen)
7685
end
7786

7887
MMI.fitted_params(::DecisionTreeClassifier, fitresult) =
79-
(tree=fitresult[1], encoding=get_encoding(fitresult[2]))
88+
(tree=fitresult[1],
89+
encoding=get_encoding(fitresult[2]),
90+
features=fitresult[4])
8091

8192
function smooth(scores, smoothing)
8293
iszero(smoothing) && return scores
@@ -394,6 +405,9 @@ The fields of `fitted_params(mach)` are:
394405
of tree (obtained by calling `fit!(mach, verbosity=2)` or from
395406
report - see below)
396407
408+
- `features`: the names of the features encountered in training, in an
409+
order consistent with the output of `print_tree` (see below)
410+
397411
398412
# Report
399413
@@ -405,6 +419,9 @@ The fields of `report(mach)` are:
405419
tree, with single argument the tree depth; interpretation requires
406420
internal integer-class encoding (see "Fitted parameters" above).
407421
422+
- `features`: the names of the features encountered in training, in an
423+
order consistent with the output of `print_tree` (see below)
424+
408425
409426
# Examples
410427

test/runtests.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,15 @@ yhat = MLJBase.predict(baretree, fitresult, X);
3737
yyhat = predict_mode(baretree, fitresult, MLJBase.selectrows(X, 1:3))
3838
@test MLJBase.classes(yyhat[1]) == MLJBase.classes(y[1])
3939

40+
# check report and fitresult fields:
41+
@test Set([:classes_seen, :print_tree, :features]) == Set(keys(report))
42+
@test Set(report.classes_seen) == Set(levels(y))
43+
@test report.print_tree(2) === nothing # :-(
44+
@test report.features == [:sepal_length, :sepal_width, :petal_length, :petal_width]
45+
fp = fitted_params(baretree, fitresult)
46+
@test Set([:tree, :encoding, :features]) == Set(keys(fp))
47+
@test fp.features == report.features
4048

41-
# # testing machine interface:
42-
# tree = machine(baretree, X, y)
43-
# fit!(tree)
44-
# yyhat = predict_mode(tree, MLJBase.selectrows(X, 1:3))
4549
using Random: seed!
4650
seed!(0)
4751

0 commit comments

Comments
 (0)