Skip to content

Commit 5e054fa

Browse files
committed
added tests
1 parent f3708f9 commit 5e054fa

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

test/pooling.jl

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ maxpool_answer_dict = Dict(
2626
3 13 18;
2727
5 15 20.
2828
],
29-
29+
3030
"dx" => [
3131
0 0 0 0;
3232
0 7 0 17;
@@ -58,7 +58,7 @@ maxpool_answer_dict = Dict(
5858
27, 28, 29, 30,
5959
32, 33, 34, 35,
6060
37, 38, 39, 40,
61-
61+
6262
47, 48, 49, 50,
6363
52, 53, 54, 55,
6464
57, 58, 59, 60.
@@ -67,18 +67,18 @@ maxpool_answer_dict = Dict(
6767
1, 3, 5,
6868
11, 13, 15,
6969
16, 18, 20,
70-
70+
7171
41, 43, 45,
7272
51, 53, 55,
7373
56, 58, 60.
7474
], (3, 3, 2)),
75-
75+
7676
"dx" => reshape([
7777
0, 0, 0, 0, 0,
7878
0, 0, 0, 0, 0,
7979
0, 0, 0, 0, 0,
8080
0, 0, 0, 0, 0,
81-
81+
8282
0, 0, 0, 0, 0,
8383
0, 27, 0, 29, 0,
8484
0, 0, 0, 0, 0,
@@ -94,12 +94,12 @@ maxpool_answer_dict = Dict(
9494
0, 0, 0, 0, 0,
9595
0, 0, 0, 0, 0,
9696
0, 0, 0, 0, 0,
97-
97+
9898
0, 0, 0, 0, 0,
9999
0, 27, 28, 29, 30,
100100
0, 32, 33, 34, 35,
101101
0, 37, 38, 39, 40,
102-
102+
103103
0, 0, 0, 0, 0,
104104
0, 47, 48, 49, 50,
105105
0, 52, 53, 54, 55,
@@ -110,12 +110,12 @@ maxpool_answer_dict = Dict(
110110
0, 0, 0, 0, 0,
111111
11, 0, 13, 0, 15,
112112
16, 0, 18, 0, 20,
113-
113+
114114
0, 0, 0, 0, 0,
115115
0, 0, 0, 0, 0,
116116
0, 0, 0, 0, 0,
117117
0, 0, 0, 0, 0,
118-
118+
119119
41, 0, 43, 0, 45,
120120
0, 0, 0, 0, 0,
121121
51, 0, 53, 0, 55,
@@ -150,7 +150,7 @@ meanpool_answer_dict = Dict(
150150
1.25 10.0 8.75
151151
2.25 12.0 9.75
152152
],
153-
153+
154154
"dx" => [
155155
1.0 1.0 3.5 3.5;
156156
1.0 1.0 3.5 3.5;
@@ -182,7 +182,7 @@ meanpool_answer_dict = Dict(
182182
14.0, 15.0, 16.0, 17.0,
183183
19.0, 20.0, 21.0, 22.0,
184184
24.0, 25.0, 26.0, 27.0,
185-
185+
186186
34.0, 35.0, 36.0, 37.0,
187187
39.0, 40.0, 41.0, 42.0,
188188
44.0, 45.0, 46.0, 47.0
@@ -191,31 +191,31 @@ meanpool_answer_dict = Dict(
191191
0.125, 0.625, 1.125,
192192
2.125, 5.0, 6.0,
193193
2.0, 4.375, 4.875,
194-
194+
195195
7.75, 16.25, 17.25,
196196
19.25, 40.0, 42.0,
197197
11.5, 23.75, 24.75,
198198
], (3, 3, 2)),
199-
199+
200200
"dx" => reshape([
201201
1.75, 1.75, 2.0, 2.0, 0.0,
202202
1.75, 1.75, 2.0, 2.0, 0.0,
203203
3.0, 3.0, 3.25, 3.25, 0.0,
204204
3.0, 3.0, 3.25, 3.25, 0.0,
205-
205+
206206
1.75, 1.75, 2.0, 2.0, 0.0,
207207
1.75, 1.75, 2.0, 2.0, 0.0,
208208
3.0, 3.0, 3.25, 3.25, 0.0,
209209
3.0, 3.0, 3.25, 3.25, 0.0,
210-
210+
211211
0.0, 0.0, 0.0, 0.0, 0.0,
212212
0.0, 0.0, 0.0, 0.0, 0.0,
213213
0.0, 0.0, 0.0, 0.0, 0.0,
214214
0.0, 0.0, 0.0, 0.0, 0.0,
215215
], (5, 4, 3)),
216216
"dx_nostride" => reshape([
217217
1.75, 3.625, 3.875, 4.125, 2.125,
218-
4.125, 8.5, 9.0, 9.5, 4.875,
218+
4.125, 8.5, 9.0, 9.5, 4.875,
219219
5.375, 11.0, 11.5, 12.0, 6.125,
220220
3.0, 6.125, 6.375, 6.625, 3.375,
221221

@@ -234,12 +234,12 @@ meanpool_answer_dict = Dict(
234234
0.265625, 0.625, 0.625, 0.75, 0.75,
235235
0.265625, 0.625, 0.625, 0.75, 0.75,
236236
0.25, 0.546875, 0.546875, 0.609375, 0.609375,
237-
237+
238238
0.96875, 2.03125, 2.03125, 2.15625, 2.15625,
239239
2.40625, 5.0, 5.0, 5.25, 5.25,
240240
2.40625, 5.0, 5.0, 5.25, 5.25,
241241
1.4375, 2.96875, 2.96875, 3.09375, 3.09375,
242-
242+
243243
0.96875, 2.03125, 2.03125, 2.15625, 2.15625,
244244
2.40625, 5.0, 5.0, 5.25, 5.25,
245245
2.40625, 5.0, 5.0, 5.25, 5.25,
@@ -296,4 +296,10 @@ for rank in (1, 2, 3)
296296
end
297297
end
298298
end
299-
end
299+
end
300+
301+
x = rand(10, 10, 3, 10)
302+
@test size(maxpool(x, (2, 2))) == (5, 5, 3, 10)
303+
@test size(maxpool(x, (2, 2); pad = (2, 2), stride = (2, 2))) == (7, 7, 3, 10)
304+
@test size(meanpool(x, (2, 2))) == (5, 5, 3, 10)
305+
@test size(meanpool(x, (2, 2); pad = (2, 2), stride = (2, 2))) == (7, 7, 3, 10)

0 commit comments

Comments
 (0)