Skip to content

Commit f60ae1e

Browse files
committed
Update perf testing script
1 parent 36b4d9b commit f60ae1e

File tree

2 files changed

+25
-28
lines changed

2 files changed

+25
-28
lines changed

test/perf/compare.jl

Lines changed: 0 additions & 14 deletions
This file was deleted.

test/perf/perf_report.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
using JLD2, NNlib, BenchmarkTools
22

3+
# We need things to go quickly here
4+
BenchmarkTools.DEFAULT_PARAMETERS.samples = 20
5+
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 2.5
6+
37
results = Dict()
48

59
function add_result(val, keys...)
@@ -14,14 +18,15 @@ function add_result(val, keys...)
1418
return r
1519
end
1620

17-
for rank in (3, 2, 1),
18-
N in (10, 20, 40, 80),
19-
C_in in (1, 2, 4),
20-
C_out in (1, 2, 4),
21-
K in (3, 6, 12),
22-
stride in (1, 2, 4),
23-
dilation in (1, 2, 4),
24-
padding in (0, 2, 4)
21+
# Modify these as needed
22+
for rank in (2,),
23+
N in (20, 40, 80),
24+
C_in in (1,),
25+
C_out in (1,),
26+
K in (3,),
27+
stride in (1,),
28+
dilation in (1,),
29+
padding in (0, 2)
2530

2631
for (conv!, ∇conv_data!, ∇conv_filter!, cT, backend) in (
2732
(NNlib.conv_direct!, NNlib.∇conv_data_direct!, NNlib.∇conv_filter_direct!, DenseConvDims, "direct"),
@@ -41,7 +46,12 @@ for rank in (3, 2, 1),
4146
catch
4247
continue
4348
end
44-
y = zeros(Float32, NNlib.output_size(cdims)..., C_out, 1)
49+
50+
if cT == DenseConvDims
51+
y = zeros(Float32, NNlib.output_size(cdims)..., C_out, 1)
52+
else
53+
y = zeros(Float32, NNlib.output_size(cdims)..., C_out*C_in, 1)
54+
end
4555

4656
dx = similar(x)
4757
dw = similar(w)
@@ -61,8 +71,9 @@ for rank in (3, 2, 1),
6171
end
6272

6373

64-
for rank in (3, 2, 1),
65-
N in (10, 20, 40, 80),
74+
# Modify these as needed
75+
for rank in (2,),
76+
N in (20,),
6677
K in (2, 4),
6778
stride in (1, 2, 4)
6879

@@ -76,13 +87,13 @@ for rank in (3, 2, 1),
7687
(NNlib.meanpool!, NNlib.∇meanpool!, "meanpool"),
7788
)
7889

79-
t_fwd = @benchmark pool( $y, $x, pdims)
80-
t_data = @benchmark ∇pool($dx, $y, $x, pdims)
90+
t_fwd = @benchmark $(pool)( $y, $x, $pdims)
91+
t_data = @benchmark $(∇pool)($dx, $y, $y, $x, $pdims)
8192

8293
add_result(t_fwd, "$(name)$(rank)d", "direct", pdims)
8394
add_result(t_data, "$(name)$(rank)d_data", "direct", pdims)
8495

8596
@show(pdims)
8697
@save "results.jld2" results
8798
end
88-
end
99+
end

0 commit comments

Comments
 (0)