Skip to content

Commit c64b296

Browse files
committed
fix tests in regression
1 parent f98b40f commit c64b296

File tree

3 files changed

+48
-1
lines changed

3 files changed

+48
-1
lines changed

src/regression/bench.jl

Whitespace-only changes.

test/bench/trees.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
2+
3+
include("../../src/DecisionTree.jl")
4+
5+
function loaddata()
6+
f = open("data/digits.csv")
7+
data = readlines(f)[2:end]
8+
data = [[parse(Float32, i)
9+
for i in split(row, ",")]
10+
for row in data]
11+
data = hcat(data...)
12+
Y = Int.(data[1, 1:end]) .+ 1
13+
X = convert(Matrix, transpose(data[2:end, 1:end]))
14+
return X, Y
15+
end
16+
17+
num_leaves(node::DecisionTree.Node) = num_leaves(node.left) + num_leaves(node.right)
18+
num_leaves(node::DecisionTree.Leaf) = 1
19+
20+
X, Y = loaddata()
21+
22+
# for compilation
23+
for i in 1:10
24+
t = DecisionTree.build_tree(Y, X)
25+
end
26+
27+
println("[ === CLASSIFICATION BENCHMARK === ]")
28+
for j in 1:3
29+
@time for i in 1:100
30+
tree = DecisionTree.build_tree(Y, X)
31+
end
32+
end
33+
34+
35+
Y = float.(Y) # labels/targets to Float to enable regression
36+
37+
# for compilation
38+
for i in 1:10
39+
t = DecisionTree.build_tree(Y, X)
40+
end
41+
42+
println("[ === REGRESSION BENCHMARK === ]")
43+
for j in 1:3
44+
@time for i in 1:100
45+
tree = DecisionTree.build_tree(Y, X)
46+
end
47+
end

test/regression/random.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ min_samples_leaf = 5; max_depth = 3; n_subfeatures = 0;
2929
model = build_tree(labels, features, min_samples_leaf, n_subfeatures, max_depth)
3030
@test depth(model) == max_depth
3131

32-
min_samples_leaf = 1; n_subfeatures = 0; max_depth = -1; min_samples_split = 100;
32+
min_samples_leaf = 1; n_subfeatures = 0; max_depth = -1; min_samples_split = 300;
3333
model = build_tree(labels, features, min_samples_leaf, n_subfeatures, max_depth, min_samples_split)
3434
preds = apply_tree(model, features);
3535
@test R2(labels, preds) < 0.8

0 commit comments

Comments
 (0)