Skip to content

Commit f6003b1

Browse files
committed
Add performance testing framework
1 parent dd1cb04 commit f6003b1

File tree

5 files changed

+255
-0
lines changed

5 files changed

+255
-0
lines changed

test/perf/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.jld2

test/perf/Manifest.toml

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
3+
[[Base64]]
4+
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
5+
6+
[[BenchmarkTools]]
7+
deps = ["JSON", "Printf", "Statistics", "Test"]
8+
git-tree-sha1 = "5d1dd8577643ba9014574cd40d9c028cd5e4b85a"
9+
uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
10+
version = "0.4.2"
11+
12+
[[BinaryProvider]]
13+
deps = ["Libdl", "Pkg", "SHA", "Test"]
14+
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
15+
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
16+
version = "0.5.3"
17+
18+
[[CodecZlib]]
19+
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
20+
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
21+
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
22+
version = "0.5.1"
23+
24+
[[Crayons]]
25+
deps = ["Test"]
26+
git-tree-sha1 = "3017c662a988bcb8a3f43306a793617c6524d476"
27+
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
28+
version = "1.0.0"
29+
30+
[[DataStructures]]
31+
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
32+
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
33+
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
34+
version = "0.15.0"
35+
36+
[[Dates]]
37+
deps = ["Printf"]
38+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
39+
40+
[[Distributed]]
41+
deps = ["Random", "Serialization", "Sockets"]
42+
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
43+
44+
[[FileIO]]
45+
deps = ["Pkg", "Random", "Test"]
46+
git-tree-sha1 = "c94b0787956629036fb2b20fccde9e52b89d079a"
47+
uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
48+
version = "1.0.5"
49+
50+
[[InteractiveUtils]]
51+
deps = ["Markdown"]
52+
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
53+
54+
[[JLD2]]
55+
deps = ["CodecZlib", "DataStructures", "FileIO", "LinearAlgebra", "Mmap", "Printf", "Random", "Test"]
56+
git-tree-sha1 = "3ba90ff93e1d5b9b2103588051c2d349fae54dac"
57+
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
58+
version = "0.1.2"
59+
60+
[[JSON]]
61+
deps = ["Dates", "Distributed", "Mmap", "Sockets", "Test", "Unicode"]
62+
git-tree-sha1 = "1f7a25b53ec67f5e9422f1f551ee216503f4a0fa"
63+
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
64+
version = "0.20.0"
65+
66+
[[LibGit2]]
67+
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
68+
69+
[[Libdl]]
70+
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
71+
72+
[[LinearAlgebra]]
73+
deps = ["Libdl"]
74+
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
75+
76+
[[Logging]]
77+
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
78+
79+
[[Markdown]]
80+
deps = ["Base64"]
81+
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
82+
83+
[[Mmap]]
84+
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
85+
86+
[[NNlib]]
87+
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "Test", "TimerOutputs"]
88+
path = "../.."
89+
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
90+
version = "0.4.3+"
91+
92+
[[OrderedCollections]]
93+
deps = ["Random", "Serialization", "Test"]
94+
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
95+
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
96+
version = "1.0.2"
97+
98+
[[Pkg]]
99+
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
100+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
101+
102+
[[Printf]]
103+
deps = ["Unicode"]
104+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
105+
106+
[[REPL]]
107+
deps = ["InteractiveUtils", "Markdown", "Sockets"]
108+
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
109+
110+
[[Random]]
111+
deps = ["Serialization"]
112+
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
113+
114+
[[Requires]]
115+
deps = ["Test"]
116+
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
117+
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
118+
version = "0.5.2"
119+
120+
[[SHA]]
121+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
122+
123+
[[Serialization]]
124+
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
125+
126+
[[Sockets]]
127+
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
128+
129+
[[SparseArrays]]
130+
deps = ["LinearAlgebra", "Random"]
131+
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
132+
133+
[[Statistics]]
134+
deps = ["LinearAlgebra", "SparseArrays"]
135+
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
136+
137+
[[Test]]
138+
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
139+
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
140+
141+
[[TimerOutputs]]
142+
deps = ["Crayons", "Printf", "Test", "Unicode"]
143+
path = "/Users/sabae/.julia/dev/TimerOutputs"
144+
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
145+
version = "0.4.0+"
146+
147+
[[TranscodingStreams]]
148+
deps = ["Pkg", "Random", "Test"]
149+
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
150+
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
151+
version = "0.8.1"
152+
153+
[[UUIDs]]
154+
deps = ["Random", "SHA"]
155+
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
156+
157+
[[Unicode]]
158+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

