@@ -15,9 +15,10 @@ const PKG = "MLJDecisionTreeInterface"
1515
1616struct TreePrinter{T}
1717 tree:: T
18+ features:: Vector{Symbol}
1819end
19- (c:: TreePrinter )(depth) = DT. print_tree (c. tree, depth)
20- (c:: TreePrinter )() = DT. print_tree (c. tree, 5 )
20+ (c:: TreePrinter )(depth) = DT. print_tree (c. tree, depth, feature_names = c . features )
21+ (c:: TreePrinter )() = DT. print_tree (c. tree, 5 , feature_names = c . features )
2122
2223Base. show (stream:: IO , c:: TreePrinter ) =
2324 print (stream, " TreePrinter object (call with display depth)" )
@@ -71,7 +72,7 @@ function MMI.fit(
7172 cache = nothing
7273 report = (
7374 classes_seen= classes_seen,
74- print_tree= TreePrinter (tree),
75+ print_tree= TreePrinter (tree, features ),
7576 features= features,
7677 )
7778 return fitresult, cache, report
@@ -765,6 +766,8 @@ The fields of `fitted_params(mach)` are:
765766
766767# Report
767768
769+ The fields of `report(mach)` are:
770+
768771- `features`: the names of the features encountered in training
769772
770773
@@ -862,6 +865,8 @@ The fields of `fitted_params(mach)` are:
862865
863866# Report
864867
868+ The fields of `report(mach)` are:
869+
865870- `features`: the names of the features encountered in training
866871
867872
@@ -968,6 +973,8 @@ The fields of `fitted_params(mach)` are:
968973
969974# Report
970975
976+ The fields of `report(mach)` are:
977+
971978- `features`: the names of the features encountered in training
972979
973980
@@ -1079,6 +1086,8 @@ The fields of `fitted_params(mach)` are:
10791086
10801087# Report
10811088
1089+ The fields of `report(mach)` are:
1090+
10821091- `features`: the names of the features encountered in training
10831092
10841093
0 commit comments