Skip to content

Commit d94339b

Browse files
committed
Refactor benchmark
1 parent a748ebf commit d94339b

File tree

1 file changed

+71
-38
lines changed

1 file changed

+71
-38
lines changed

benchmark/benchmarks.jl

Lines changed: 71 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,55 +10,88 @@ function benchmark_evaluation()
1010
operators = OperatorEnum(;
1111
binary_operators=[+, -, /, *], unary_operators=[cos, exp], enable_autodiff=true
1212
)
13-
for T in (ComplexF32, ComplexF64, Float32, Float64)
14-
if !(T <: Real) && PACKAGE_VERSION < v"0.5.0" && PACKAGE_VERSION != v"0.0.0"
15-
continue
16-
end
17-
suite[T] = BenchmarkGroup()
1813

19-
n = 1_000
14+
config_options = [
15+
[
16+
(turbo=turbo, T=T, n=n, derivative=derivative) for turbo in (false, true) for
17+
T in (ComplexF32, ComplexF64, Float32, Float64) for n in (100, 1_000, 10_000)
18+
for derivative in (false, true)
19+
]...,
20+
]
2021

21-
#! format: off
22-
for turbo in (false, true)
23-
if turbo && !(T in (Float32, Float64))
24-
continue
25-
end
26-
extra_key = turbo ? "_turbo" : ""
22+
config_options = filter!(config_options) do config
23+
!(config.T <: Real) &&
24+
PACKAGE_VERSION < v"0.5.0" &&
25+
PACKAGE_VERSION != v"0.0.0" &&
26+
return false
27+
28+
config.turbo && !(config.T in (Float32, Float64)) && return false
29+
30+
config.T != Float32 && config.n != 1_000 && return false
31+
32+
config.T != Float32 && config.derivative && return false
33+
34+
return true
35+
end
36+
37+
for config in config_options
38+
T = config.T
39+
turbo = config.turbo
40+
n = config.n
41+
derivative = config.derivative
42+
43+
derivative_s = derivative ? "derivative" : "evaluation"
44+
turbo_s = turbo ? "turbo" : "standard"
45+
46+
haskey(suite, derivative_s) || (suite[derivative_s] = BenchmarkGroup())
47+
haskey(suite[derivative_s], T) || (suite[derivative_s][T] = BenchmarkGroup())
48+
haskey(suite[derivative_s][T], n) || (suite[derivative_s][T][n] = BenchmarkGroup())
49+
haskey(suite[derivative_s][T][n], turbo_s) ||
50+
(suite[derivative_s][T][n][turbo_s] = BenchmarkGroup())
51+
52+
if derivative
53+
eval_grad_tree_array(
54+
gen_random_tree_fixed_size(20, operators, 5, T),
55+
randn(MersenneTwister(0), T, 5, n),
56+
operators;
57+
variable=true,
58+
turbo=turbo,
59+
)
60+
suite[derivative_s][T][n][turbo_s] = @benchmarkable(
61+
[
62+
eval_grad_tree_array(tree, X, $operators; variable=true, turbo=$turbo)
63+
for tree in trees
64+
],
65+
setup = (
66+
X = randn(MersenneTwister(0), $T, 5, $n);
67+
treesize = 20;
68+
ntrees = 100;
69+
trees = [
70+
gen_random_tree_fixed_size(treesize, $operators, 5, $T) for
71+
_ in 1:ntrees
72+
]
73+
)
74+
)
75+
else
2776
eval_tree_array(
2877
gen_random_tree_fixed_size(20, operators, 5, T),
2978
randn(MersenneTwister(0), T, 5, n),
3079
operators;
31-
turbo=turbo
80+
turbo=turbo,
3281
)
33-
suite[T]["evaluation$(extra_key)"] = @benchmarkable(
82+
suite[derivative_s][T][n][turbo_s] = @benchmarkable(
3483
[eval_tree_array(tree, X, $operators; turbo=$turbo) for tree in trees],
35-
setup=(
36-
X=randn(MersenneTwister(0), $T, 5, $n);
37-
treesize=20;
38-
ntrees=100;
39-
trees=[gen_random_tree_fixed_size(treesize, $operators, 5, $T) for _ in 1:ntrees]
84+
setup = (
85+
X = randn(MersenneTwister(0), $T, 5, $n);
86+
treesize = 20;
87+
ntrees = 100;
88+
trees = [
89+
gen_random_tree_fixed_size(treesize, $operators, 5, $T) for
90+
_ in 1:ntrees
91+
]
4092
)
4193
)
42-
if T <: Real
43-
eval_grad_tree_array(
44-
gen_random_tree_fixed_size(20, operators, 5, T),
45-
randn(MersenneTwister(0), T, 5, n),
46-
operators;
47-
variable=true,
48-
turbo=turbo
49-
)
50-
suite[T]["derivative$(extra_key)"] = @benchmarkable(
51-
[eval_grad_tree_array(tree, X, $operators; variable=true, turbo=$turbo) for tree in trees],
52-
setup=(
53-
X=randn(MersenneTwister(0), $T, 5, $n);
54-
treesize=20;
55-
ntrees=100;
56-
trees=[gen_random_tree_fixed_size(treesize, $operators, 5, $T) for _ in 1:ntrees]
57-
)
58-
)
59-
end
6094
end
61-
#! format: on
6295
end
6396
return suite
6497
end

0 commit comments

Comments
 (0)