test/perf/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[deps]
2+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
4+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
5+
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

test/perf/compare.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
using NNlib, BenchmarkTools, JLD2
2+
3+
@load "results.jld2" results

test/perf/perf_report.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
using JLD2, NNlib, BenchmarkTools
2+
3+
results = Dict()
4+
5+
function add_result(val, keys...)
6+
r = results
7+
for k in keys[1:end-1]
8+
if !haskey(r, k)
9+
r[k] = Dict()
10+
end
11+
r = r[k]
12+
end
13+
r[keys[end]] = val
14+
return r
15+
end
16+
17+
for rank in (3, 2, 1),
18+
N in (10, 20, 40, 80),
19+
C_in in (1, 2, 4),
20+
C_out in (1, 2, 4),
21+
K in (3, 6, 12),
22+
stride in (1, 2, 4),
23+
dilation in (1, 2, 4),
24+
padding in (0, 2, 4)
25+
26+
for (conv!, ∇conv_data!, ∇conv_filter!, cT, backend) in (
27+
(NNlib.conv_direct!, NNlib.∇conv_data_direct!, NNlib.∇conv_filter_direct!, DenseConvDims, "direct"),
28+
(NNlib.conv_im2col!, NNlib.∇conv_data_im2col!, NNlib.∇conv_filter_im2col!, DenseConvDims, "im2col"),
29+
(NNlib.depthwiseconv_direct!, NNlib.∇depthwiseconv_data_direct!, NNlib.∇depthwiseconv_filter_direct!, DepthwiseConvDims, "direct"),
30+
(NNlib.depthwiseconv_im2col!, NNlib.∇depthwiseconv_data_im2col!, NNlib.∇depthwiseconv_filter_im2col!, DepthwiseConvDims, "im2col"),
31+
)
32+
33+
x = zeros(Float32, repeat([N], rank)..., C_in, 1)
34+
if cT == DenseConvDims
35+
w = zeros(Float32, repeat([K], rank)..., C_in, C_out)
36+
else
37+
w = zeros(Float32, repeat([K], rank)..., C_out, C_in)
38+
end
39+
cdims = try
40+
cT(x, w; stride=stride, dilation=dilation, padding=padding)
41+
catch
42+
continue
43+
end
44+
y = zeros(Float32, NNlib.output_size(cdims)..., C_out, 1)
45+
46+
dx = similar(x)
47+
dw = similar(w)
48+
dy = similar(y)
49+
50+
t_fwd = @benchmark $(conv!)($y, $x, $w, $cdims)
51+
t_dx = @benchmark $(∇conv_data!)($dx, $y, $w, $cdims)
52+
t_dw = @benchmark $(∇conv_filter!)($dw, $x, $y, $cdims)
53+
54+
add_result(t_fwd, "conv$(rank)d", backend, cdims)
55+
add_result(t_dx, "conv$(rank)d_data", backend, cdims)
56+
add_result(t_dw, "conv$(rank)d_filter", backend, cdims)
57+
58+
@show(cdims)
59+
@save "results.jld2" results
60+
end
61+
end
62+
63+
64+
for rank in (3, 2, 1),
65+
N in (10, 20, 40, 80),
66+
K in (2, 4),
67+
stride in (1, 2, 4)
68+
69+
x = zeros(Float32, repeat([N], rank)..., 1, 1)
70+
pdims = PoolDims(x, K; stride=stride)
71+
y = zeros(Float32, NNlib.output_size(pdims)..., 1, 1)
72+
dx = similar(x)
73+
74+
for (pool, ∇pool, name) in (
75+
(NNlib.maxpool!, NNlib.∇maxpool!, "maxpool"),
76+
(NNlib.meanpool!, NNlib.∇meanpool!, "meanpool"),
77+
)
78+
79+
t_fwd = @benchmark pool( $y, $x, pdims)
80+
t_data = @benchmark ∇pool($dx, $y, $x, pdims)
81+
82+
add_result(t_fwd, "$(name)$(rank)d", "direct", pdims)
83+
add_result(t_data, "$(name)$(rank)d_data", "direct", pdims)
84+
85+
@show(pdims)
86+
@save "results.jld2" results
87+
end
88+
end

0 commit comments

Comments
 (0)