Skip to content

Commit 7de2230

Browse files
Merge pull request #207 from yiyuezhuo/fix_maxpool_gradient
Fix maxpool gradient
2 parents d0cee46 + fca5b64 commit 7de2230

File tree

2 files changed

+470
-2
lines changed

2 files changed

+470
-2
lines changed

src/impl/pooling_direct.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ for name in (:max, :mean)
66
@eval function $((Symbol("$(name)pool_direct!")))(
77
y::AbstractArray{T,5}, x::AbstractArray{T,5},
88
pdims::PoolDims; alpha::T = T(1), beta::T = T(0)) where {T}
9+
@assert beta == T(0) "beta not supported yet"
910
check_dims(size(x), size(y), pdims)
1011

1112
width, height, depth = input_size(pdims)
@@ -176,7 +177,7 @@ for name in (:max, :mean)
176177
# If it's equal; this is the one we chose. We only choose one per
177178
# kernel window, all other elements of dx must be zero.
178179
if y_idx == x[x_idxs...] && !maxpool_already_chose
179-
dx[x_idxs...] = dy_idx*alpha + beta*dx[x_idxs...]
180+
dx[x_idxs...] += dy_idx*alpha + beta*dx[x_idxs...]
180181
maxpool_already_chose = true
181182
# Maxpooling does not support `beta` right now. :(
182183
#else
@@ -228,7 +229,7 @@ for name in (:max, :mean)
228229
x_idxs = (input_kw, input_kh, input_kd, c, batch_idx)
229230
if $(name == :max)
230231
if y_idx == x[x_idxs...] && !maxpool_already_chose
231-
dx[x_idxs...] = dy_idx*alpha + beta*dx[x_idxs...]
232+
dx[x_idxs...] += dy_idx*alpha + beta*dx[x_idxs...]
232233
maxpool_already_chose = true
233234
#else
234235
# dx[x_idxs...] = T(0) + beta*dx[x_idxs...]

0 commit comments

Comments
 (0)