Skip to content

Commit fd9c327

Browse files
authored
Fix typo in maxpool backprop (#134)
Fix typo in `maxpool` backprop
2 parents 5998b47 + 2d9ac54 commit fd9c327

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/impl/pooling_direct.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ for name in (:max, :mean)
200200
w in w_region
201201

202202
# Grab the incoming gradient at this index for future use
203-
y_idx = dy[w, h, d, c, batch_idx]
203+
y_idx = y[w, h, d, c, batch_idx]
204204
dy_idx = dy[w, h, d, c, batch_idx]
205205
maxpool_already_chose = false
206206

test/pooling.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,14 @@ x = rand(10, 10, 3, 10)
303303
@test size(maxpool(x, (2, 2); pad = (2, 2), stride = (2, 2))) == (7, 7, 3, 10)
304304
@test size(meanpool(x, (2, 2))) == (5, 5, 3, 10)
305305
@test size(meanpool(x, (2, 2); pad = (2, 2), stride = (2, 2))) == (7, 7, 3, 10)
306+
307+
# Add another test for 2d maxpool that uses an odd-length size:
308+
@testset "Issue #133" begin
309+
x = reshape([(1.:9.)...], 3, 3, 1, 1)
310+
pdims = PoolDims(size(x), (2,2), padding = (1,1), stride = (2,2))
311+
y = maxpool(x, pdims)
312+
313+
dy = y .* 0 .+ 1
314+
dx = ∇maxpool(dy, y, x, pdims)
315+
@test dx[:,:,1,1] == [1.0 0.0 1.0; 0.0 0.0 0.0; 1.0 0.0 1.0]
316+
end

0 commit comments

Comments
 (0)