|
56 | 56 | is_leaf(l::Leaf) = true |
57 | 57 | is_leaf(n::Node) = false |
58 | 58 |
|
59 | | -zero(::Type{String}) = "" |
60 | | -convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, zero(S), lf, Leaf(zero(T), [zero(T)])) |
| 59 | +_zero(::Type{String}) = "" |
| 60 | +_zero(x::Any) = zero(x) |
| 61 | +convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, _zero(S), lf, Leaf(_zero(T), [_zero(T)])) |
61 | 62 | convert(::Type{Root{S, T}}, node::LeafOrNode{S, T}) where {S, T} = Root{S, T}(node, 0, Float64[]) |
62 | 63 | convert(::Type{LeafOrNode{S, T}}, tree::Root{S, T}) where {S, T} = tree.node |
63 | 64 | promote_rule(::Type{Node{S, T}}, ::Type{Leaf{T}}) where {S, T} = Node{S, T} |
@@ -95,21 +96,21 @@ depth(leaf::Leaf) = 0 |
95 | 96 | depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right)) |
96 | 97 | depth(tree::Root) = depth(tree.node) |
97 | 98 |
|
98 | | -function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; feature_names=nothing) |
| 99 | +function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing) |
99 | 100 | n_matches = count(leaf.values .== leaf.majority) |
100 | 101 | ratio = string(n_matches, "/", length(leaf.values)) |
101 | 102 | println(io, "$(leaf.majority) : $(ratio)") |
102 | 103 | end |
103 | | -function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing) |
104 | | - return print_tree(stdout, leaf, depth, indent; feature_names=feature_names) |
| 104 | +function print_tree(leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing) |
| 105 | + return print_tree(stdout, leaf, depth, indent; sigdigits, feature_names) |
105 | 106 | end |
106 | 107 |
|
107 | 108 |
|
108 | | -function print_tree(io::IO, tree::Root, depth=-1, indent=0; sigdigits=2, feature_names=nothing) |
109 | | - return print_tree(io, tree.node, depth, indent; sigdigits=sigdigits, feature_names=feature_names) |
| 109 | +function print_tree(io::IO, tree::Root, depth=-1, indent=0; sigdigits=4, feature_names=nothing) |
| 110 | + return print_tree(io, tree.node, depth, indent; sigdigits, feature_names) |
110 | 111 | end |
111 | | -function print_tree(tree::Root, depth=-1, indent=0; sigdigits=2, feature_names=nothing) |
112 | | - return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names) |
| 112 | +function print_tree(tree::Root, depth=-1, indent=0; sigdigits=4, feature_names=nothing) |
| 113 | + return print_tree(stdout, tree, depth, indent; sigdigits, feature_names) |
113 | 114 | end |
114 | 115 |
|
115 | 116 | """ |
@@ -137,26 +138,26 @@ Feature 3 < -28.15 ? |
137 | 138 |
|
138 | 139 | To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object, |
139 | 140 | `DecisionTree.Node` object or `DecisionTree.Root` object can be wrapped to obtain a tree structure implementing the |
140 | | -AbstractTrees.jl interface. See [`wrap`](@ref)` for details. |
| 141 | +AbstractTrees.jl interface. See [`wrap`](@ref)` for details. |
141 | 142 | """ |
142 | 143 | function print_tree(io::IO, tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing) |
143 | 144 | if depth == indent |
144 | 145 | println(io) |
145 | 146 | return |
146 | 147 | end |
147 | | - featval = round(tree.featval; sigdigits=sigdigits) |
| 148 | + featval = round(tree.featval; sigdigits) |
148 | 149 | if feature_names === nothing |
149 | 150 | println(io, "Feature $(tree.featid) < $featval ?") |
150 | 151 | else |
151 | 152 | println(io, "Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $featval ?") |
152 | 153 | end |
153 | 154 | print(io, " " ^ indent * "├─ ") |
154 | | - print_tree(io, tree.left, depth, indent + 1; feature_names=feature_names) |
| 155 | + print_tree(io, tree.left, depth, indent + 1; sigdigits, feature_names) |
155 | 156 | print(io, " " ^ indent * "└─ ") |
156 | | - print_tree(io, tree.right, depth, indent + 1; feature_names=feature_names) |
| 157 | + print_tree(io, tree.right, depth, indent + 1; sigdigits, feature_names) |
157 | 158 | end |
158 | | -function print_tree(tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing) |
159 | | - return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names) |
| 159 | +function print_tree(tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing) |
| 160 | + return print_tree(stdout, tree, depth, indent; sigdigits, feature_names) |
160 | 161 | end |
161 | 162 |
|
162 | 163 | function show(io::IO, leaf::Leaf) |
|
0 commit comments