1
1
using JLD2, NNlib, BenchmarkTools
2
2
3
+ # We need things to go quickly here
4
+ BenchmarkTools. DEFAULT_PARAMETERS. samples = 20
5
+ BenchmarkTools. DEFAULT_PARAMETERS. seconds = 2.5
6
+
3
7
results = Dict ()
4
8
5
9
function add_result (val, keys... )
@@ -14,14 +18,15 @@ function add_result(val, keys...)
14
18
return r
15
19
end
16
20
17
- for rank in (3 , 2 , 1 ),
18
- N in (10 , 20 , 40 , 80 ),
19
- C_in in (1 , 2 , 4 ),
20
- C_out in (1 , 2 , 4 ),
21
- K in (3 , 6 , 12 ),
22
- stride in (1 , 2 , 4 ),
23
- dilation in (1 , 2 , 4 ),
24
- padding in (0 , 2 , 4 )
21
+ # Modify these as needed
22
+ for rank in (2 ,),
23
+ N in (20 , 40 , 80 ),
24
+ C_in in (1 ,),
25
+ C_out in (1 ,),
26
+ K in (3 ,),
27
+ stride in (1 ,),
28
+ dilation in (1 ,),
29
+ padding in (0 , 2 )
25
30
26
31
for (conv!, ∇conv_data!, ∇conv_filter!, cT, backend) in (
27
32
(NNlib. conv_direct!, NNlib.∇conv_data_direct!, NNlib.∇conv_filter_direct!, DenseConvDims, " direct" ),
@@ -41,7 +46,12 @@ for rank in (3, 2, 1),
41
46
catch
42
47
continue
43
48
end
44
- y = zeros (Float32, NNlib. output_size (cdims)... , C_out, 1 )
49
+
50
+ if cT == DenseConvDims
51
+ y = zeros (Float32, NNlib. output_size (cdims)... , C_out, 1 )
52
+ else
53
+ y = zeros (Float32, NNlib. output_size (cdims)... , C_out* C_in, 1 )
54
+ end
45
55
46
56
dx = similar (x)
47
57
dw = similar (w)
@@ -61,8 +71,9 @@ for rank in (3, 2, 1),
61
71
end
62
72
63
73
64
- for rank in (3 , 2 , 1 ),
65
- N in (10 , 20 , 40 , 80 ),
74
+ # Modify these as needed
75
+ for rank in (2 ,),
76
+ N in (20 ,),
66
77
K in (2 , 4 ),
67
78
stride in (1 , 2 , 4 )
68
79
@@ -76,13 +87,13 @@ for rank in (3, 2, 1),
76
87
(NNlib. meanpool!, NNlib.∇meanpool!, " meanpool" ),
77
88
)
78
89
79
- t_fwd = @benchmark pool ( $ y, $ x, pdims)
80
- t_data = @benchmark ∇pool ($ dx, $ y, $ x, pdims)
90
+ t_fwd = @benchmark $ ( pool) ( $ y, $ x, $ pdims)
91
+ t_data = @benchmark $ ( ∇pool) ($ dx, $ y, $ y, $ x, $ pdims)
81
92
82
93
add_result (t_fwd, " $(name)$(rank) d" , " direct" , pdims)
83
94
add_result (t_data, " $(name)$(rank) d_data" , " direct" , pdims)
84
95
85
96
@show (pdims)
86
97
@save " results.jld2" results
87
98
end
88
- end
99
+ end
0 commit comments